# Analysis notebook

### Imports

In [None]:
import json
import os
from locpix_points.scripts.visualise import visualise_torch_geometric, visualise_parquet, load_file
from locpix_points.evaluate.locanalyse import(
    analyse_locs,
    explain,
)
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from torch_geometric.utils.convert import to_networkx
import networkx as nx
from networkx.drawing import draw_networkx, draw
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import pandas as pd
from torch_geometric.data import Data
import torch
import yaml

### Functions

In [2]:
def find_graph_path(project_directory, file_name, file_folder):
    """Visualise raw data
    
    Args:
        project_directory (string): Location of project directory
        file_name (string) : Name of file to image
        file_folder (string) : Which folder the file is in"""
    
    train_file_map_path = os.path.join(project_directory, f"{file_folder}/train/file_map.csv")
    val_file_map_path = os.path.join(project_directory, f"{file_folder}/val/file_map.csv")
    test_file_map_path = os.path.join(project_directory, f"{file_folder}/test/file_map.csv")
    
    train_file_map = pd.read_csv(train_file_map_path)
    val_file_map = pd.read_csv(val_file_map_path)
    test_file_map = pd.read_csv(test_file_map_path)
    
    train_out = train_file_map[train_file_map["file_name"] == file_name]
    val_out = val_file_map[val_file_map["file_name"] == file_name]
    test_out = test_file_map[test_file_map["file_name"] == file_name]
    
    if len(train_out) > 0:
        folder = "train"
        file_name = train_out["idx"].values[0]
    if len(val_out) > 0:
        folder = "val"
        file_name = val_out["idx"].values[0]
    if len(test_out) > 0:
        folder = "test"
        file_name = test_out["idx"].values[0]
    
    return os.path.join(project_directory, f"{file_folder}/{folder}/{file_name}.pt")

### Parameters

In [3]:
project_directory = ".."
final_test = True


In [4]:
# load config
with open(os.path.join(project_directory, "config/locanalyse.yaml"), "r") as ymlfile:
    config = yaml.safe_load(ymlfile)
label_map = config["label_map"]

# load in gt_label_map
metadata_path = os.path.join(project_directory, "metadata.json")
with open(
    metadata_path,
) as file:
    metadata = json.load(file)
    # add time ran this script to metadata
    gt_label_map = metadata["gt_label_map"]

assert {str(val):key for key, val in label_map.items()} == gt_label_map
gt_label_map = {int(key): val for key, val in gt_label_map.items()}

### Graph structure explainability

In [5]:
# get item to evaluate on
file_name = "three_1004"

#### Visualise raw file

In [None]:
if not final_test:
    file_folder = "preprocessed/gt_label"
else:
    file_folder = "preprocessed/test/gt_label"
file_path = os.path.join(project_directory, file_folder, file_name + ".parquet")
print(file_path)
visualise_parquet(file_path, 'y', 'x', None, 'channel', {0: "channel_0", 1: "channel_1", 2: "channel_2", 3: "channel_3"}, cmap=['k'], spheres=True, sphere_size=0.004)

In [None]:
if not final_test:
    file_folder = "preprocessed/gt_label"
else:
    file_folder = "preprocessed/test/gt_label"

fig, ax = plt.subplots(1,1,figsize=(10,10), sharex=True, sharey=True)
file_path = os.path.join(project_directory, file_folder, file_name + ".parquet")
df, unique_chans = load_file(file_path, "x", "y", None, "channel")
x = df["x"].to_numpy()
y = df["y"].to_numpy()
ax.set_aspect('equal', adjustable='box')
ax.scatter(y, x, s=10, c='k')
ax.axis('off')
scalebar = AnchoredSizeBar(ax.transData,
                            0.1, '', 'lower left', 
                            pad=1,
                            color='k',
                            frameon=False,
                            size_vertical=0.01)

ax.add_artist(scalebar)
output_path = os.path.join(project_directory, "output", "combined" + '_raw_s_10.svg') 
plt.subplots_adjust(wspace=0, hspace=0)
#plt.savefig(output_path, transparent=True, bbox_inches="tight", pad_inches=0)

### Explain

In [None]:
train_set, train_map, test_set, test_map, model, model_type, config, device = analyse_locs(project_directory, config, final_test, False)

#### SubgraphX

