In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px

import matplotlib
from matplotlib.colors import ListedColormap

import os
import gc
import argparse
import torch
import optuna
import joblib
import pickle

from optuna.samplers import TPESampler
from sklearn.cluster import KMeans
from sklearn.metrics import f1_score
from tqdm.notebook import tqdm

from collections import OrderedDict

from torch import nn
from torch.nn import Linear, ReLU, Sequential, LayerNorm
from torch.utils.data import DataLoader, TensorDataset

from torch_geometric.nn.models import GraphUNet
from torch_geometric.utils import dense_to_sparse, to_dense_adj
from torch_geometric_temporal.nn.recurrent import GConvGRU

import geoad.nn.models as models
import geoad.utils.utils as utils
import geoad.utils.fault_detection as fd

from geoad.utils.utils import roc_params, compute_auc

from importlib import reload
models = reload(models)
utils = reload(utils)

from pyprojroot import here
root_dir = str(here())

data_dir = '~/data/interim/'

matplotlib.rcParams.update({'font.size': 20})
matplotlib.rcParams.update({'font.family': 'DejaVu Serif'})

### TEST GALA

In [None]:
# rng_seed = 0
# torch.manual_seed(rng_seed)
# torch.cuda.manual_seed(rng_seed)
# np.random.seed(rng_seed)

use_weight = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

dataset = 'df_StOlavs_D1L2B'
df_orig = pd.read_parquet(data_dir + f'{dataset}.parq')

df, nodes = fd.treat_nodes(df_orig)
_, nodes['subgraph'] = fd.NNGraph(nodes, radius=15, subgraphs=True)

main_graph = nodes.subgraph.value_counts().index[0]
nodes = nodes.query('subgraph==@main_graph').copy()
G = fd.NNGraph(nodes, radius=15)
df = df[df.pid.isin(nodes.pid.unique())].copy()

# Weighted adjacency matrix
A = torch.tensor(G.W.toarray()).float() #Using W as a float() tensor
edge_index, edge_weight = dense_to_sparse(A)
edge_index = edge_index.to(device)
edge_weight = edge_weight.to(device) if use_weight else None

data, labels, data_dfs = utils.generate_cluster_anomaly(df, nodes, G, data_size=1)

n_timestamps = data.shape[2]

# Possible hyperparameters
n_encoding_layers = 3
reduction = 0.75

N_epochs = 10
batch_size = 2048

model = models.GALA(n_timestamps, n_encoding_layers, reduction)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_function = nn.MSELoss()

for i in range(data.shape[0]):
    print(i)
    X = torch.tensor(data[i,:,:]).float().to(device)
    norm_X = LayerNorm(X.shape, elementwise_affine=False) # Normalizes the <entire matrix> to 0 mean 1 var
    X = norm_X(X)
    
    label = labels[i,:]

    loss_values = []

    model.train()
    model.reset_parameters()
    for epoch in tqdm(range(N_epochs)):
            
        # batch = batch.to(device)
        optimizer.zero_grad()
        output = model(X, edge_index, edge_weight)
        loss = loss_function(X, output)
        loss_values.append(loss.cpu().detach().numpy())
        loss.backward()
        optimizer.step()

plt.plot(loss_values)
plt.show()

model.eval()
Y = model(X, edge_index, edge_weight)

f = nn.MSELoss(reduction='none')
score = torch.mean(f(X,Y), axis=1).cpu().detach().numpy()

nodes['anomaly'] = label
nodes['score'] = score

tpr, fpr, _ = roc_params(metric=score, label=label, interp=True)
auc = compute_auc(tpr,fpr)
auc

In [None]:
label_cmap = ListedColormap(plt.cm.viridis(np.linspace(0,1,2)))

fig, ax = plt.subplots(ncols=2, figsize=(16,5))
plotting_params = {'edge_color':'darkgray', 'edge_width':1.5,'vertex_color':'black', 'vertex_size':50}
G.plotting.update(plotting_params)
G.plot_signal(label, ax=ax[0], plot_name='Label')

ax[0].collections[0].set_cmap(label_cmap)  # Modify the colormap of the plotted data
ax[0].axis('off')

G.plot_signal(score, ax=ax[1], plot_name='Anomaly Score')
ax[1].collections[0].set_cmap('viridis')
ax[1].axis('off')

plt.show()

In [None]:
utils.visualize_map(nodes, color='anomaly', size=np.ones(nodes.pid.nunique()), size_max=10, title='Label',
                     hover_data=['cluster'], zoom=15, figsize=(600,600), colormap='viridis')
                    
utils.visualize_map(nodes, color='score', size=np.ones(nodes.pid.nunique()), size_max=10, title='Anomaly score',
                    hover_data=['cluster'], zoom=15, figsize=(600,600), colormap='viridis')

In [None]:
tpr, fpr, _ = roc_params(metric=score, label=label, interp=True)
auc = compute_auc(tpr,fpr)
auc