# GNN model to model MERFISH data 

Note: figure code found in figures_paper_gnn.ipynb

env: use base_env_gnn.yml

In [None]:

import os
import random
import numpy as np
import pandas as pd
import networkx as nx
import torch
from tqdm import tqdm
from src.config import *  
from src.dataset import MouseSpatialPyg
from src.models import Net
from src.train import train_and_evaluate_model, binary_auc
from src.explanations import *
# Load merfish data
merfish_data = pd.read_csv("data/merfish_dataset.csv", index_col=0)

# Preprocessing steps (e.g., filter cell types, remove low-count types)
cell_types = merfish_data['Cell_Type'].unique()
unique, counts = np.unique(merfish_data['Cell_Type'], return_counts=True)
print(dict(zip(unique, counts)))
cell_types_to_remove = counts[counts < 1000]
cell_types = [x for x in cell_types if x not in cell_types_to_remove]

# Parameters
#set this >100 for most accurate node explanations
no_explan = 150
epoch = 1000
auc_records = []
num_splits = 1
just_explan = True

for cell_type in tqdm(cell_types):
    print(f"Processing cell type: {cell_type}")
    G = nx.Graph()
    neighbour_no = 20
    grouped = merfish_data.groupby('Sample')
    
    # Build graph for each sample
    for sample_name, group in grouped:
        coords = group[['x', 'y']].values
        from sklearn.neighbors import NearestNeighbors
        nbrs = NearestNeighbors(n_neighbors=neighbour_no, algorithm='ball_tree').fit(coords)
        distances, indices = nbrs.kneighbors(coords)
        for i in range(len(coords)):
            node_id = group.index[i]
            G.add_node(node_id, **group.iloc[i].to_dict())
            for j in range(1, neighbour_no):
                neighbor_idx = indices[i, j]
                neighbor_node_id = group.index[neighbor_idx]
                distance = distances[i, j]
                G.add_edge(node_id, neighbor_node_id, weight=distance)

    nx.set_node_attributes(G, 0, 'Closest_cell_type_label')
    for node in G.nodes():
        neighbors = sorted(G[node], key=lambda neighbor: G[node][neighbor]['weight'])
        if neighbors:
            closest_neighbor = neighbors[0]
            if G.nodes[closest_neighbor]['Cell_Type'] == cell_type:
                G.nodes[node]['Closest_cell_type_label'] = 1

    # One-hot encode cell types
    from sklearn.preprocessing import LabelBinarizer
    lb = LabelBinarizer()
    unique_labels = list(set(nx.get_node_attributes(G, 'Cell_Type').values()))
    lb.fit(unique_labels)
    for node in G.nodes():
        cell_type_label = G.nodes[node]['Cell_Type']
        G.nodes[node]['one_hot_Cell_Type'] = lb.transform([cell_type_label])[0]

    # Remove nodes of a particular cell type
    G.remove_nodes_from([node for node, data in G.nodes(data=True) if data['Cell_Type'] == cell_type])
    
    # For each split
    for split_num in range(num_splits):
        print(f"Split {split_num + 1} for cell type {cell_type}")
        torch.manual_seed(split_num)
        np.random.seed(split_num)
        dataset = MouseSpatialPyg(G, merfish_data, edge_lengths=True, inductive_split=True)
        data = dataset[0]
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = Net(data, num_features=int(data.num_features), hidden_dim1=16, hidden_dim2=32, dropout_rate=0.5).to(device)
        best_performance, best_model, predicted_classes_best_model = train_and_evaluate_model(
            model, data, learning_rate=0.001, num_epochs=epoch, weights=True, return_lowest_loss=True
        )

        ground_truth = data.y[data.test_mask].cpu().numpy()
        cell_type_safe = cell_type.replace("/", "_").replace(" ", "_")
        probs = torch.exp(predicted_classes_best_model)
        roc_auc = binary_auc(probs.cpu().detach().numpy(), ground_truth, data, file_save_name=f"results/{cell_type_safe}_auc.pdf")
        print(f"AUC for split {split_num + 1}: {roc_auc}")
        auc_records.append({'Cell_Type': cell_type, 'Split_Num': split_num + 1, 'AUC': roc_auc})
        
        mask = data['test_mask']
        #get indices of true in mask boolean array
        indices = [i for i, x in enumerate(mask) if x]
        #get class labels for indices (y_pred_class is the predicted output)
        _, predict_labels = torch.max(torch.tensor(predicted_classes_best_model), dim=1)
        predict_labels=predict_labels.cpu()
        test_labels_index = pd.DataFrame({
            'Indices': indices,
            'Label':  predict_labels
        })

        score_df_intgrad_aggregated, p_values_df = generate_node_explanations(model, data, test_labels_index, no_explan, mask_type="node", label_list=lb.classes_)
        #save
        safe_cell_type = str(cell_type).replace("/", "_")  # Replace slashes with underscores
        #first row= base line, bottom row= node of interest
        p_values_df.to_csv(f"node_{safe_cell_type}_explanations.csv", index=True)

        #edge_explan=generate_edge_explanations(model, data, G, no_explan,cell_type=cell_type)
        # Clean up to free memory
        del model, data
        import gc
        gc.collect()
        torch.cuda.empty_cache()

# Save aggregated AUC records if needed:
auc_df = pd.DataFrame(auc_records)
auc_df.to_csv("results/auc_scores.csv", index=False)


View result downstream analysis in figures_paper_gnn.ipynb e.g. plotting node results and edge results