## DISFusion Reproduction Test
### This notebook demonstrates the reproduction of DISFusion model performance on using pre-trained weights.

#### Importing libraries and functions

In [2]:
import torch
import numpy as np
import pandas as pd
import random
import os
from sklearn.preprocessing import StandardScaler

from reproduction_utils import (
    graph_ChebNet, hypergrph_HGNN, DISFusion,
    processingIncidenceMatrix, getData, _generate_G_from_H_weight, 
    load_fold_data, get_train_test_indices_from_gene_names, cal_metrics
)

# GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


#### Prepraring primary data

In [3]:
class Args:
    lr = 1e-5
    epochs = 200
    n_hid = 256
    dropout = 0.5
    lambdinter = 1e-4
    in_channels = 48 
args = Args()

# Path
dataPath = '../../Data/STRING'
FOLD_TO_LOAD = 8
MODEL_PATH = f"./Data/DISFusion.pt"

positiveGenePath = f'{dataPath}/dataset/pan-cancer/715true.txt'
negativeGenePath = f'{dataPath}/dataset/pan-cancer/1231false.txt'
feature_genename_file = f'{dataPath}/feature_genename.txt'
ppi_edgelist_file = f'{dataPath}/STRING_ppi_edgelist.tsv'

print(">>> Loading Global Data...")

# 1. Gene List

geneList = pd.read_csv(r'./Data/geneList.txt', header=None)
geneList = list(geneList.iloc[:,1].values)

# Incidence Matrix (Function Hypergraph)
print("   Generating Incidence Matrix...")
Function_hypergraph = processingIncidenceMatrix(geneList, dataPath)

# PPI Network
print("   Loading PPI Network...")
gene_name_file = f'{dataPath}/feature_genename.txt'
ppi_edgelist_file = f'{dataPath}/STRING_ppi_edgelist.tsv'
gene_to_idx = {gene.strip(): i for i, gene in enumerate(open(gene_name_file))}

ppi_df = pd.read_csv(ppi_edgelist_file, sep='\t', usecols=['partner1', 'partner2'])
ppi_df['partner1'] = ppi_df['partner1'].map(gene_to_idx)
ppi_df['partner2'] = ppi_df['partner2'].map(gene_to_idx)
ppi_df.dropna(inplace=True)
edge = ppi_df.astype(int).values.transpose()
PPI_graph = torch.from_numpy(edge).long().to(device)

# 3. Multi-omics Feature
print("   Loading Omics Features...")
data_x_df = pd.read_csv(dataPath + '/multiomics_features_STRING.tsv', sep='\t', index_col=0)
data_x_df = data_x_df.dropna()
scaler = StandardScaler()
features_scaled = scaler.fit_transform(data_x_df.values)
data_x = torch.tensor(features_scaled, dtype=torch.float32)
multi_feature = data_x[:, :args.in_channels].to(device)

# 5. Label Frame
filtered_geneList = pd.read_csv(gene_name_file, header=None).iloc[:, 0].tolist()

auroc_list, auprc_list, f1_list = [], [], []
evaluation_res = pd.DataFrame(index=filtered_geneList)

fold_dir = f'{dataPath}/10fold'
fold_paths = [os.path.join(fold_dir, f'fold_{i+1}') for i in range(10)]

sampleIndex, label, labelFrame = getData(positiveGenePath, negativeGenePath, filtered_geneList)

print("Global Data Loaded.")

>>> Loading Global Data...
   Generating Incidence Matrix...
Original geneList size: 13627 â†’ Filtered size: 10251
Final incidenceMatrix shape: (10251, 20528)
   Loading PPI Network...
   Loading Omics Features...
Global Data Loaded.


#### Data Load and Graph Reconfiguration

In [6]:
print(f">>> Reconstructing Graph")

fold_path = f'{dataPath}/10fold/fold_{FOLD_TO_LOAD}'
train_idx, valid_idx, test_idx, train_mask, valid_mask, test_mask, labels = load_fold_data(fold_path)

