# Tests on how to launch the explanation module for our data

In [None]:
import torch
import torch_geometric
from torch_geometric.loader import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from models import GATv2Lightning
from utils.dataloader_utils import HDFDataset_Writer, HDFDatasetLoader, GraphDataset
from torch_geometric.nn import Sequential
from sklearn.utils.class_weight import compute_class_weight

In [None]:
SEIZURE_LOOKBACK = 600
BUFFER_TIME = 15
TIMESTEP = 6
INTER_OVERLAP = 0
PREICTAL_OVERLAP = 0
ICTAL_OVERLAP = 0
DOWNSAMPLING_F = 60
SFREQ = 256
SMOTE_FLAG = False
CONNECTIVITY_METRIC = "spectral_corr"
TRAIN_VAL_SPLIT = 0.1
SEED = 42
FFT = False
loso_patient = "chb20"
MNE_FEATURES = True
KFOLD_CVAL_MODE = False
NORMALIZING_PERIOD = 'interictal'
USED_CLASSES_DICT = {
            "interictal": False,
            "preictal": True,
            "ictal": True,
        }
BATCH_SIZE = 32
data_pth = "../../data"
writer = HDFDataset_Writer(
            seizure_lookback=SEIZURE_LOOKBACK,
            buffer_time=BUFFER_TIME,
            sample_timestep=TIMESTEP,
            inter_overlap=INTER_OVERLAP,
            preictal_overlap=PREICTAL_OVERLAP,
            ictal_overlap=ICTAL_OVERLAP,
            downsample=DOWNSAMPLING_F,
            sampling_f=SFREQ,
            smote=SMOTE_FLAG,
            connectivity_metric=CONNECTIVITY_METRIC,
            npy_dataset_path=f"{data_pth}/npy_data_full",
            event_tables_path=f"{data_pth}/event_tables",
            cache_folder=f"{data_pth}/cache",
        )
cache_file_path = writer.get_dataset()

loader = HDFDatasetLoader(
    root=cache_file_path,
    train_val_split_ratio=TRAIN_VAL_SPLIT,
    loso_subject=loso_patient,
    sampling_f=SFREQ,
    extract_features=MNE_FEATURES,
    fft=FFT,
    seed=SEED,
    used_classes_dict=USED_CLASSES_DICT,
    normalize_with=NORMALIZING_PERIOD,
    kfold_cval_mode=KFOLD_CVAL_MODE,
)

train_ds_path, valid_ds_path, loso_ds_path = loader.get_datasets()

train_dataset = GraphDataset(train_ds_path)
valid_dataset = GraphDataset(valid_ds_path)
loso_dataset = GraphDataset(loso_ds_path)
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0
)
valid_dataloader = DataLoader(
    valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0
)
loso_dataloader = DataLoader(
    loso_dataset, batch_size=len(loso_dataset), shuffle=False, num_workers=0
)

In [None]:
features_shape = train_dataset[0].x.shape[-1]

In [None]:
model = GATv2Lightning(
    in_features=features_shape,
    n_gat_layers=4,
    hidden_dim=32,
    n_heads=4,
    activation="leaky_relu",
    norm_method="batch",
    pooling_method="mean",
    class_weights=torch.tensor([1.])
)


# Example for attention explanation and connectivty measures

In [None]:
batch_unpacked = next(iter(loso_dataloader))

In [None]:
batch_unpacked

In [None]:
from torch_geometric.explain import AttentionExplainer, Explainer, ModelConfig

att_explainer = AttentionExplainer()
torch_geometric.seed_everything(42)
config = ModelConfig(
    "binary_classification", task_level="graph", return_type="raw"
)
explainer = Explainer(
    model,
    algorithm=att_explainer,
    explanation_type="model",
    model_config=config,
    edge_mask_type="object",
)
idx = 0
pyg_batch = torch.zeros((18,), dtype=torch.long)
explanation = explainer(
    x=batch_unpacked.x,
    edge_index=batch_unpacked.edge_index,
    target=batch_unpacked.y,
    pyg_batch=batch_unpacked.batch,
)
#explanation.visualize_graph()

