# Explaining the network's predicitons in terms of connectivity changes
# and feature importance. To run it, one needs to have a final model
# checkpoints saved for every fold and the corresponding testing set.
# Explanations are done by the model trained on a given fold and
# on it's testing data, then merged to create average values.
## The structure of the folder with checkpoints should be:
checkpoints/ <br>
  - fold_0/ <br>
    - checkpoint_name.ckpt <br>
  - fold_1/ <br>
    - checkpoint_name.ckpt <br>
  - ... <br>
  - fold_n <br>
    - checkpoint_name.ckpt <br>  
## The same structure follows for the testing data folder.

In [1]:
import torch
import torch_geometric
from torch_geometric.loader import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from models import GATv2Lightning
from utils.dataloader_utils import GraphDataset
from torch_geometric.nn import Sequential
from sklearn.utils.class_weight import compute_class_weight
import lightning.pytorch as pl
import os
import json
import networkx as nx
from torchmetrics.classification import MulticlassConfusionMatrix
from sklearn.metrics import balanced_accuracy_score
import seaborn as sns
import matplotlib as mpl
from statistics import mean, stdev


In [2]:
checkpoint_dir = "..data/final_kfold/checkpoints_final_folds/"
data_dir = "../data/final_kfold/saved_folds/"
save_dir_att = "../explainability_results/attention_connectivity"
fold_list = os.listdir(checkpoint_dir)
checkpoint_fold_list = [os.path.join(checkpoint_dir, fold) for fold in fold_list]
data_fold_list = [os.path.join(data_dir, fold) for fold in fold_list]
fold_list.sort()
data_fold_list.sort()
checkpoint_fold_list.sort()

In [None]:
"""Confusion matrix"""
try:
    del summary_balanced_acc
except NameError:
    pass
try:
    del summary_conf_matrix
except NameError:
    pass
balanced_acc_list = []
for n, fold in enumerate(fold_list):
    checkpoint_path = os.path.join(checkpoint_fold_list[n], os.listdir(checkpoint_fold_list[n])[0])

    trainer = pl.Trainer(
        accelerator="auto",
        max_epochs=1,
        enable_progress_bar=True,
        deterministic=False,
        log_every_n_steps=1,
        enable_model_summary=False,
    )

    n_gat_layers = 1
    hidden_dim = 32
    dropout = 0.0
    slope = 0.0025
    pooling_method = "mean"
    norm_method = "batch"
    activation = "leaky_relu"
    n_heads = 9
    lr = 0.0012
    weight_decay = 0.0078
    dataset = GraphDataset(data_fold_list[n])
    n_classes = 3
    features_shape = dataset[0].x.shape[-1]

    model = GATv2Lightning.load_from_checkpoint(
        checkpoint_path,
        in_features=features_shape,
        n_classes=n_classes,
        n_gat_layers=n_gat_layers,
        hidden_dim=hidden_dim,
        n_heads=n_heads,
        slope=slope,
        dropout=dropout,
        pooling_method=pooling_method,
        activation=activation,
        norm_method=norm_method,
        lr=lr,
        weight_decay=weight_decay,
        map_location=torch.device('cpu')
    )
    loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)

    preds = trainer.predict(model,loader)
    preds = torch.cat(preds,dim=0)
    preds = torch.nn.functional.softmax(preds,dim=1).argmax(dim=1)
    ground_truth= torch.tensor([data.y.int().item() for data in dataset])
    balanced_acc = balanced_accuracy_score(ground_truth, preds)
    balanced_acc_list.append(balanced_acc)
    print(f"Balanced accuracy {fold}: {balanced_acc}")
    metric = MulticlassConfusionMatrix(3,)
    conf_matrix = metric(preds, ground_truth).int().numpy()
    label_list = []
    for row in conf_matrix:
        row_sum = sum(row)
        val_percentage = row / row_sum
        val_percentage = [f"{val:.2%}" for val in val_percentage]
        new_label_list = [f"{val}\n{perc}" for val, perc in (zip(row, val_percentage))]
        label_list += new_label_list
    annots = np.asarray(label_list).reshape(3,3)
    class_names = ["preictal", "ictal", "interictal"]
    disp = sns.heatmap(conf_matrix, xticklabels=class_names,yticklabels=class_names, annot=annots,fmt='',cmap='Blues',)
    fig = plt.gcf()
    cbar_axes = [ax for ax in fig.get_axes() if isinstance(ax, plt.Axes)]
    print("Colorbar axes:", cbar_axes)
    cbar = cbar_axes[0].collections[0].colorbar
    cbar.set_ticks([])
    if cbar_axes:
        retrieved_colorbar = cbar_axes[0].collections[0].colorbar
        print("Retrieved colorbar:", retrieved_colorbar)
    else:
        print("No colorbar found in the plot.")


    plt.show()
    
    try:
        summary_conf_matrix += conf_matrix
    except NameError:
        summary_conf_matrix = conf_matrix
    try:
        summary_balanced_acc += balanced_acc
    except NameError:
        summary_balanced_acc = balanced_acc