trainIndex, testIndex = get_train_test_indices_from_gene_names(
    feature_genename_file, 
    train_idx, 
    test_idx, 
    sampleIndex, 
    labelFrame
)
print(f"Train Size: {len(trainIndex)}, Test Size: {len(testIndex)}")

trainFrame = labelFrame.iloc[trainIndex]
trainPositiveGene = list(trainFrame.where(trainFrame==1).dropna().index)
positiveMatrixSum = Function_hypergraph.loc[trainPositiveGene].sum()

# Disease-specific Hyperedge Selection
selHyperedgeIndex = np.where(positiveMatrixSum >= 3)[0]
selHyperedge = Function_hypergraph.iloc[:, selHyperedgeIndex]
hyperedgeWeight = positiveMatrixSum[selHyperedgeIndex].values
selHyperedgeWeightSum = Function_hypergraph.iloc[:, selHyperedgeIndex].values.sum(0)
hyperedgeWeight = hyperedgeWeight / selHyperedgeWeightSum

# Incidence Matrix H
H = np.array(selHyperedge).astype('float')
DV = np.sum(H * hyperedgeWeight, axis=1)

# Random Isolation Handling
for i in range(DV.shape[0]):
    if(DV[i] == 0):
        t = random.randint(0, H.shape[1]-1)
        H[i][t] = 0.0001

# Adjacency Matrix G
G = _generate_G_from_H_weight(H, hyperedgeWeight)
N = H.shape[0]

adj_hyperGraph = torch.Tensor(G).float().to(device)
fh = torch.eye(N).float().to(device)
theLabels = torch.from_numpy(labelFrame.values.reshape(-1,)).to(device)

print(f"Hypergraph Constructed (N={N}).")

>>> Reconstructing Graph
Train Size: 1752, Test Size: 97
Hypergraph Constructed (N=10251).


#### Initialize model and load weights

In [8]:
print(">>> Initializing Models & Loading Weights...")

model_hypergrph = hypergrph_HGNN(in_ch=N, n_hid=args.n_hid, dropout=0.2).to(device)
model_graph = graph_ChebNet(hdim=args.n_hid, dropout=0.5).to(device)
model_fusion = DISFusion(input_dim=args.n_hid, length=2, lambdinter=args.lambdinter, attention=0, nb_classes=2, dropout=args.dropout).to(device)

if os.path.exists(MODEL_PATH):
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    
    model_hypergrph.load_state_dict(checkpoint['model_hypergrph'])
    model_graph.load_state_dict(checkpoint['model_graph'])
    model_fusion.load_state_dict(checkpoint['model_fusion'])
    
    print(f"All models loaded from {MODEL_PATH}")
else:
    raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")

>>> Initializing Models & Loading Weights...
All models loaded from ./Data/DISFusion.pt


#### Inference and Performance Evaluation

In [10]:
print(">>> Running Inference...")

model_hypergrph.eval()
model_graph.eval()
model_fusion.eval()

with torch.no_grad():
    # Forward
    out_hyper = model_hypergrph(fh, adj_hyperGraph)
    out_graph = model_graph(multi_feature, PPI_graph)
    
    _, output = model_fusion(out_hyper, out_graph)
    
    test_output = output[testIndex]
    test_labels = theLabels[testIndex]
    
    auc, auprc, f1 = cal_metrics(test_output, test_labels)
    
    print("\n" + "="*40)
    print(f"DISFusion Reproduction Results")
    print("="*40)
    print(f"AUROC    : {auc:.4f}")
    print(f"AUPRC    : {auprc:.4f}")
    print(f"F1-Score : {f1:.4f}")
    print("="*40)

>>> Running Inference...

DISFusion Reproduction Results
AUROC    : 0.9207
AUPRC    : 0.8631
F1-Score : 0.8148