In [None]:
# Sample connection_strengths dictionary
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
connection_strengths = {
    ('node1', 'node2'): 0.8, ('node1', 'node3'): 0.6, ('node1', 'node4'): 0.9, ('node1', 'node5'): 0.7,
    ('node2', 'node3'): 0.7, ('node2', 'node4'): 0.5, ('node2', 'node5'): 0.6, ('node2', 'node6'): 0.4,
    ('node3', 'node4'): 0.4, ('node3', 'node5'): 0.3, ('node3', 'node6'): 0.2, ('node3', 'node7'): 0.1,
    ('node4', 'node5'): 0.3, ('node4', 'node6'): 0.2, ('node4', 'node7'): 0.1, ('node4', 'node8'): 0.6,
    ('node5', 'node6'): 0.1, ('node5', 'node7'): 0.2, ('node5', 'node8'): 0.3, ('node5', 'node9'): 0.9
}
# Create a NetworkX graph
G = nx.Graph()

# Add edges and connection strengths to the graph
for edge, strength in connection_strengths.items():
    G.add_edge(*edge, strength=strength)

# Get positions for circular layout
pos = nx.circular_layout(G)

# Get connection strengths as edge labels
edge_labels = {(edge[0], edge[1]): f"{strength:.2f}" for edge, strength in nx.get_edge_attributes(G, 'strength').items()}

# Get connection strengths as edge opacities (scaled between 0.2 to 1.0 for visualization)
edge_opacities = [0.2 + strength * 0.8 for strength in nx.get_edge_attributes(G, 'strength').values()]

# Draw the graph
nx.draw(G, pos, with_labels=True, node_size=1000, font_size=10, node_color='skyblue', font_color='black', width=2.0, edge_color=edge_opacities, edge_cmap=plt.cm.Blues)
plt.colorbar(plt.cm.ScalarMappable(cmap=plt.cm.Blues), label="Connection Strength")
# Show the plot
plt.show()


In [None]:
explanation

# Example for feature importance explanation

In [None]:
from torch_geometric.explain import GNNExplainer, Explainer, ModelConfig

gnn_explainer = GNNExplainer(epochs=100, lr=0.001)
torch_geometric.seed_everything(42)
config = ModelConfig(
    "binary_classification", task_level="graph", return_type="raw"
)
explainer = Explainer(
    model,
    algorithm=gnn_explainer,
    explanation_type="model",
    model_config=config,
    node_mask_type="attributes",
    edge_mask_type='object'
)
idx = 0
pyg_batch = torch.zeros((18,), dtype=torch.long)
explanation = explainer(
    x=batch_unpacked.x,
    edge_index=batch_unpacked.edge_index,
    target=batch_unpacked.y,
    pyg_batch=batch_unpacked.batch,
)
feature_labels = ['variance', 'hjorth_mobility','hjorth_complexity',
                  "line_length", "katz_fd", "higuchi_fd", "delta_energy",
                  "theta_energy", "alpha_energy", "beta_energy"
                  ]
explanation.visualize_feature_importance(feat_labels=feature_labels)

In [None]:
explanation.visualize_graph()

# Captum feature importance

In [None]:
from torch_geometric.explain import CaptumExplainer, Explainer, ModelConfig

captum_explainer = CaptumExplainer("ShapleyValueSampling")
torch_geometric.seed_everything(42)
config = ModelConfig(
    "binary_classification", task_level="graph", return_type="raw"
)
explainer = Explainer(
    model,
    algorithm=captum_explainer,
    explanation_type="model",
    model_config=config,
    node_mask_type="attributes",
)
idx = 0
pyg_batch = torch.zeros((18,), dtype=torch.long)
explanation = explainer(
    x=train_dataset[idx].x,
    edge_index=train_dataset[idx].edge_index,
    target=train_dataset[idx].y.squeeze(),
    pyg_batch=pyg_batch,
)
feature_labels = ['variance', 'hjorth_mobility','hjorth_complexity',
                  "line_length", "katz_fd", "higuchi_fd", "delta_energy",
                  "theta_energy", "alpha_energy", "beta_energy"
                  ]
explanation.visualize_feature_importance(feat_labels=feature_labels)