# GNN model to model EMT 

Note: figure code found in figures_paper_gnn.ipynb

use: base_env_gnn.yml

In [None]:
import os
import numpy as np
import pandas as pd
import networkx as nx
import torch
from tqdm import tqdm

from src.config import *
from src.data_processing import load_and_process_data, assign_emt_labels
from src.graph_utils import build_graph, assign_one_hot_celltype, assign_spatial_blocks, create_edge_index
from src.dataset import XeniumSpatialPyg
from src.models import Net
from src.train import cross_validation_training
from src.evaluation import evaluate_classification,evaluate_regression  # or evaluate_regression if continuous
from src.explanations import generate_node_explanations, generate_edge_explanations

# 1. Load and process data
xenium_labels, pca_columns = load_and_process_data()
#default: four states, if false then analyse 2 state model
xenium_labels, mapping = assign_emt_labels(xenium_labels, four_states=True)

# 2. Build the graph
G = build_graph(xenium_labels, neighbour_no=10)
lb = assign_one_hot_celltype(G, attribute='celltype_minor')
assign_spatial_blocks(G, n_blocks=4)
edge_index, edge_weights = create_edge_index(G, edge_lengths=True)

# 3. Create cross-validation datasets
num_folds = 2
fold_datasets = []
for fold_idx in range(1, num_folds + 1):
    dataset = XeniumSpatialPyg(G, xenium_labels, edge_index=edge_index, edge_lengths=edge_weights,
                                inductive_split=False, test_cnv_prediction=True, continuous_y=True,
                                tme_and_cnv=True, fold_idx=fold_idx, num_folds=num_folds)
    fold_datasets.append(dataset[0])

# 4. Train models via cross-validation
all_predicted, all_true, models = cross_validation_training(fold_datasets, learning_rate=0.01, num_epochs=1000)

# 5. Evaluate classification performance
n_classes = all_predicted[0].shape[1]
df_auc = evaluate_classification(all_predicted, all_true, n_classes)
df_auc.to_csv('results/auc_scores_per_fold_and_class.csv', index=False)

#regression 
#df_rmse = evaluate_regression(all_predicted, all_true)
#df_rmse.to_csv('results/rmse_scores_per_fold.csv', index=False)

# 6. Generate explanations for one fold 
data = fold_datasets[0]
predicted = all_predicted[0]
model = models[0]
#set this >100 for most accurate node explanations; set to 300 for results in the paper
no_explan = 300

mask = data['test_mask']
indices = [i for i, x in enumerate(mask) if x]
_, predict_labels = torch.max(torch.tensor(predicted), dim=1)
predict_labels = predict_labels.cpu()
test_labels_index = pd.DataFrame({
    'Indices': indices,
    'Label': predict_labels
})

# Generate node explanations; file saved in node_p_values.csv
node_expls, node_pvals = generate_node_explanations(model, data, test_labels_index, no_explan=no_explan, mask_type="node", label_list=lb.classes_)
node_pvals.to_csv(f"node_explanations.csv", index=True)

# Generate edge explanations: takes a while to run
edge_results = generate_edge_explanations(model, data, G, no_explan=no_explan, cell_type="xenium", output_dir="results/")


Plot results

view result downstream analysis in figures_paper_gnn.ipynb