## MODIG Reproduction Test
### This notebook demonstrates the reproduction of MODIG model performance on using pre-trained weights.

#### Library import and settings

In [None]:
import torch
import numpy as np
import pandas as pd
import os
from sklearn import metrics
from reproduction_utils import MODIG, ModigGraph, cal_metrics, load_fold_data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

Using device: cuda


#### Setting up and loading data

In [2]:
PPI_TYPE = 'STRING_ppi_edgelist'
CANCER_TYPE = 'pancan'
FOLD_TO_LOAD = 9
MODEL_PATH = f"./Data/MODIG.pt"

graph_path = os.path.join('./Data/graph', PPI_TYPE)
THR_GO = 0.8
THR_EXP = 0.8
THR_SEQ = 0.8
THR_PATH = 0.6

print(">>> Loading Graph Data...")

modig_input = ModigGraph(graph_path, PPI_TYPE, CANCER_TYPE)

ppi_path = os.path.join(graph_path, PPI_TYPE + '_ppi.tsv')
go_path = os.path.join(graph_path, PPI_TYPE + '_' + str(THR_GO) + '_go.tsv')
exp_path = os.path.join(graph_path, PPI_TYPE + '_' + str(THR_EXP) + '_exp.tsv')
seq_path = os.path.join(graph_path, PPI_TYPE + '_' + str(THR_SEQ) + '_seq.tsv')
path_path = os.path.join(graph_path, PPI_TYPE + '_' + str(THR_PATH) + '_path.tsv')
omic_path = os.path.join(graph_path, PPI_TYPE + '_omics.tsv')

files_exist = (os.path.exists(ppi_path) and os.path.exists(go_path) and 
               os.path.exists(exp_path) and os.path.exists(seq_path) and 
               os.path.exists(path_path) and os.path.exists(omic_path))



if files_exist:
    print("Found pre-generated graph files. Loading...")
    ppi_network = pd.read_csv(ppi_path, sep='\t', index_col=0)
    go_network = pd.read_csv(go_path, sep='\t', index_col=0)
    exp_network = pd.read_csv(exp_path, sep='\t', index_col=0)
    seq_network = pd.read_csv(seq_path, sep='\t', index_col=0)
    path_network = pd.read_csv(path_path, sep='\t', index_col=0)
    omicsfeature = pd.read_csv(omic_path, sep='\t', index_col=0)
    final_gene_node = list(omicsfeature.index)
else:
    print("Graph files not found. Generating from raw data (this may take time)...")
    
    omicsfeature, final_gene_node = modig_input.get_node_omicfeature()
    if not os.path.exists(graph_path):
        os.makedirs(graph_path)
    omicsfeature_df = pd.DataFrame(omicsfeature.numpy(), index=final_gene_node)
    omicsfeature_df.to_csv(omic_path, sep='\t')
    
    ppi_network, go_network, exp_network, seq_network, path_network = modig_input.generate_graph(
        THR_GO, THR_EXP, THR_SEQ, THR_PATH)
    print("Graph generation complete.")
    

name_of_network = ['PPI', 'GO', 'EXP', 'SEQ', 'PATH']
graphlist = []
print("Converting to PyG Data objects...")

if isinstance(omicsfeature, pd.DataFrame):
    omics_tensor = torch.FloatTensor(omicsfeature.values)
elif isinstance(omicsfeature, torch.Tensor):
    omics_tensor = omicsfeature
else:
    omics_tensor = torch.FloatTensor(omicsfeature.values)

for i, network in enumerate([ppi_network, go_network, exp_network, seq_network, path_network]):
    featured_graph = modig_input.load_featured_graph(network, omicsfeature)
    featured_graph = featured_graph.to(device)
    graphlist.append(featured_graph)

n_fdim = graphlist[0].x.shape[1]
print(f"Feature Dimension: {n_fdim}")
print("Graph Data Loaded Successfully.")

>>> Loading Graph Data...
Found pre-generated graph files. Loading...
Converting to PyG Data objects...
Feature Dimension: 64
Graph Data Loaded Successfully.


#### Load fold data

In [5]:
print(f">>> Loading Labels and Masks ...")

label_file_path = f'./Data/label/{CANCER_TYPE}_genelist_for_train_new.tsv'
label_file = pd.read_csv(label_file_path, sep='\t', names=['Hugosymbol', 'Label'], header=0)

genes_match = pd.merge(pd.Series(sorted(final_gene_node), name='Hugosymbol'), label_file, on='Hugosymbol', how='left')

gene_to_index = {gene: idx for idx, gene in genes_match["Hugosymbol"].items()}

feature_genename_file = './feature_genename.txt'
geneList = pd.read_csv(feature_genename_file, header=None).iloc[:, 0].tolist()

train_data_idx, test_data_idx = load_fold_data(FOLD_TO_LOAD)

test_mask_indices = [gene_to_index[g] for g in [geneList[i] for i in test_data_idx] if g in gene_to_index]
test_mask = torch.tensor(test_mask_indices, dtype=torch.long).to(device)

test_genes_list = [g for g in [geneList[i] for i in test_data_idx] if g in gene_to_index]

labels_tensor = torch.tensor(genes_match["Label"].fillna(-1).values, dtype=torch.float).to(device)
test_label = labels_tensor[test_mask]

print(f"Test Size: {len(test_mask)}")

>>> Loading Labels and Masks ...
Test Size: 97


#### Initialize model and load weights

In [6]:
print(">>> Initializing Model...")

HS1 = 300
HS2 = 100
DROPOUT = 0.25

model = MODIG(nfeat=n_fdim, hidden_size1=HS1, hidden_size2=HS2, dropout=DROPOUT)
model = model.to(device)

if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    print(f"Successfully loaded weights from {MODEL_PATH}")
else:
    raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")

>>> Initializing Model...
Successfully loaded weights from ./Data/MODIG.pt


#### Inference and Performance Evaluation

In [7]:
print(">>> Running Inference...")

model.eval()
with torch.no_grad():
    output = model(graphlist) 
    pred_logits = output[test_mask].squeeze()
    pred_prob = torch.sigmoid(pred_logits).cpu().numpy()
    
    label_cpu = test_label.cpu().numpy()
    valid_indices = label_cpu != -1
    
    final_labels = label_cpu[valid_indices]
    final_probs = pred_prob[valid_indices]
    final_genes = np.array(test_genes_list)[valid_indices]
    
    acc, auroc, auprc, f1 = cal_metrics(final_probs, final_labels)
    
    print("\n" + "="*40)
    print(f"MODIG Reproduction Results")
    print("="*40)
    print(f"AUROC    : {auroc:.4f}")
    print(f"AUPRC    : {auprc:.4f}")
    print(f"F1-Score : {f1:.4f}")
    print(f"Accuracy : {acc:.4f}")
    print("="*40)

>>> Running Inference...

MODIG Reproduction Results
AUROC    : 0.8677
AUPRC    : 0.8000
F1-Score : 0.6914
Accuracy : 0.7423
