## EMGNN Reproduction Test
### This notebook demonstrates the reproduction of EMGNN model performance on using pre-trained weights.

#### Library import and settings

In [1]:
import torch
import numpy as np
import pandas as pd
import os
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops
from sklearn.preprocessing import StandardScaler

from reproduction_utils import EMGNN, cal_metrics

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


#### Setting up and loading data

In [None]:
FOLD_TO_LOAD = 3
MODEL_PATH = f"./Data/EMGNN.pt"

class Args:
    hidden = 64
    n_layers = 3
    dropout = 0.5
    alpha = 0.2
    nb_heads = 1
    # Model Type
    gat = False
    gcn = True
    gin = False
    mlp = False

args = Args()

print(">>> Loading Global Data & Preprocessing...")

features_order = ['MF: KIRC', 'MF: BRCA', 'MF: READ', 'MF: PRAD', 'MF: STAD', 'MF: HNSC',
 'MF: LUAD', 'MF: THCA', 'MF: BLCA', 'MF: ESCA', 'MF: LIHC', 'MF: UCEC',
 'MF: COAD', 'MF: LUSC', 'MF: CESC', 'MF: KIRP', 'METH: KIRC', 'METH: BRCA',
 'METH: READ', 'METH: PRAD', 'METH: STAD', 'METH: HNSC', 'METH: LUAD',
 'METH: THCA', 'METH: BLCA', 'METH: ESCA', 'METH: LIHC', 'METH: UCEC',
 'METH: COAD', 'METH: LUSC', 'METH: CESC', 'METH: KIRP', 'GE: KIRC', 'GE: BRCA',
 'GE: READ', 'GE: PRAD', 'GE: STAD', 'GE: HNSC', 'GE: LUAD', 'GE: THCA',
 'GE: BLCA', 'GE: ESCA', 'GE: LIHC', 'GE: UCEC', 'GE: COAD', 'GE: LUSC',
 'GE: CESC', 'GE: KIRP', 'CNA: KIRC', 'CNA: BRCA', 'CNA: READ', 'CNA: PRAD',
 'CNA: STAD', 'CNA: HNSC', 'CNA: LUAD', 'CNA: THCA', 'CNA: BLCA', 'CNA: ESCA',
 'CNA: LIHC', 'CNA: UCEC', 'CNA: COAD', 'CNA: LUSC', 'CNA: CESC', 'CNA: KIRP']

gene_index_file = "./Data/STRING_gene_index_mapping.tsv"
node_df = pd.read_csv(gene_index_file, sep="\t")
node_names = node_df["gene"].values
num_nodes = len(node_names)
node_to_index = {gene: idx for idx, gene in enumerate(node_names)}

ppi_file = "../../Data/STRING/STRING_ppi_edgelist.tsv"
ppi_df = pd.read_csv(ppi_file, sep="\t")
ppi_df = ppi_df[ppi_df["partner1"].isin(node_to_index) & ppi_df["partner2"].isin(node_to_index)]

adj = np.zeros((num_nodes, num_nodes))
for _, row in ppi_df.iterrows():
    i, j = node_to_index[row["partner1"]], node_to_index[row["partner2"]]
    confidence = row["confidence"] / 1000 
    adj[i, j] = confidence
    adj[j, i] = confidence 

features_file = "../../Data/STRING/multiomics_features_STRING.tsv"
features_df = pd.read_csv(features_file, sep="\t", index_col=0)
features = features_df.values 
feature_names = features_df.columns.values
feature_names = [str(f) for f in feature_names] 

feature_ind = [feature_names.index(f_n) for f_n in features_order if f_n in feature_names]
features = features[:, feature_ind]

adj = torch.FloatTensor(adj)
features = torch.FloatTensor(features)

node2idx = {}
counter = 0
meta_x = torch.zeros((100000, 64))
meta_y = torch.zeros(1000000, 1) 

for i in node_names:
    if i not in node2idx:
        node2idx[i] = counter
        counter += 1

label_file = f"../../Data/STRING/10fold/fold_{FOLD_TO_LOAD}/labels.txt"
with open(label_file, "r") as f:
    labels_all = [int(line.strip()) for line in f.readlines()]
y_all = torch.tensor(labels_all, dtype=torch.long)

for i, label in enumerate(y_all):
    idx = node2idx[node_names[i]]
    if isinstance(features, torch.Tensor):
        meta_x[idx] = features[i]
    else:
        meta_x[idx] = torch.tensor(features[i])
        
    if(meta_y[idx] == 0):
        meta_y[idx] = label

print("Global Data Setup Complete.")

>>> Loading Global Data & Preprocessing...
Global Data Setup Complete.


#### Load fold data

In [3]:
print(f">>> Loading Data ...")

fold_path = f"../../Data/STRING/10fold/fold_{FOLD_TO_LOAD}"

train_mask = np.loadtxt(os.path.join(fold_path, "train_mask.txt"), dtype=int)
val_mask = np.loadtxt(os.path.join(fold_path, "valid_mask.txt"), dtype=int)
test_mask = np.loadtxt(os.path.join(fold_path, "test_mask.txt"), dtype=int)
labels = np.loadtxt(os.path.join(fold_path, "labels.txt"), dtype=int)

idx_train = torch.LongTensor([i for i in range(num_nodes) if train_mask[i]])
idx_test = torch.LongTensor([i for i in range(num_nodes) if test_mask[i]])

edge_index = (adj > 0).nonzero().t()
edge_index, _ = add_self_loops(edge_index)

data = Data(x=features, edge_index=edge_index, y=torch.tensor(labels), node_names=node_names)

data = data.to(device)
meta_x = meta_x.to(device)
idx_test = idx_test.to(device)

print(f"Data Loaded. Test Nodes: {len(idx_test)}")

>>> Loading Data ...
Data Loaded. Test Nodes: 97


#### Initialize model and load weights

In [5]:
print(">>> Initializing Model...")

nfeat = features.shape[1]

model = EMGNN(nfeat,
              args.hidden,
              args.n_layers,
              nclass=2,
              args=args,
              data=data,
              meta_x=meta_x,
              node2idx=node2idx)

model = model.to(device)

if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device), strict=False)
    print(f"Successfully loaded weights from {MODEL_PATH}")
else:
    raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")

>>> Initializing Model...
Successfully loaded weights from ./Data/EMGNN.pt


#### Inference and Performance Evaluation

In [None]:
print(">>> Running Inference...")

model.eval()
with torch.no_grad():
    output = model(data.x, data.edge_index, data)
    
    output_prob = torch.exp(output)
    y_pred_prob = output_prob[idx_test][:, 1].cpu().numpy()
    
    y_true = data.y[idx_test].cpu().numpy()
    
    acc, auroc, auprc, f1 = cal_metrics(y_true, y_pred_prob)
    
    print("\n" + "="*40)
    print(f"EMGNN Reproduction Results")
    print("="*40)
    print(f"AUROC    : {auroc:.4f}")
    print(f"AUPRC    : {auprc:.4f}")
    print(f"F1-Score : {f1:.4f}")
    print(f"Accuracy : {acc:.4f}")
    print("="*40)