## Analysis notebook

### Imports

In [None]:
from bokeh.plotting import output_file, save
import json
import os
import pickle
from locpix_points.data_loading import datastruc
from locpix_points.scripts.visualise import visualise_torch_geometric, visualise_parquet, load_file
from locpix_points.evaluation.featanalyse import (
    explain,
    generate_umap_embedding,
    visualise_umap_embedding,
    generate_pca_embedding,
    visualise_pca_embedding,
    visualise_explanation,
    k_means_fn,
    get_prediction,
    subgraph_eval,
    pgex_eval,
    attention_eval,
    test_ensemble_averaging,
)

import matplotlib.pyplot as plt
import pandas as pd
import polars as pl
import seaborn as sns
from sklearn.preprocessing import StandardScaler
import torch
import umap
import yaml

### Functions

In [None]:
def find_graph_path(project_directory, file_name, file_folder):
    """Visualise raw data
    
    Args:
        project_directory (string): Location of project directory
        file_name (string) : Name of file to image
        file_folder (string) : Which folder the file is in"""
    
    train_file_map_path = os.path.join(project_directory, f"{file_folder}/train/file_map.csv")
    val_file_map_path = os.path.join(project_directory, f"{file_folder}/val/file_map.csv")
    test_file_map_path = os.path.join(project_directory, f"{file_folder}/test/file_map.csv")
    
    train_file_map = pd.read_csv(train_file_map_path)
    val_file_map = pd.read_csv(val_file_map_path)
    test_file_map = pd.read_csv(test_file_map_path)
    
    train_out = train_file_map[train_file_map["file_name"] == file_name]
    val_out = val_file_map[val_file_map["file_name"] == file_name]
    test_out = test_file_map[test_file_map["file_name"] == file_name]
    
    if len(train_out) > 0:
        folder = "train"
        file_name = train_out["idx"].values[0]
    if len(val_out) > 0:
        folder = "val"
        file_name = val_out["idx"].values[0]
    if len(test_out) > 0:
        folder = "test"
        file_name = test_out["idx"].values[0]
    
    return os.path.join(project_directory, f"{file_folder}/{folder}/{file_name}.pt")

### Parameters

In [None]:
project_directory = ".."
# load config
with open(os.path.join(project_directory, "config/featanalyse_manual.yaml"), "r") as ymlfile:
    config_manual = yaml.safe_load(ymlfile)
with open(os.path.join(project_directory, "config/featanalyse_nn.yaml"), "r") as ymlfile:
    config_nn = yaml.safe_load(ymlfile)
label_map = config_manual["label_map"]
assert label_map == config_nn["label_map"]
manual_features = config_manual["features"]

In [None]:
final_test = True
umap_n_neighbours = 20
umap_min_dist = 0.5
pca_n_components = 2
device = 'cuda'
n_repeats=1

## Analyse the nn features

In [None]:
if final_test:  
    test_df_nn_loc = os.path.join(project_directory, "output/test_df_nn_loc.csv")
    test_df_nn_loc = pd.read_csv(test_df_nn_loc)

    test_df_nn_cluster = os.path.join(project_directory, "output/test_df_nn_cluster.csv")
    test_df_nn_cluster = pd.read_csv(test_df_nn_cluster)

    test_df_nn_fov = os.path.join(project_directory, "output/test_df_nn_fov.csv")
    test_df_nn_fov = pd.read_csv(test_df_nn_fov)


#### UMAP

In [None]:
test_umap_embedding_nn_loc_path = os.path.join(project_directory, "output/test_umap_embedding_nn_loc.pkl")
test_umap_embedding_nn_cluster_path = os.path.join(project_directory, "output/test_umap_embedding_nn_cluster.pkl")
test_umap_embedding_nn_fov_path = os.path.join(project_directory, "output/test_umap_embedding_nn_fov.pkl")


In [None]:
print("------ HANDCRAFTED FEATURES -------")
#with open(train_umap_embedding_nn_loc_path, "rb") as f:
#        train_umap_embedding_nn_loc = pickle.load(f)
#visualise_umap_embedding(train_umap_embedding_nn_loc, train_df_nn_loc, label_map)
if final_test:
    with open(test_umap_embedding_nn_loc_path, "rb") as f:
        test_umap_embedding_nn_loc = pickle.load(f)
    visualise_umap_embedding(test_umap_embedding_nn_loc, test_df_nn_loc, label_map, interactive=True)#save=True, save_name="clusternet_only_handcrafted_cluster_features_umap_nn_20_mindist_0.5", project_directory=project_directory, point_size=0.001)

