## Mount Drive

In [1]:
# mount drive
import os
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
os.chdir('/content/drive/My Drive/Colab_Notebooks/github/GGAT-GatedFusion')

Mounted at /content/drive/


# For Torch

In [2]:
import torch
print(torch.__version__)

2.9.0+cu126


In [3]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.6.0+cu124.html

Looking in links: https://data.pyg.org/whl/torch-2.6.0+cu124.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch_scatter-2.1.2%2Bpt26cu124-cp312-cp312-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m124.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch_sparse-0.6.18%2Bpt26cu124-cp312-cp312-linux_x86_64.whl (5.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m137.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch_cluster-1.6.3%2Bpt26cu124-cp312-cp312-linux_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m133.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch

In [4]:
# Enable dynamic memory allocation in PyTorch to help with fragmentation
# Before importing torch
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [5]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"      # This will force CUDA to throw the error at the exact line it happens, not later at torch.tensor(...).

# Import from single channel

In [6]:
from GGAT_singleChannel import *

  import torch_geometric.typing
  import torch_geometric.typing
  import torch_geometric.typing
  import torch_geometric.typing


# Main

## Set Parameters

In [7]:
dataset = 'RR1'
fold = 5
epochs = 200
node_file_path = f'data/{dataset}/interactom_nodes.txt'   # stores the nodes for the largest connected component in human Interactome
edge_list_file_path = f'data/{dataset}/interactom_edges.txt' # stores the edges for the largest connected component in human Interactome
train_file_path = f'data/{dataset}/Fold{fold}/train_set.tsv'
test_file_path = f'data/{dataset}/Fold{fold}/test_set.tsv'

### Fusion args

In [8]:
class Args:
  no_cuda = False
  seed = 42
  use_valid = True
  min_delta = 1e-4

# 0. get all the args
args = Args()

### Single Channel args

In [9]:
class n2vArgs:
  model = "GGATGRU" # "GGATGRU"
  model_type = "n2v"  # options: "label", "n2v"
  no_cuda = False
  seed = 42
  epochs = 3000
  lr = 0.005
  weight_decay = 5e-4
  hidden = 8 # per_head
  nb_heads = 8
  inter_dim = 32  # newly added to control the dims between different layers
  input_dropout = 0 # 0.6 may be too high, tried 0.3 with 500, currently set to 0 to stop input_dropout
  gat_dropout = 0.4  # seperate gat with input
  dataset = "dis"
  use_valid = True
  earlystop = False
  early_stop_patience = 3000  # stop if no improvement after x epochs: 20, 100, 200, 500
  min_delta = 1e-4  # Only count as improvement if gain is > 0.0001

# 0. get all the args
n2vargs = n2vArgs()

In [10]:
class labelArgs:
  model = "GGATGRU" # "GGATGRU"
  model_type = "label"  # options: "label", "n2v"
  no_cuda = False
  seed = 42
  epochs = 3000
  lr = 0.005
  weight_decay = 5e-4
  hidden = 8 # per_head
  nb_heads = 8
  inter_dim = 32  # newly added to control the dims between different layers
  input_dropout = 0 # 0.6 may be too high, tried 0.3 with 500, currently set to 0 to stop input_dropout
  gat_dropout = 0.4  # seperate gat with input
  dataset = "dis"
  use_valid = True
  earlystop = False
  early_stop_patience = 3000  # stop if no improvement after x epochs: 20, 100, 200, 500
  min_delta = 1e-4  # Only count as improvement if gain is > 0.0001

# 0. get all the args
labelargs = labelArgs()

### continue

In [11]:
# the file that saved best model
model_path_label = f'results/{dataset}/Fold{fold}/{dataset}_fold{fold}_best_model_{labelargs.model}_label_{labelargs.epochs}epochs_{labelargs.inter_dim}inter.pt'
model_path_n2v = f'results/{dataset}/Fold{fold}/{dataset}_fold{fold}_best_model_{n2vargs.model}_n2v_{n2vargs.epochs}epochs_{n2vargs.inter_dim}inter.pt'

# training log file
log_file = open(f"results/{dataset}/Fold{fold}/{dataset}_fold{fold}_training_log_gated_fusion_{epochs}epochs.txt", "w")

