In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px

import matplotlib
from matplotlib.colors import ListedColormap

import os
import gc
import argparse
import torch
import optuna
import joblib
import pickle

from torch_geometric.utils import dense_to_sparse
from sklearn.cluster import KMeans, BisectingKMeans, SpectralClustering

import geoad.nn.models as models
import geoad.utils.utils as utils
import geoad.utils.fault_detection as fd

from geoad.utils.utils import roc_params, compute_auc

from importlib import reload
models = reload(models)
utils = reload(utils)

from pyprojroot import here
root_dir = str(here())

data_dir = '~/data/interim/'

matplotlib.rcParams.update({'font.size': 20})
matplotlib.rcParams.update({'font.family': 'DejaVu Serif'})

----------

#### Showcasing data

In [None]:
rng_seed = 0
torch.manual_seed(rng_seed)
torch.cuda.manual_seed(rng_seed)
np.random.seed(rng_seed)

use_weight = False
device = 'cpu'

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

dataset = 'output_name'
df_orig = pd.read_parquet(data_dir + f'{dataset}.parq')

df_ds = df_orig[df_orig.timestamp<'2022-06'].copy()
df_ds = df_ds.groupby('pid').resample('30d', on='timestamp').mean().reset_index()

data, labels, data_dfs, G, nodes=utils.generate_data(df_ds, 10, 0, 10, anomalous_nodes=20, noise=0.1, label_noise=True)

A = torch.tensor(G.W.toarray()).float() #Using W as a float() tensor
edge_index, edge_weight = dense_to_sparse(A)
edge_index = edge_index.to(device)
edge_weight = edge_weight.to(device) if use_weight else None


n_timestamps = data.shape[2]

label_cmap = ListedColormap(plt.cm.viridis(np.linspace(0,1,3)))
label = labels[3,:]
fig, ax = plt.subplots(figsize=(10,5))
plotting_params = {'edge_color':'darkgray', 'edge_width':1.5,'vertex_color':'black', 'vertex_size':50}
G.plotting.update(plotting_params)
G.plot_signal(label, ax=ax, plot_name='Label')

ax.collections[0].set_cmap(label_cmap)  # Modify the colormap of the plotted data
ax.axis('off')
plt.show()


In [None]:
df_plot = data_dfs[1]
df_plot = df_plot.merge(df_ds[['pid','latitude','longitude', 'easting', 'northing']].drop_duplicates('pid'), how='left', on='pid')
fd.plot_selected_pixels(df_plot, id_list=df_plot[df_plot.anomaly==2].pid.unique())

In [None]:
df_visualize = data_dfs[5].groupby('pid', as_index=False).max().merge(df_orig[['easting','northing','pid']].drop_duplicates(), how='left', on='pid')
df_visualize['label'] = df_visualize['anomaly'].apply(lambda x: 'anomaly' if x == 1 else 'normal')
df_visualize

In [None]:
utils.visualize_map(df_visualize, color='label', zoom=15.5, size=np.ones(df_visualize.shape[0]), size_max=5,
                    discrete_colormap=px.colors.qualitative.Plotly, transparent=True)