In [None]:
print("------ CLUSTER ENCODER -------")
#with open(train_umap_embedding_nn_cluster_path, "rb") as f:
#        train_umap_embedding_nn_cluster = pickle.load(f)
#visualise_umap_embedding(train_umap_embedding_nn_cluster, train_df_nn_cluster, label_map)
if final_test:
    with open(test_umap_embedding_nn_cluster_path, "rb") as f:
            test_umap_embedding_nn_cluster = pickle.load(f)
    visualise_umap_embedding(test_umap_embedding_nn_cluster, test_df_nn_cluster, label_map, interactive=True)#, save=True, save_name="clusternet_only_nn_cluster_encoder_umap_nn_20_mindist_0.5", project_directory=project_directory, point_size=0.001)

In [None]:
print("------ FOV ENCODER -------")
#with open(train_umap_embedding_nn_fov_path, "rb") as f:
#        train_umap_embedding_nn_fov = pickle.load(f)
#visualise_umap_embedding(train_umap_embedding_nn_fov, train_df_nn_fov, label_map)
if final_test:
    with open(test_umap_embedding_nn_fov_path, "rb") as f:
        test_umap_embedding_nn_fov = pickle.load(f)
    plot = visualise_umap_embedding(test_umap_embedding_nn_fov, test_df_nn_fov, label_map, interactive=True)# save=True, save_name="clusternet_only_nn_fov_encoder_umap_nn_20_mindist_0.5", project_directory=project_directory, point_size=0.001)

#### Load in configuration

In [None]:
# load in gt_label_map
metadata_path = os.path.join(project_directory, "metadata.json")
with open(
    metadata_path,
) as file:
    metadata = json.load(file)
    # add time ran this script to metadata
    gt_label_map = metadata["gt_label_map"]

gt_label_map = {int(key): val for key, val in gt_label_map.items()}

### Load model

In [None]:
cluster_model = torch.load(os.path.join(project_directory, f"output/cluster_model.pt"))
cluster_model.to(device)
cluster_model.eval()

#### Load datasets

In [None]:
cluster_train_folder = os.path.join(project_directory, "processed/featanalysis/train")
cluster_val_folder = os.path.join(project_directory, "processed/featanalysis/val")
cluster_test_folder = os.path.join(project_directory, "processed/featanalysis/test")

cluster_train_set = datastruc.ClusterDataset(
    None,
    cluster_train_folder,
    label_level=None,
    pre_filter=None,
    save_on_gpu=None,
    transform=None,
    pre_transform=None,
    fov_x=None,
    fov_y=None,
)

cluster_val_set = datastruc.ClusterDataset(
    None,
    cluster_val_folder,
    label_level=None,
    pre_filter=None,
    save_on_gpu=None,
    transform=None,
    pre_transform=None,
    fov_x=None,
    fov_y=None,
)

cluster_test_set = datastruc.ClusterDataset(
    None,
    cluster_test_folder,
    label_level=None,
    pre_filter=None,
    save_on_gpu=None,
    transform=None,
    pre_transform=None,
    fov_x=None,
    fov_y=None,
)

### Identify incorrectly predicted points in the UMAP for the test

In [None]:
files = plot.renderers[0].data_source.data["file_name"]
files = [x.removesuffix('.parquet') for x in files]

wrong_files = []
for file in files:
    file_name = file
    x, pred = get_prediction(
        file_name,
        cluster_model, 
        cluster_train_set, 
        cluster_val_set, 
        cluster_test_set, 
        project_directory,
        device, 
        gt_label_map)
    if x.y.detach().item() != pred:
        wrong_files.append(x.name)

new_colors = ["#000000"]*len(files)
for id, file in enumerate(files):
    if file in wrong_files:
        new_colors[id] = "#FF0000"