# to save the predictions
output_file_gated_best = f'results/{dataset}/Fold{fold}/{dataset}_fold{fold}_gated_fusion_predictions_best_model_{epochs}epochs.tsv'
output_file_gated_last =f'results/{dataset}/Fold{fold}/{dataset}_fold{fold}_gated_fusion_predictions_last_model_{epochs}epochs.tsv'


# to save best model
best_model_path = f'results/{dataset}/Fold{fold}/{dataset}_fold{fold}_gated_fusion_best_model_{epochs}epochs.pt'
last_model_path = f'results/{dataset}/Fold{fold}/{dataset}_fold{fold}_gated_fusion_last_model_{epochs}epochs.pt'

## Load Data Set

In [12]:
# 1. get graph original nodes
node_idx_dict = get_gene_idx_dict_from_file(node_file_path)
node_gene_dict = {v:k for k,v in node_idx_dict.items()}
# print(node_idx_dict)

# 2. get selected disease pairs
# [(disA, disB), ...], [label, ...], {disease: [gene_1, gene_2, ...]}]
train_dis_pairs, train_labels, train_disease_genes_dict = get_disease_sets(train_file_path)
test_dis_pairs, test_labels, test_disease_genes_dict = get_disease_sets(test_file_path)
train_disease_pair_rr = get_disease_pair_rr_list(train_dis_pairs, train_labels, train_disease_genes_dict, node_idx_dict)
test_disease_pair_rr = get_disease_pair_rr_list(test_dis_pairs, test_labels, test_disease_genes_dict, node_idx_dict)
# print(train_disease_pair_rr[0])


## Prep Model and Data

In [13]:
# train_disease_pair_rr = [(gene_list, label), ...]
gene_lists = [x[0] for x in train_disease_pair_rr]
labels = [x[1] for x in train_disease_pair_rr]

# stratified train-validation split
log_fn = lambda msg: print(msg)

train_set, val_set = split_train_val(
    gene_lists, labels, use_valid=args.use_valid, log_fn=log_fn
)

Train Label Distribution: Counter({1: 5078, 0: 3624})
Val Label Distribution: Counter({1: 564, 0: 403})


In [14]:
print(dataset)
print(fold)
print(n2vargs.model)
print(n2vargs.model_type)
print(n2vargs.epochs)
print(labelargs.model)
print(labelargs.model_type)
print(labelargs.epochs)

RR1
5
GGATGRU
n2v
3000
GGATGRU
label
3000


# fuse embed and train

## prep models and data

In [15]:
args.cuda = not args.no_cuda and torch.cuda.is_available() # Uses GPU if (CUDA is not explicitly disabled by the user) and (it is available)
device = 'cuda' if args.cuda else 'cpu' # Stores the current device, need it when constructing tensors or models on the same device.

##### n2v model
# prep GGATGRU models and data based on the model choice
n2v_data = label_data = n2v_model = label_model = fusion_proj = None

n2v_data, n2v_x = prepare_data('./data/dis/', 'dis')
n2v_model = build_ggat(n2v_x.shape[1], n2vargs)




##### label model
label_data, label_x = prepare_data('./data/dis/', 'label2vec')
label_model = build_ggat(label_x.shape[1], labelargs)


# # loss function
# loss_fn = nn.BCEWithLogitsLoss()

# If use cuda: move the model and data to CUDA.
label_model, label_data, n2v_model, n2v_data = to_device(
    label_model, label_data, n2v_model, n2v_data, device = device)



Loading dis dataset...
Loading label2vec dataset...


## load weights from saved models

In [17]:
n2v_checkpoint = torch.load(model_path_n2v, map_location=device, weights_only=False)
print("Checkpoint keys:", n2v_checkpoint.keys())

n2v_model.load_state_dict(n2v_checkpoint['n2v_model_state_dict'])

print(f"Model loaded from checkpoint: {model_path_n2v}")

label_checkpoint = torch.load(model_path_label, map_location=device, weights_only=False)
print("Checkpoint keys:", label_checkpoint.keys())

label_model.load_state_dict(label_checkpoint['label_model_state_dict'])

print(f"Model loaded from checkpoint: {model_path_label}")