label_list = []
for row in summary_conf_matrix:
    row_sum = sum(row)
    val_percentage = row / row_sum
    val_percentage = [f"{val:.2%}" for val in val_percentage]
    new_label_list = [f"{val}\n{perc}" for val, perc in (zip(row, val_percentage))]
    label_list += new_label_list
annots = np.asarray(label_list).reshape(3,3)
print(f"Summary balanced accuracy: {mean(balanced_acc_list):.4%} +/- {stdev(balanced_acc_list):.4%}")
display_final = sns.heatmap(summary_conf_matrix, xticklabels=class_names,yticklabels=class_names, annot=annots,fmt='',cbar=False,cmap='Blues')
fig = plt.gcf()
fig.tight_layout()
fig.savefig("plots/confusion_matrix.pdf",dpi=400)
plt.show()

In [3]:
def create_graph_from_dict(input_dict, self_loops=False, threshold=0.0):
    g = nx.Graph()
    for edge, value in input_dict.items():
        if value < threshold:
            continue
        if self_loops:
            g.add_edge(*edge, strength=value)
        else:
            if edge[0] != edge[1]:
                g.add_edge(*edge, strength=value)
    return g

# Attention explanations

In [4]:
checkpoint_dir = "..data/final_kfold/checkpoints_final_folds/"
data_dir = "../data/final_kfold/saved_folds/"
save_dir_att = "../explainability_results/attention_connectivity"
fold_list = os.listdir(checkpoint_dir)
checkpoint_fold_list = [os.path.join(checkpoint_dir, fold) for fold in fold_list]
data_fold_list = [os.path.join(data_dir, fold) for fold in fold_list]
fold_list.sort()
data_fold_list.sort()
checkpoint_fold_list.sort()