In [None]:
plot.renderers[0].data_source.data["new_colors"] = new_colors
plot.renderers[0].glyph.fill_color = 'new_colors'
plot.renderers[0].glyph.line_color = 'new_colors'
umap.plot.show(plot)

### Publication figures

In [None]:
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from torch_geometric.utils.convert import to_networkx
import networkx as nx
from networkx.drawing import draw_networkx, draw
import numpy as np
from matplotlib.colors import ListedColormap, Normalize


### Identify points for each class furthest and closest to rest of class

In [None]:
x = plot.renderers[0].data_source.data["x"]
y = plot.renderers[0].data_source.data["y"]
labels = plot.renderers[0].data_source.data["item"]
new_colors = plot.renderers[0].data_source.data["new_colors"]
file_name = plot.renderers[0].data_source.data["file_name"]
unique_labels = set(labels)
df = pd.DataFrame({"x": x, "y": y, "label": labels, "correct": new_colors, "name": file_name})
for label in unique_labels:
    print("Label: ", label)
    class_df = df[df["label"] == label]
    x_mean = np.mean(class_df["x"])
    y_mean = np.mean(class_df["y"])
    class_df["dist"] = ((class_df["x"]-x_mean)**2 + (class_df["y"]-y_mean)**2)**0.5
    min = class_df.loc[class_df["dist"].idxmin()]
    max = class_df.loc[class_df["dist"].idxmax()]
    print("--- Min ---")
    print(min)
    print("--- Max ---")
    print(max)
    print("---------")


### This gives files alternating closest then furthest

In [None]:
files = ["two_4524", "two_1407", "T_467", "T_33", "L_375", "L_825", "grid_7301", "grid_3676", "O_68", "O_292", "one_4017", "one_928", "three_2437", "three_2463"]

#### Raw files

In [None]:
if not final_test:
    file_folder = "preprocessed/gt_label"
else:
    file_folder = "preprocessed/test/gt_label"
fig, ax = plt.subplots(14,1,figsize=(20,80), sharex=True, sharey=True)
for idx, file_name in enumerate(files):
    file_path = os.path.join(project_directory, file_folder, file_name + ".parquet")
    df, unique_chans = load_file(file_path, "x", "y", None, "channel")
    x = df["x"].to_numpy()
    y = df["y"].to_numpy()
    ax[idx].set_aspect('equal', adjustable='box')
    ax[idx].scatter(y, x, s=1, c='k')
    ax[idx].axis('off')
    scalebar = AnchoredSizeBar(ax[idx].transData,
                               0.1, '', 'lower left', 
                               pad=1,
                               color='k',
                               frameon=False,
                               size_vertical=0.01)

    ax[idx].add_artist(scalebar)
output_path = os.path.join(project_directory, "output", "combined" + '_raw_s_1.svg') 
plt.subplots_adjust(wspace=0, hspace=0)
#plt.savefig(output_path, transparent=True, bbox_inches="tight", pad_inches=0)

#### Clustering

In [None]:
if not final_test:
    gt_file_folder = "preprocessed/gt_label"
    feat_file_folder = "preprocessed/featextract/locs"
else:
    gt_file_folder = "preprocessed/test/gt_label"
    feat_file_folder = "preprocessed/test/featextract/locs"

fig, ax = plt.subplots(14,1,figsize=(20,80), sharex=True, sharey=True)
colors_grey = ['0.8', (0.0, 1.0, 0.0), (0.9198330167772646, 0.00019544195496590255, 0.9023663764628042), (0.022826063681157582, 0.5658432009989469, 0.9292042754527637), (1.0, 0.5, 0.0), (0.2022271667963922, 0.004776515828955663, 0.892404204324589), (0.3303283202899151, 0.4608491026134133, 0.2941030733894585), (0.5, 1.0, 0.5), (0.7723074963983451, 0.0066115490293984225, 0.15243662980903372), (0.9136952591189091, 0.5104151769385785, 0.7797496184063708), (1.0, 1.0, 0.0), (0.0, 1.0, 1.0), (0.4996633088717094, 0.7906621743682507, 0.01563627319525085)]
cmap_grey = ListedColormap(colors_grey)
colors = [(0.0, 1.0, 0.0), (0.9198330167772646, 0.00019544195496590255, 0.9023663764628042), (0.022826063681157582, 0.5658432009989469, 0.9292042754527637), (1.0, 0.5, 0.0), (0.2022271667963922, 0.004776515828955663, 0.892404204324589), (0.3303283202899151, 0.4608491026134133, 0.2941030733894585), (0.5, 1.0, 0.5), (0.7723074963983451, 0.0066115490293984225, 0.15243662980903372), (0.9136952591189091, 0.5104151769385785, 0.7797496184063708), (1.0, 1.0, 0.0), (0.0, 1.0, 1.0), (0.4996633088717094, 0.7906621743682507, 0.01563627319525085)]
cmap_no_grey = ListedColormap(colors)