Checkpoint keys: dict_keys(['predictor_state_dict', 'pooler_state_dict', 'n2v_model_state_dict', 'best_val_auc'])
Model loaded from checkpoint: results/RR1/Fold5/RR1_fold5_best_model_GGATGRU_n2v_3000epochs_32inter.pt
Checkpoint keys: dict_keys(['predictor_state_dict', 'pooler_state_dict', 'label_model_state_dict', 'best_val_auc'])
Model loaded from checkpoint: results/RR1/Fold5/RR1_fold5_best_model_GGATGRU_label_3000epochs_32inter.pt


In [18]:
label_model.eval()
n2v_model.eval()

with torch.no_grad():
    label_embs = label_model(label_data.x, label_data.edge_index)
    n2v_embs = n2v_model(n2v_data.x, n2v_data.edge_index)

## Fusion model

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import numpy as np


class AttentionPooling(nn.Module):
    def __init__(self, input_dim, return_weights=False):
        super().__init__()
        self.att_mlp = nn.Sequential(
            nn.Linear(input_dim, input_dim * 2),
            nn.Tanh(),
            nn.Linear(input_dim * 2, 1)
        )
        self.return_weights = return_weights

    def forward(self, node_embs):
        attn_weights = self.att_mlp(node_embs)
        attn_weights = torch.softmax(attn_weights, dim=0)
        pooled = (attn_weights * node_embs).sum(dim=0)
        if self.return_weights:
            return pooled, attn_weights.squeeze()
        return pooled


class GatedFusionLayer(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.gate_net = nn.Sequential(
            nn.Linear(input_dim * 2, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, input_dim),
            nn.Sigmoid()
        )

    def forward(self, emb1, emb2):
        x = torch.cat([emb1, emb2], dim=1)
        gate = self.gate_net(x)
        return gate * emb1 + (1 - gate) * emb2    # fused = g * label_emb + (1 - g) * n2v_emb


class FusionModel(nn.Module):
    def __init__(self, input_dim, use_fusion_proj=True, attn_return_weights=False):
        super().__init__()
        self.use_fusion_proj = use_fusion_proj
        self.attn_return_weights = attn_return_weights

        if use_fusion_proj:
            self.fusion = GatedFusionLayer(input_dim)
            self.att_pool = AttentionPooling(input_dim, return_weights=attn_return_weights)
        else:
            self.fusion = None
            self.att_pool = AttentionPooling(input_dim * 2, return_weights=attn_return_weights)

        self.rr_predictor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, label_embs, n2v_embs, gene_indices):
        if self.use_fusion_proj:
            fused = self.fusion(label_embs, n2v_embs)
        else:
            fused = torch.cat([label_embs, n2v_embs], dim=1)

        subset = fused[gene_indices]
        if self.attn_return_weights:
            pooled, attn_weights = self.att_pool(subset)
            logit = self.rr_predictor(pooled.unsqueeze(0))
            return logit.squeeze(), attn_weights
        else:
            pooled = self.att_pool(subset)
            logit = self.rr_predictor(pooled.unsqueeze(0))
            return logit.squeeze()


def collate_disease_pairs(batch):
    gene_sets, labels = [], []
    for gene_list, y in batch:
        gene_sets.append(gene_list)
        labels.append(y)
    return gene_sets, torch.tensor(labels, dtype=torch.float32)


def evaluate_on_validation(model, label_embs, n2v_embs, val_set, device='cuda'):
    model.eval()
    label_embs = label_embs.to(device)
    n2v_embs = n2v_embs.to(device)

    preds, probs, labels = [], [], []

    with torch.no_grad():
        for gene_indices, y in val_set:
            output = model(label_embs, n2v_embs, gene_indices)
            if isinstance(output, tuple):
                logit, _ = output
            else:
                logit = output
            prob = torch.sigmoid(logit).item()
            pred = int(prob >= 0.5)
            preds.append(pred)
            probs.append(prob)
            labels.append(y)

    labels = torch.tensor(labels)
    preds = torch.tensor(preds)
    probs = torch.tensor(probs)

    acc = (preds == labels).float().mean().item()
    auc = roc_auc_score(labels.numpy(), probs.numpy()) if len(set(labels.numpy())) == 2 else float('nan')
    return acc, auc