In [None]:
from torch_geometric.explain import AttentionExplainer, Explainer, ModelConfig
att_explainer = AttentionExplainer()
torch_geometric.seed_everything(42)
for i, fold in enumerate(fold_list):
    print(f"Fold {i}")
    checkpoint_path = os.path.join(checkpoint_fold_list[i], os.listdir(checkpoint_fold_list[i])[0])
    
    model = GATv2Lightning.load_from_checkpoint(
    checkpoint_path,
    in_features=features_shape,
    n_classes=n_classes,
    n_gat_layers=n_gat_layers,
    hidden_dim=hidden_dim,
    n_heads=n_heads,
    slope=slope,
    dropout=dropout,
    pooling_method=pooling_method,
    activation=activation,
    norm_method=norm_method,
    lr=lr,
    weight_decay=weight_decay,
    map_location=torch.device('cpu')
    )
    
    dataset = GraphDataset(data_fold_list[i])
    loader = DataLoader(
    dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=8, prefetch_factor=20
    )
    config = ModelConfig(
        "multiclass_classification", task_level="graph", return_type="raw"
    )
    explainer = Explainer(
        model,
        algorithm=att_explainer,
        explanation_type="model",
        model_config=config,
        edge_mask_type="object",
    )
    loader = DataLoader(
        dataset, batch_size=1, shuffle=False, drop_last=False
    )

    edge_connection_dict_all = {}
    edge_connection_dict_preictal = {}
    edge_connection_dict_interictal = {}
    edge_connection_dict_ictal = {}
    interictal_cntr = 0
    preictal_cntr = 0
    ictal_cntr = 0
    for n, batch in enumerate(loader):
        explanation = explainer(
            x=batch.x,
            edge_index=batch.edge_index,
            target=batch.y,
            pyg_batch=batch.batch,
        )
        for edge_idx in range(explanation.edge_index.size(1)):
            edge = explanation.edge_index[:, edge_idx].tolist()
            edge.sort()
            edge = str(tuple(edge))
            edge_mask = explanation.edge_mask[edge_idx].item()
            if edge in edge_connection_dict_all.keys():
                edge_connection_dict_all[edge] += edge_mask
            else:
                edge_connection_dict_all[edge] = edge_mask
            if batch.y == 0:
                if edge in edge_connection_dict_preictal.keys():
                    edge_connection_dict_preictal[edge] += edge_mask
                else:
                    edge_connection_dict_preictal[edge] = edge_mask
            elif batch.y == 1:
                if edge in edge_connection_dict_ictal.keys():
                    edge_connection_dict_ictal[edge] += edge_mask
                else:
                    edge_connection_dict_ictal[edge] = edge_mask
            elif batch.y == 2:
                if edge in edge_connection_dict_interictal.keys():
                    edge_connection_dict_interictal[edge] += edge_mask
                else:
                    edge_connection_dict_interictal[edge] = edge_mask
        if batch.y == 0:
            preictal_cntr += 1
        elif batch.y == 1:
            ictal_cntr += 1
        elif batch.y == 2:
            interictal_cntr += 1
        
        if n % 100 == 0:
            print(f"Batch {n} done")
    
    edge_connection_dict_all = {key: value / (n+1) for key, value in edge_connection_dict_all.items()}
    edge_connection_dict_interictal = {key: value / interictal_cntr for key, value in edge_connection_dict_interictal.items()}
    edge_connection_dict_ictal = {key: value / ictal_cntr for key, value in edge_connection_dict_ictal.items()}
    edge_connection_dict_preictal = {key: value / preictal_cntr for key, value in edge_connection_dict_preictal.items()}
    save_path_fold = os.path.join(save_dir_att, f"fold_{i}")
    if not os.path.exists(save_path_fold):
        os.makedirs(save_path_fold)
    with open(os.path.join(save_path_fold, "edge_connection_dict_all.json"), "w") as f:
        json.dump(edge_connection_dict_all, f)
    with open(os.path.join(save_path_fold, "edge_connection_dict_interictal.json"), "w") as f:
        json.dump(edge_connection_dict_interictal, f)
    with open(os.path.join(save_path_fold, "edge_connection_dict_ictal.json"), "w") as f:
        json.dump(edge_connection_dict_ictal, f)
    with open(os.path.join(save_path_fold, "edge_connection_dict_preictal.json"), "w") as f:
        json.dump(edge_connection_dict_preictal, f)
    print(f"Fold {i} done")

# Visualize the attention weights

In [None]:
ch_names = [
    "Fp1",
    "Fp2",
    "F7",
    "F3",
    "Fz",
    "F4",
    "F8",
    "T7",
    "C3",
    "Cz",
    "C4",
    "T8",
    "P7",
    "P3",
    "P4",
    "P8",
    "O1",
    "O2",
]
custom_labels = {n : ch_names[n] for n in range(len(ch_names))}
g = create_graph_from_dict(edge_connection_dict_preictal, threshold=0.3)
edge_opacities = [
    0.2 + strength * 0.8
    for strength in nx.get_edge_attributes(g, "strength").values()
]
pos = nx.circular_layout(g)

nx.draw(
    g,
    with_labels=True,
    labels=custom_labels,
    pos=pos,
    font_weight="bold",
    edge_color=edge_opacities,
    width=2,
    node_size=1000,
    node_color="white",
    edge_cmap=plt.cm.autumn,
)
colorbar = plt.colorbar(
    plt.cm.ScalarMappable(cmap=plt.cm.autumn),
)
color = "white"
colorbar.set_label(label="Connection Strength", color=color)
colorbar.ax.yaxis.set_tick_params(color=color, labelcolor=color)
fig = plt.gcf()
fig.set_facecolor("black")  # Set the background color here
plt.show()

# Feature importances