for idx, file_name in enumerate(files):
    file_path = os.path.join(project_directory, gt_file_folder, file_name + ".parquet")
    df_gt = pl.read_parquet(file_path)
    file_path = os.path.join(project_directory, feat_file_folder, file_name + ".parquet")
    df_feat = pl.read_parquet(file_path)
    df = df_feat.join(df_gt, on=["x", "y", "channel", "frame"], how = "outer")
    df = df.with_columns(pl.col("clusterID").fill_null(-1))
    assert df["channel"].unique().item() == 0
    x = df["x"].to_numpy()
    y = df["y"].to_numpy()
    c = df["clusterID"].to_numpy()
    if np.min(c) == -1:
        cmap = cmap_grey
    else:
        cmap = cmap_no_grey
    ax[idx].set_aspect('equal', adjustable='box')
    ax[idx].scatter(y, x, s=1, c=c, cmap=cmap)
    ax[idx].axis('off')
    scalebar = AnchoredSizeBar(ax[idx].transData,
                            0.1, '', 'lower left', 
                            pad=1,
                            color='k',
                            frameon=False,
                            size_vertical=0.01)

    ax[idx].add_artist(scalebar)
output_path = os.path.join(project_directory, "output", "combined" + '_clustered_s_1.svg') 
plt.subplots_adjust(wspace=0, hspace=0)
#plt.savefig(output_path, transparent=True, bbox_inches="tight", pad_inches=0)

#### SubgraphX

In [None]:
subgraph_config = {
    # number of iterations to get prediction
    "rollout":  20, # 20
    # number of atoms of leaf node in search tree
    "min_atoms": 5,
    # hyperparameter that encourages exploration
    "c_puct": 10.0,
    # number of atoms to expand when extend the child nodes in the search tree
    "expand_atoms": 14,
    # whether to expand the children nodes from high degreee to low degree when extend the child nodes in the search tree
    "high2low": False,
    # number of local radius to caclulate
    "local_radius": 4,
    # sampling time of montecarlo approxim
    "sample_num": 100, # 100
    # reward method
    "reward_method": "mc_l_shapley",
    # subgrpah building method
    "subgraph_building_method": "split",
    # maximum number of nodes to include in subgraph when generating explanation
    "max_nodes": 8,
    # number of classes
    "num_classes": 7,
}

In [None]:
# visualise overlaid subgraph using matplotlib
dataitems = torch.load(os.path.join(project_directory, "output/subgraphx_dataitems_rollout_100.pt"))
node_imps = torch.load(os.path.join(project_directory, "output/subgraphx_nodeimps_rollout_100.pt"))

if not final_test:
    fold = config["fold"]
    file_folder = f"processed/fold_{fold}"
else:
    file_folder = "processed"