In [9]:
subgraph_config = {
    # number of iterations to get prediction
    "rollout": 20, 
    # number of atoms of leaf node in search tree
    "min_atoms": 5,
    # hyperparameter that encourages exploration
    "c_puct": 10.0,
    # number of atoms to expand when extend the child nodes in the search tree
    "expand_atoms": 14,
    # whether to expand the children nodes from high degreee to low degree when extend the child nodes in the search tree
    "high2low": False,
    # number of local radius to caclulate
    "local_radius": 4,
    # sampling time of montecarlo approxim
    "sample_num": 100, 
    # reward method
    "reward_method": "mc_l_shapley",
    # subgrpah building method
    "subgraph_building_method": "split",
    # maximum number of nodes to include in subgraph when generating explanation
    "max_nodes": 15,
    # number of classes
    "num_classes": 7,
}

In [None]:
 # ---- subgraphx -----
output = explain([file_name], train_map, train_set, model, model_type, config, device, type='subgraphx', subgraph_config=subgraph_config, intermediate=True)    
subgraph, complement, data, node_imp = output[0]

In [None]:
# visualise overlaid subgraph using matplotlib

fig, ax = plt.subplots(1,1,figsize=(20,20), sharex=True, sharey=True)
ax.set_aspect('equal', adjustable='box')
scalebar = AnchoredSizeBar(ax.transData,
                            0.1, '', 'lower right', 
                            pad=1,
                            color='k',
                            frameon=False,
                            size_vertical=0.01)

ax.add_artist(scalebar)
# graph
nx_g = to_networkx(data, to_undirected=True)
nx_g.remove_edges_from(nx.selfloop_edges(nx_g))
pos = data.pos.cpu().numpy()
node_color = np.where(node_imp.cpu().numpy(), 'r', 'b')

draw(nx_g, pos=np.flip(pos, axis= 1), ax=ax, node_color=node_color, node_size=50)
output_path = os.path.join(project_directory, "output", "combined" + '_subgraphx_s_10.svg') 
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig(output_path, transparent=True, bbox_inches="tight", pad_inches=0)

#### Attention

In [None]:
# ---- attention -----
output = explain([file_name], train_map, train_set, model, model_type, config, device, type='attention')
positions, edge_indices, alphas = output[0]

In [30]:
alphas_ = []
# threshold the attention values
for alpha in alphas:
    alphas_.append(torch.where(alpha > 0.0, 1.0, 0.0))

In [None]:
# visualise overlaid subgraph using matplotlib

remove_self_loops_and_neg_edges = False

if not final_test:
    fold = config["fold"]
    file_folder = f"processed/fold_{fold}"
else:
    file_folder = "processed"

colors = [(0, 0, 0), (1, 0, 0)] # first color is black, last is red
cm = LinearSegmentedColormap.from_list("br", colors)

fig, ax = plt.subplots(3,1,figsize=(20,20), sharex=True, sharey=True)
    
# raw file
for idx, position in enumerate(positions):

    ax[idx].set_aspect('equal', adjustable='box')
    ax[idx].axis('off')
    scalebar = AnchoredSizeBar(ax[idx].transData,
                                0.1, '', 'lower right', 
                                pad=1,
                                color='k',
                                frameon=False,
                                size_vertical=0.01)

    ax[idx].add_artist(scalebar)
    # graph
    dataitem = Data(x=None, edge_index=edge_indices[idx], pos=positions[idx])
    nx_g = to_networkx(dataitem)
    pos = dataitem.pos.cpu().numpy()
    edge_color = cm(alphas_[idx].cpu())
    if remove_self_loops_and_neg_edges:
        neg_edges = np.argwhere(alphas_[idx].cpu().numpy() == 0.0)
        neg_edges = np.array([e for e in nx_g.edges])[neg_edges[:,0]]
        neg_edges = [tuple(val) for val in neg_edges]
        edges = list(nx_g.edges)
        remove_indices = [i for i, item in enumerate(edges) if item in neg_edges]
        edge_color = np.delete(edge_color, remove_indices, axis=0)
        nx_g.remove_edges_from(neg_edges)
    self_loops = list(nx.selfloop_edges(nx_g))
    edges = list(nx_g.edges)
    remove_indices = [i for i, item in enumerate(edges) if item in self_loops]
    edge_color = np.delete(edge_color, remove_indices, axis=0)
    nx_g.remove_edges_from(nx.selfloop_edges(nx_g))
    if remove_self_loops_and_neg_edges:
        min = np.min(edge_color[:,0])
        max = np.max(edge_color[:,0])
        edge_color[:,0] = (edge_color[:,0] - min)/(max - min)
    draw(nx_g, pos=np.flip(pos, axis= 1), ax=ax[idx], edge_color=edge_color, node_size=1, node_color='k')
    output_path = os.path.join(project_directory, "output", "combined" + '_attention_all_edges_s_1.svg') 
    plt.subplots_adjust(wspace=0, hspace=0)

    #plt.savefig(output_path, transparent=True, bbox_inches="tight", pad_inches=0)