In [None]:
checkpoint_dir = "..data/final_kfold/checkpoints_final_folds/"
data_dir = "../data/final_kfold/saved_folds/"
save_dir = "../explainability_results/feature_importance"
fold_list = os.listdir(checkpoint_dir)
checkpoint_fold_list = [os.path.join(checkpoint_dir, fold) for fold in fold_list]
data_fold_list = [os.path.join(data_dir, fold) for fold in fold_list]
fold_list.sort()
data_fold_list.sort()
checkpoint_fold_list.sort()

In [None]:
from torch_geometric.explain import GNNExplainer, Explainer, ModelConfig
torch_geometric.seed_everything(42)

for i, fold in enumerate(fold_list):
    print(fold)
    checkpoint_path = os.path.join(checkpoint_fold_list[i], os.listdir(checkpoint_fold_list[i])[0])
    
    model = GATv2Lightning.load_from_checkpoint(
    checkpoint_path,
    in_features=features_shape,
    n_classes=n_classes,
    n_gat_layers=n_gat_layers,
    hidden_dim=hidden_dim,
    n_heads=n_heads,
    slope=slope,
    dropout=dropout,
    pooling_method=pooling_method,
    activation=activation,
    norm_method=norm_method,
    lr=lr,
    weight_decay=weight_decay,
    map_location=torch.device('cpu')
    )
    
    dataset = GraphDataset(data_fold_list[i])
    loader = DataLoader(
    dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=8, prefetch_factor=20
    )
    
    gnn_explainer = GNNExplainer(epochs=100, lr=0.01)
    sum_masks = torch.zeros((18,10))
    interictal_masks = torch.zeros((18,10))
    ictal_masks = torch.zeros((18,10))
    preictal_masks = torch.zeros((18,10))
    interictal_cntr = 0
    preictal_cntr = 0
    ictal_cntr = 0
    config = ModelConfig(
            "multiclass_classification", task_level="graph", return_type="raw"
        )
    explainer = Explainer(
            model,
            algorithm=gnn_explainer,
            explanation_type="model",
            model_config=config,
            node_mask_type="attributes",
            edge_mask_type='object'
        )
    for n,batch in enumerate(loader):
        batch_unpacked = batch
        

        explanation = explainer(
            x=batch_unpacked.x,
            edge_index=batch_unpacked.edge_index,
            target=batch_unpacked.y,
            pyg_batch=batch_unpacked.batch,
        )

        sum_masks += explanation.node_mask

        if  batch_unpacked.y == 0:
            preictal_masks += explanation.node_mask
            preictal_cntr += 1
        elif batch_unpacked.y == 1:
            ictal_masks += explanation.node_mask
            ictal_cntr += 1
        elif batch_unpacked.y == 2:
            interictal_masks += explanation.node_mask
            interictal_cntr += 1
        if n % 100 == 0:
            print(f"Batch {n} done")
    sum_masks /= n+1
    interictal_masks /= interictal_cntr
    ictal_masks /= ictal_cntr
    preictal_masks /= preictal_cntr

    final_explanation_sum = explanation.clone()
    final_explanation_interictal = explanation.clone()
    final_explanation_preictal = explanation.clone()
    final_explanation_ictal = explanation.clone()
    
    final_explanation_sum.node_mask = sum_masks
    final_explanation_interictal.node_mask = interictal_masks
    final_explanation_preictal.node_mask = preictal_masks
    final_explanation_ictal.node_mask = ictal_masks
    
   
    save_path_fold = os.path.join(save_dir, fold)
    if not os.path.exists(save_path_fold):
        os.makedirs(save_path_fold)
    torch.save(final_explanation_sum, os.path.join(save_path_fold, f"final_explanation_sum.pt"))
    torch.save(final_explanation_interictal, os.path.join(save_path_fold, f"final_explanation_interictal.pt"))
    torch.save(final_explanation_preictal, os.path.join(save_path_fold, f"final_explanation_preictal.pt"))
    torch.save(final_explanation_ictal, os.path.join(save_path_fold, f"final_explanation_ictal.pt"))
    print(f" {fold} done")

  

# Visualize

In [None]:
loaded_explanation = torch.load(os.path.join(save_path_fold, f"final_explanation_sum.pt"))
feature_labels = ['variance', 'hjorth_mobility','hjorth_complexity',
                "line_length", "katz_fd", "higuchi_fd", "delta_energy",
                "theta_energy", "alpha_energy", "beta_energy"
                ]
loaded_explanation.visualize_feature_importance(feat_labels=feature_labels)