fig, ax = plt.subplots(7,2,figsize=(10,40), sharex=True, sharey=True)
for idx, file_name in enumerate(files):
    
    file_loc = find_graph_path(project_directory, file_name, file_folder)
    # raw file
    processed_file = torch.load(file_loc)
    processed_file = processed_file.pos_dict['locs'].cpu().numpy()
    x = processed_file[:,0]
    y = processed_file[:,1]
     # center points
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    x = x - x_mean 
    y = y - y_mean
    ax[idx//2, idx%2].set_aspect('equal', adjustable='box')
    ax[idx//2, idx%2].scatter(y, x, s=1, c='0.8')
    #ax[idx].axis('off')
    scalebar = AnchoredSizeBar(ax[idx//2, idx%2].transData,
                               0.1, '', 'lower right', 
                               pad=1,
                               color='k',
                               frameon=False,
                               size_vertical=0.01)

    ax[idx//2, idx%2].add_artist(scalebar)
    # graph
    nx_g = to_networkx(dataitems[idx], to_undirected=True)
    nx_g.remove_edges_from(nx.selfloop_edges(nx_g))
    pos = dataitems[idx].pos.cpu().numpy()
    node_color = np.where(node_imps[idx].cpu().numpy(), '#00FF00', 'k')
    # center points
    pos[:,0] = pos[:,0] - x_mean
    pos[:,1] = pos[:,1] - y_mean
    draw(nx_g, pos=np.flip(pos, axis= 1), ax=ax[idx//2, idx%2], node_color=node_color, node_size=50)
output_path = os.path.join(project_directory, "output", "combined" + '_subgraphx_s_1.svg') 
plt.subplots_adjust(wspace=0, hspace=0)
#plt.savefig(output_path, transparent=True, bbox_inches="tight", pad_inches=0)


#### Visualise clustering in processed graph

In [None]:
# visualise overlaid subgraph using matplotlib

dataitems = torch.load(os.path.join(project_directory, "output/subgraphx_dataitems_rollout_100.pt"))

if not final_test:
    fold = config["fold"]
    file_folder = f"processed/fold_{fold}"
else:
    file_folder = "processed"

colors_grey = ['0.8', (0.0, 1.0, 0.0), (0.9198330167772646, 0.00019544195496590255, 0.9023663764628042), (0.022826063681157582, 0.5658432009989469, 0.9292042754527637), (1.0, 0.5, 0.0), (0.2022271667963922, 0.004776515828955663, 0.892404204324589), (0.3303283202899151, 0.4608491026134133, 0.2941030733894585), (0.5, 1.0, 0.5), (0.7723074963983451, 0.0066115490293984225, 0.15243662980903372), (0.9136952591189091, 0.5104151769385785, 0.7797496184063708), (1.0, 1.0, 0.0), (0.0, 1.0, 1.0), (0.4996633088717094, 0.7906621743682507, 0.01563627319525085)]
cmap_grey = ListedColormap(colors_grey)
colors = [(0.0, 1.0, 0.0), (0.9198330167772646, 0.00019544195496590255, 0.9023663764628042), (0.022826063681157582, 0.5658432009989469, 0.9292042754527637), (1.0, 0.5, 0.0), (0.2022271667963922, 0.004776515828955663, 0.892404204324589), (0.3303283202899151, 0.4608491026134133, 0.2941030733894585), (0.5, 1.0, 0.5), (0.7723074963983451, 0.0066115490293984225, 0.15243662980903372), (0.9136952591189091, 0.5104151769385785, 0.7797496184063708), (1.0, 1.0, 0.0), (0.0, 1.0, 1.0), (0.4996633088717094, 0.7906621743682507, 0.01563627319525085)]
cmap_no_grey = ListedColormap(colors)

fig, ax = plt.subplots(14,1,figsize=(20,80), sharex=True, sharey=True)
for idx, file_name in enumerate(files):
    
    file_loc = find_graph_path(project_directory, file_name, file_folder)
    # raw file
    processed_file = torch.load(file_loc)
    c = processed_file.edge_index_dict["locs", "in", "clusters"][1].cpu().numpy()
    processed_file = processed_file.pos_dict['locs'].cpu().numpy()
    x = processed_file[:,0]
    y = processed_file[:,1]
     # center points
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    x = x - x_mean 
    y = y - y_mean
    ax[idx].set_aspect('equal', adjustable='box')
    
    if np.min(c) == -1:
        cmap = cmap_grey
    else:
        cmap = cmap_no_grey
    ax[idx].scatter(y, x, s=1, c=c, cmap=cmap)
    ax[idx].axis('off')
    scalebar = AnchoredSizeBar(ax[idx].transData,
                               0.1, '', 'lower right', 
                               pad=1,
                               color='k',
                               frameon=False,
                               size_vertical=0.01)

    ax[idx].add_artist(scalebar)
   
output_path = os.path.join(project_directory, "output/processed_clusters.svg") 
plt.subplots_adjust(wspace=0, hspace=0)
#plt.savefig(output_path, transparent=True, bbox_inches="tight", pad_inches=0)
