## MNGCL Reproduction Test
### This notebook demonstrates the reproduction of MNGCL model performance on using pre-trained weights.


#### Setup and Configurations

In [1]:
import torch
import numpy as np
import pandas as pd
from sklearn import metrics, linear_model
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import average_precision_score, f1_score
import matplotlib.pyplot as plt
import os
import sys

from reproduction_utils import MNGCL, GCN

# Device Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Paths
DATA_PATH = "../../Data/STRING/"
DATA_PATH_2 = "./Data/" 
FOLD_PATH = "../../Data/STRING/10fold/fold_3"
MODEL_PATH = "./Data/MNGCL.pt"

# Check if model exists
if not os.path.exists(MODEL_PATH):
    print(f"Warning: Model file '{MODEL_PATH}' not found. Please train the model and save the weights first.")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


#### Data Loading Functions

In [2]:
def load_kfold_data(fold_path, device):
    """Load train/valid/test split indices and masks for a specific fold."""
    
    # Load Indices
    train_idx = np.loadtxt(f"{fold_path}/train.txt", dtype=int)
    valid_idx = np.loadtxt(f"{fold_path}/valid.txt", dtype=int)
    test_idx = np.loadtxt(f"{fold_path}/test.txt", dtype=int)
    
    # Load Masks
    train_mask = torch.tensor(np.loadtxt(f"{fold_path}/train_mask.txt", dtype=bool), device=device)
    valid_mask = torch.tensor(np.loadtxt(f"{fold_path}/valid_mask.txt", dtype=bool), device=device)
    test_mask = torch.tensor(np.loadtxt(f"{fold_path}/test_mask.txt", dtype=bool), device=device)
    
    # Load Labels
    labels = torch.tensor(np.loadtxt(f"{fold_path}/labels.txt"), dtype=torch.float32, device=device)
    
    print(f"train/valid/test data load completed")
    
    return train_idx, valid_idx, test_idx, train_mask, valid_mask, test_mask, labels

#### Load Data (Features & Graphs)

In [3]:
print("Loading Multi-omics Features...")

try:
    data_x_df = pd.read_csv(DATA_PATH + 'multiomics_features_STRING.tsv', sep='\t', index_col=0)
    data_x_df = data_x_df.dropna()
    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(data_x_df.values)
    data_x = torch.tensor(features_scaled, dtype=torch.float32, device=DEVICE)
    
    # Use Pan-Cancer features (first 48 dimensions)
    data_x = data_x[:, :48]
    print(f"Feature shape: {data_x.shape}")
    
    try:
        dataz = torch.load(DATA_PATH_2 + "Str_feature_STRING.pkl", map_location=DEVICE)
        data_x = torch.cat((data_x, dataz), 1)
        print(f"Features with node2vec: {data_x.shape}")
    except FileNotFoundError:
        print("Node2vec features not found, proceeding without them.")

except Exception as e:
    print(f"Error loading features: {e}")

print("\nLoading Graph Adjacency Matrices...")
try:
    ppiAdj = torch.load(DATA_PATH + 'STRING_ppi.pkl', map_location=DEVICE)
    ppiAdj_self = torch.load(DATA_PATH_2 + 'STRING_ppi_selfloop.pkl', map_location=DEVICE)
    pathAdj = torch.load(DATA_PATH + 'pathway_SimMatrix_filtered.pkl', map_location=DEVICE)
    goAdj = torch.load(DATA_PATH + 'GO_SimMatrix_filtered.pkl', map_location=DEVICE)
    
    pos = ppiAdj_self.to_dense().to(DEVICE)
    
    print("Graphs loaded successfully.")
except Exception as e:
    print(f"Error loading graphs: {e}")

print("\nLoading ...")
train_idx, valid_idx, test_idx, train_mask, valid_mask, test_mask, Y = load_kfold_data(FOLD_PATH, DEVICE)

Loading Multi-omics Features...
Feature shape: torch.Size([10251, 48])
Features with node2vec: torch.Size([10251, 64])

Loading Graph Adjacency Matrices...
Graphs loaded successfully.

Loading ...
train/valid/test data load completed


#### Initialize Model & Load Weights

In [6]:
# Configuration (training config)
INPUT_DIM = data_x.shape[1]
HIDDEN_DIM = 300
OUTPUT_DIM = 100
TAU = 0.5

# Initialize GCN
gcn = GCN(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)

# Initialize MNGCL
model = MNGCL(
    gnn=gcn,
    pos=pos,
    tau=TAU,
    gnn_outsize=OUTPUT_DIM,
    projection_hidden_size=300,
    projection_size=100
).to(DEVICE)

# Load Pre-trained Weights
if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    print("Pre-trained model weights loaded successfully.")
else:
    print("Model weights not found. Cannot proceed with inference.")

Pre-trained model weights loaded successfully.


#### Inference & Evaluation

In [5]:
def LogReg(train_x, train_y, test_x):
    """Logistic Regression classifier for final prediction"""
    regr = linear_model.LogisticRegression(max_iter=10000)
    regr.fit(train_x, train_y.ravel())
    pre = regr.predict_proba(test_x)
    pre = pre[:, 1]
    return pre

print("Running Inference...")
model.eval()
with torch.no_grad():
    # Prepare sparse indices for ChebConv
    ppiAdj_index = ppiAdj.coalesce().indices().to(DEVICE)
    pathAdj_index = pathAdj.coalesce().indices().to(DEVICE)
    goAdj_index = goAdj.coalesce().indices().to(DEVICE)
    
    # Forward Pass
    # MNGCL forward returns: emb1, emb2, emb3, concatenated_emb, loss
    _, _, _, emb, _ = model(ppiAdj_index, pathAdj_index, goAdj_index, 
                            data_x, data_x, data_x)
    
    # Extract embeddings for train/test sets
    train_x_emb = torch.sigmoid(emb[train_mask]).cpu().detach().numpy()
    train_y_label = Y[train_mask].cpu().numpy()
    
    test_x_emb = torch.sigmoid(emb[test_mask]).cpu().detach().numpy()
    test_y_label = Y[test_mask].cpu().numpy()
    
    # Final Prediction using Logistic Regression
    print("Training Logistic Regression on learned embeddings...")
    preds = LogReg(train_x_emb, train_y_label, test_x_emb)
    
    # Calculate Metrics
    auc = metrics.roc_auc_score(test_y_label, preds)
    auprc = average_precision_score(test_y_label, preds)
    f1 = f1_score(test_y_label, (preds > 0.5).astype(int))

print("\n" + "="*40)
print(f"MNGCL Reproduction Results")
print("="*40)
print(f"AUROC    : {auc:.4f}")
print(f"AUPRC    : {auprc:.4f}")
print(f"F1-Score : {f1:.4f}")
print("="*40)

Running Inference...
Training Logistic Regression on learned embeddings...

MNGCL Reproduction Results
AUROC    : 0.8757
AUPRC    : 0.7809
F1-Score : 0.7568