def plot_attention(weights, gene_ids=None, title="Attention Weights"):
    plt.figure(figsize=(10, 4))
    plt.plot(weights.detach().cpu().numpy(), marker='o')
    if gene_ids:
        plt.xticks(ticks=range(len(gene_ids)), labels=gene_ids, rotation=90)
    plt.title(title)
    plt.xlabel("Gene Index")
    plt.ylabel("Attention Weight")
    plt.tight_layout()
    plt.show()


def train_fusion_model(
    model,
    label_embs,
    n2v_embs,
    train_set,
    val_set,
    epochs=30,
    lr=1e-3,
    device='cuda',
    batch_size=1,
    log_every=1,
    log_file=None,
    min_delta=0.001,
    best_model_path="best_model.pt",
    last_model_path="last_model.pt",
):

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()

    best_val_auc = -1.0

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0
        all_preds = []
        all_labels = []
        all_probs = []

        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_disease_pairs)

        for i, (gene_sets, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            logits = []
            for gset in gene_sets:
                output = model(label_embs, n2v_embs, gset)
                if isinstance(output, tuple):
                    logit, attn_weights = output
                else:
                    logit = output
                logits.append(logit)

            logits = torch.stack(logits).to(device)
            labels = labels.to(device)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.detach().cpu().numpy())

        acc = (np.array(all_preds) == np.array(all_labels)).mean()
        try:
            auc = roc_auc_score(all_labels, all_probs)
        except:
            auc = float('nan')

        log_msg = f"[Epoch {epoch}] || [In-training Metrics]: Train Loss: {total_loss:.4f} | Train Acc: {acc:.4f} | Train AUC: {auc:.4f}"

        if epoch % log_every == 0:
            val_acc, val_auc = evaluate_on_validation(model, label_embs, n2v_embs, val_set, device)
            log_msg += f" || [End-of-epoch metrics]: Val Acc: {val_acc:.4f} | Val AUC: {val_auc:.4f}"
            if val_auc > best_val_auc + min_delta:
                best_val_auc = val_auc
                best_states = {'fusionModel_dict': model.state_dict(), 'best_val_auc': best_val_auc}
                # print(model.state_dict().items())
                torch.save(best_states, best_model_path)
                log_msg += "\n===Model Updated==="

        if log_file:
            log_print(log_msg, log_file)
        else:
            print(log_msg)

    last_states = {'fusionModel_dict': model.state_dict(), 'best_val_auc': best_val_auc}
    torch.save(last_states, last_model_path)
    final_msg = f"Last model saved to {last_model_path} with ROC AUC: {best_val_auc:.4f}\n" \
                f"Best model saved to {best_model_path} with Val ROC AUC: {best_val_auc:.4f}"
    if log_file:
        log_print(final_msg, log_file)
        log_file.close()
    else:
        print(final_msg)


## test functions

In [20]:
def get_test_probs(fusion_model, label_model, n2v_model, label_data, n2v_data, test_set, device='cuda'):
    fusion_model.eval()
    label_model.eval()
    n2v_model.eval()

    with torch.no_grad():
        label_embs = label_model(label_data.x.to(device), label_data.edge_index.to(device))
        n2v_embs = n2v_model(n2v_data.x.to(device), n2v_data.edge_index.to(device))

        test_probs = []
        y_test = []

        for gene_list, label in test_set:
            logit, _ = fusion_model(label_embs, n2v_embs, gene_list)
            prob = torch.sigmoid(logit).item()
            test_probs.append(prob)
            y_test.append(label)

    return np.array(y_test), np.array(test_probs)
from sklearn.metrics import accuracy_score

def save_predictions_to_tsv(rows, output_file):
    """
      save prediction results and metrics to a tsv file.

      Args:
          rows (list of dict): Prediction records, each with keys like 'prob', 'label', 'acc', etc.
          output_file (str): Path to the output .tsv file
    """
    fieldnames = list(rows[0].keys())
    with open(output_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter='\t')
        writer.writeheader()
        writer.writerows(rows)

    print(f"Saved test predictions and metrics to {output_file}")

def find_best_thresh(test_probs):
  best_thresh = 0.5
  best_mcc = -1
  for thresh in np.arange(0.0, 1.01, 0.01):
      temp_preds = (test_probs >= thresh).astype(int)
      mcc = matthews_corrcoef(y_true, temp_preds)
      if mcc > best_mcc:
          best_mcc = mcc
          best_thresh = thresh
  return best_mcc, best_thresh


