In [2]:
# Imports

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import scipy as sp
# import contextily as cx

import torch
import pygsp
import optuna
import joblib
import gc
import argparse
import os
import matplotlib

from matplotlib.ticker import ScalarFormatter, StrMethodFormatter, FormatStrFormatter, FuncFormatter
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

from sklearn.metrics import mean_squared_error, confusion_matrix, auc
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from optuna.samplers import TPESampler
from torch.nn import Linear
from torch_geometric.nn.models import GraphUNet
from torch_geometric.nn import GCNConv, Sequential
from torch_geometric.data import Data
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx, grid
from torchvision import datasets, transforms

from importlib import reload
from pyprojroot import here
ROOT_DIR = str(here())
insar_path = ROOT_DIR + "/data/raw/insar/"

matplotlib.rcParams.update({'font.size': 20})
matplotlib.rcParams.update({'font.family': 'Times New Roman'})

# import dario.models.mismatch_analysis as mma
# mma = reload(mma)

# Function definitions

def plot_anim(outputs, epochs):
    def generate_matrix(epoch):
        out = outputs[epoch][2].detach().numpy().reshape(28,28)
        inp = outputs[epoch][1].numpy().reshape(28,28)

        out = np.c_[inp,out]
        return out #np.abs(out-inp)

    fig, ax = plt.subplots()
    def init():
        ax.clear()
        plt.close()

    def update(frame):
        matrix = generate_matrix(frame)  # Generate the matrix for the current frame
        ax.imshow(matrix, cmap='gray', vmin=0, vmax=1)  # Update the plot with the new matrix
        # Hide all ticks and tick labels
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_title(f'{frame}', fontdict={'color':'white'})
        plt.close()

    fps = 2
    ani = FuncAnimation(fig, update, frames=range(epochs), interval=1000/fps, repeat=True, blit=False, init_func=init)
    return ani

def roc_params(metric, label, interp=True):
    fpr = []
    tpr = []
    thr = []
    thr_list = list(np.linspace(0, metric.max(),1001))

    fp = 1
    ind = 0
    while fp > 0:
        threshold = thr_list[ind]
        ind += 1

        y = (metric>threshold)
        tn, fp, fn, tp = confusion_matrix(label, y).ravel()

        fpr.append( fp/(tn + fp) )
        tpr.append( tp/(tp + fn) )
        thr.append( threshold )

    while tp > 0:
        threshold = thr_list[ind]
        ind += 1
        y = (metric>threshold)
        tn, fp, fn, tp = confusion_matrix(label, y).ravel()

    
    fpr = fpr[::-1]
    tpr = tpr[::-1]
    thr = thr[::-1]

    if interp:
        fpr_base = np.linspace(0, 1, 101)
        tpr = list(np.interp(fpr_base, fpr, tpr))
        thr = list(np.interp(fpr_base, fpr, thr))
        fpr = list(fpr_base)

    fpr.insert(0, 0)
    tpr.insert(0, 0)
    thr.insert(0, threshold)

    return tpr, fpr, thr

def compute_auc(tpr, fpr):
    auc = 0
    for i in range(1, len(fpr)):
        auc += (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]) / 2
    return auc

# def detection(df_metrics, column_name='wse', threshold_min=1000, threshold_max=np.inf, selector='group',
#               detection_param='detection_sum', detection_param_threshold=None):
#     # df_relevant contains data from nodes that, at some point, have lower<=wse<=upper, and their neighbors.
#     # nodes are put into groups if they are close to each other.

#     if detection_param_threshold is None:
#         detection_param_threshold = df_metrics.timestamp.nunique()//2

#     df_relevant = mma.relevant_neighborhood(df_metrics, column_name=column_name,
#                                             lower=threshold_min, upper=threshold_max,
#                                             only_relevant=True, return_df=True, plot=False, filter_dates=False)

#     # Treating disconnected nodes as individual groups. Assining new values
#     new_group_values = df_relevant.query('group==0').pid.factorize()[0] + df_relevant.group.max()+1
#     df_relevant.loc[df_relevant.group==0, 'group'] = new_group_values


#     df_relevant['detection'] = (df_relevant[column_name]>=threshold_min) & (df_relevant[column_name]<=threshold_max)
#     df_detection = df_relevant.groupby('pid').agg({column_name:['max','mean'],
#                                                     'detection':['sum',mma.consecutive_ones],
#                                                     'group':'mean'}).reset_index()

#     df_detection.columns = [f"{level1}_{level2}" if level2 else level1 for level1, level2 in df_detection.columns]
#     df_detection.rename({'group_mean':'group'}, axis=1, inplace=True)

#     query = f'{detection_param}>{detection_param_threshold}'
#     selected = df_detection.query(query)[selector].unique()

#     return df_relevant, selected

# def skew(df):
#     return np.abs(sp.stats.skew(df.mean_velocity))


# def compute_metric(df_test, cut=2, radius=15):

#     df_metrics = []
#     for cluster in sorted(df_test.cluster.unique()):

#         df, nodes = mma.treat_nodes(df_test.query('cluster==@cluster'))
#         G, nodes['subgraph'] = mma.NNGraph(nodes, radius=radius, subgraphs=True)

#         df_metrics_cluster = []
#         for sub_index in sorted(nodes.subgraph.unique())[1:]:

#             subnodes = nodes.query('subgraph==@sub_index').copy()
#             subdf = df[df.pid.isin(subnodes.pid)].copy()

#             G = mma.NNGraph(subnodes, radius=radius)

#             w, V = np.linalg.eigh(G.L.toarray())
#             wh = np.ones(G.N)
#             wh[w<cut] = 0
#             Hh = V @ np.diag(wh) @ V.T

#             smoothed = subdf[['pid', 'timestamp', 'smoothed' ]].pivot(index='pid', columns='timestamp')

#             subdf['hf'] = np.abs((Hh @ smoothed.values).reshape((-1,), order='C'))

#             df_metrics_cluster.append(subdf)

#         df_metrics_cluster = pd.concat(df_metrics_cluster)
#         df_metrics.append(df_metrics_cluster)

#     df_metrics = pd.concat(df_metrics)
#     return df_metrics


# def hfilter(G, cut=2):
#     L = G.L.toarray()
#     w, V = np.linalg.eigh(L)
#     wh = np.ones(G.N)
#     wh[w<cut] = 0
#     Hh = V @ np.diag(wh) @ V.T
#     return Hh

# def matplotlib_roc(save=None, ax=None):
#     matplotlib.rcParams.update({'font.size': 20})
#     matplotlib.rcParams.update({'font.family': 'Times New Roman'})

#     if ax is None:
#         fig, ax = plt.subplots(figsize=(12,5))
#     # sc = ax.scatter(fpr, tpr, c=thr, cmap='viridis', label='Threshold')
#     sc = ax.plot(fpr, tpr, linestyle='dotted', linewidth=1, color='black')

#     # # Colorbar
#     # cbar = plt.colorbar(sc, ax=ax)
#     # cbar.set_label('Threshold', rotation=270, labelpad=15)

#     plt.xlabel('False Positive Rate')
#     plt.ylabel('True Positive Rate')
#     # plt.grid()
#     # plt.tight_layout()

#     if save is not None:
#         plt.savefig(save, transparent=True)