def compute_metrics(y_true, test_probs):

  # Find best threshold using MCC
  best_mcc, best_thresh = find_best_thresh(test_probs)

  # Final predictions using best threshold
  test_preds = (test_probs >= best_thresh).astype(int)

  # Compute metrics
  roc_auc = roc_auc_score(y_true, test_probs)
  accuracy = accuracy_score(y_true, test_preds)
  auprc = average_precision_score(y_true, test_probs)

  # Print results
  print(f"Best Threshold: {best_thresh:.4f}")
  print(f"ROC AUC:       {roc_auc:.4f}")
  print(f"Accuracy:      {accuracy:.4f}")
  print(f"AUPRC:         {auprc:.4f}")
  print(f"MCC:           {best_mcc:.4f}")

  return test_preds, best_mcc, best_thresh, accuracy, auprc, roc_auc

def save_to_rows(test_dis_pairs, y_true, test_preds, test_probs, best_thresh, accuracy, mcc, auprc, roc_auc):
  rows = []

  for i in range(len(test_dis_pairs)):
    rows.append({
                  "pair_id": i,
                  "disease_pair": "&".join(test_dis_pairs[i]),
                  "label": int(y_true[i]),
                  "prob": test_probs[i]
                    })

  for i, row in enumerate(rows):
      row["pred"] = int(test_preds[i] > best_thresh)
      row["acc"] = f"{accuracy:.4f}"
      row["mcc"] = f"{mcc:.4f}"
      row["auprc"] = f"{auprc:.4f}"
      row["roc_auc"] = f"{roc_auc:.4f}"
      row["best_thresh"] = f"{best_thresh:.4f}"
  return rows


def evaluate_predictions_from_file_J_score(pred_file):
    labels = []
    preds = []
    probs = []

    # read predictions
    with open(pred_file, 'r') as f:
        reader = csv.DictReader(f, delimiter='\t')
        for row in reader:
            labels.append(int(row["label"]))
            preds.append(int(row["pred"]))
            probs.append(float(row["prob"]))

    labels = np.array(labels)
    preds = np.array(preds)
    probs = np.array(probs)

    # calculate metrics
    acc = accuracy_score(labels, preds)
    mcc = matthews_corrcoef(labels, preds)
    auprc = average_precision_score(labels, probs)

    # auc and Youden's J optimal threshold
    try:
        auc = roc_auc_score(labels, probs)
        fpr, tpr, thresholds = roc_curve(labels, probs)
        j_scores = tpr - fpr
        best_thresh = thresholds[np.argmax(j_scores)]   # using j_score, a little different from above
    except ValueError:
        auc = float('nan')
        best_thresh = float('nan')

    print(f"\nFrom Predictions — Best Threshold (J): {best_thresh:.4f} | "
      f"Acc: {acc:.4f} | ROC AUC: {auc:.4f} | MCC: {mcc:.4f} | AUPRC: {auprc:.4f}")

def evaluate_predictions_from_file_mcc(pred_file):
    labels = []
    preds = []
    probs = []

    # read predictions
    with open(pred_file, 'r') as f:
        reader = csv.DictReader(f, delimiter='\t')
        for row in reader:
            labels.append(int(row["label"]))
            preds.append(int(row["pred"]))
            probs.append(float(row["prob"]))

    labels = np.array(labels)
    preds = np.array(preds)
    probs = np.array(probs)

    # Find best threshold using MCC
    best_mcc, best_thresh = find_best_thresh(probs)

    # Final predictions using best threshold
    test_preds = (probs >= best_thresh).astype(int)

    # Compute metrics
    auc = roc_auc_score(labels, probs)
    acc = accuracy_score(labels, test_preds)
    auprc = average_precision_score(labels, probs)

    print(f"\nFrom Predictions — Best thresh: {best_thresh:.4f} | "
      f"Acc: {acc:.4f} | ROC AUC: {auc:.4f} | MCC: {best_mcc:.4f} | AUPRC: {auprc:.4f}")


## Train

In [None]:
fusion_model = FusionModel(
    input_dim=32,               # your per-gene embedding dimension
    use_fusion_proj=True,       # use gated fusion
    attn_return_weights=True    # return attention weights (for optional plotting)
).to(device)

train_fusion_model(
    model=fusion_model,
    label_embs=label_embs,        # [N_genes, 32]
    n2v_embs=n2v_embs,            # [N_genes, 32]
    train_set=train_set + val_set,  # if intentionally training on full data
    val_set=train_set + val_set,              # still evaluate on val_set only
    epochs=epochs,
    log_file=log_file,
    log_every=1,
    best_model_path=best_model_path,
    last_model_path=last_model_path
)


## Test

In [22]:
y_true, test_probs = get_test_probs(fusion_model, label_model, n2v_model, label_data, n2v_data, test_disease_pair_rr)

In [None]:
test_preds, best_mcc, best_thresh, accuracy, auprc, roc_auc = compute_metrics(y_true, test_probs)
rows = save_to_rows(test_dis_pairs, y_true, test_preds, test_probs, best_thresh, accuracy, best_mcc, auprc, roc_auc)
save_predictions_to_tsv(rows, output_file_gated_last)

In [None]:
# last model
evaluate_predictions_from_file_J_score(output_file_gated_last)
evaluate_predictions_from_file_mcc(output_file_gated_last)

# load from files

In [None]:
# load from files
args.cuda = not args.no_cuda and torch.cuda.is_available() # Uses GPU if (CUDA is not explicitly disabled by the user) and (it is available)
device = 'cuda' if args.cuda else 'cpu' # Stores the current device, need it when constructing tensors or models on the same device.

##### n2v model
# prep GGATGRU models and data based on the model choice
n2v_data = label_data = n2v_model = label_model = fusion_proj = None

n2v_data, n2v_x = prepare_data('./data/dis/', 'dis')
n2v_model = build_ggat(n2v_x.shape[1], n2vargs)


##### label model
label_data, label_x = prepare_data('./data/dis/', 'label2vec')
label_model = build_ggat(label_x.shape[1], labelargs)


# # loss function
# loss_fn = nn.BCEWithLogitsLoss()

# If use cuda: move the model and data to CUDA.
label_model, label_data, n2v_model, n2v_data = to_device(
    label_model, label_data, n2v_model, n2v_data, device = device)


n2v_checkpoint = torch.load(model_path_n2v, map_location=device, weights_only=False)
print("Checkpoint keys:", n2v_checkpoint.keys())

n2v_model.load_state_dict(n2v_checkpoint['n2v_model_state_dict'])

print(f"Model loaded from checkpoint: {model_path_n2v}")

label_checkpoint = torch.load(model_path_label, map_location=device, weights_only=False)
print("Checkpoint keys:", label_checkpoint.keys())

label_model.load_state_dict(label_checkpoint['label_model_state_dict'])

print(f"Model loaded from checkpoint: {model_path_label}")


label_model.eval()
n2v_model.eval()

with torch.no_grad():
    label_embs = label_model(label_data.x, label_data.edge_index)
    n2v_embs = n2v_model(n2v_data.x, n2v_data.edge_index)

checkpoint = torch.load(best_model_path, map_location=device, weights_only=False)
print("Checkpoint keys:", checkpoint.keys())
# print(checkpoint['best_val_auc'])

fusion_model = FusionModel(
    input_dim=32,
    use_fusion_proj=True,
    attn_return_weights=True  # only needed if you're inspecting attention weights
).to(device)

fusion_model.load_state_dict(checkpoint['fusionModel_dict'])

y_true, test_probs = get_test_probs(fusion_model, label_model, n2v_model, label_data, n2v_data, test_disease_pair_rr)
test_preds, best_mcc, best_thresh, accuracy, auprc, roc_auc = compute_metrics(y_true, test_probs)
rows = save_to_rows(test_dis_pairs, y_true, test_preds, test_probs, best_thresh, accuracy, best_mcc, auprc, roc_auc)
save_predictions_to_tsv(rows, output_file_gated_best)

In [None]:
# last model
evaluate_predictions_from_file_J_score(output_file_gated_best)
evaluate_predictions_from_file_mcc(output_file_gated_best)

# Disconnect after Done with Train and Test

In [None]:
time.sleep(600)  # Wait 10min to ensure all things finish
# Disconnect runtime
from IPython.display import Javascript
display(Javascript('google.colab.kernel.disconnect();'))