In [1]:
# TxGNN Subgraph Training and Evaluation

# --- Imports and environment check ---
import torch
from txgnn import TxData, TxGNN, TxEval
import os, sys

print('Python version:', sys.version)
print('Torch version:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())

  from .autonotebook import tqdm as notebook_tqdm


Python version: 3.9.7 (tags/v3.9.7:1016ef3, Aug 30 2021, 20:19:38) [MSC v.1929 64 bit (AMD64)]
Torch version: 2.8.0+cpu
CUDA available: False


In [2]:
# --- Config ---

# Path to a subgraph folder (e.g., drug-disease-gene subgraph)
DATA_DIR = '../data/subgraphs/drug-disease-gene'

# Device
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Seeds and save paths
SEED = 42
SAVE_MODEL_PATH = './finetune_subgraph.pt'
SAVE_EVAL_PATH = './eval_subgraph'

# Model dimensions (small for Colab / CPU)
N_HID, N_INP, N_OUT = 64, 64, 64
PROTO, PROTO_NUM = True, 2
ATTN = False
SIM_MEASURE = 'all_nodes_profile'
AGG_MEASURE = 'rarity'
NUM_WALKS, PATH_LEN = 50, 2

# Training schedule
PRETRAIN_EPOCHS = 1      # 0 to skip pretrain
PRETRAIN_LR = 1e-3
PRETRAIN_BATCH = 512
FINETUNE_EPOCHS = 30
FINETUNE_LR = 5e-4
TRAIN_PRINT_N = 10
VALID_EVERY_N = 10

In [3]:
# --- Load data and prepare split ---
print('Loading data from:', DATA_DIR)
TxDataObj = TxData(data_folder_path=DATA_DIR)
TxDataObj.prepare_split(split='full_graph', seed=SEED)
print('Data ready.')

Loading data from: ../data/subgraphs/drug-disease-gene
Found local copy...
Found local copy...
Found local copy...
Found saved processed KG... Loading...
Splits detected... Loading splits....
Creating DGL graph....
Done!
Data ready.


In [4]:
# --- Initialize model ---
Tx = TxGNN(
    data=TxDataObj,
    weight_bias_track=False,
    proj_name='TxGNN',
    exp_name='subgraph_notebook',
    device=DEVICE,
)

Tx.model_initialize(
    n_hid=N_HID,
    n_inp=N_INP,
    n_out=N_OUT,
    proto=PROTO,
    proto_num=PROTO_NUM,
    attention=ATTN,
    sim_measure=SIM_MEASURE,
    agg_measure=AGG_MEASURE,
    num_walks=NUM_WALKS,
    path_length=PATH_LEN,
)
print('Model initialized on', DEVICE)

Model initialized on cpu


In [5]:
# --- Optional pretraining ---
if PRETRAIN_EPOCHS > 0:
    print('Starting pretraining...')
    Tx.pretrain(
        n_epoch=PRETRAIN_EPOCHS,
        learning_rate=PRETRAIN_LR,
        batch_size=PRETRAIN_BATCH,
        train_print_per_n=TRAIN_PRINT_N,
    )
else:
    print('Skipping pretrain step.')

Starting pretraining...
Creating minibatch pretraining dataloader...
Start pre-training with #param: 419200
Epoch: 0 Step: 0 LR: 0.00100 Loss 0.6945, Pretrain Micro AUROC 0.4759 Pretrain Micro AUPRC 0.4867 Pretrain Macro AUROC 0.4447 Pretrain Macro AUPRC 0.5878
Epoch: 0 Step: 10 LR: 0.00100 Loss 0.6871, Pretrain Micro AUROC 0.5748 Pretrain Micro AUPRC 0.5821 Pretrain Macro AUROC 0.5700 Pretrain Macro AUPRC 0.6945
Epoch: 0 Step: 20 LR: 0.00100 Loss 0.6833, Pretrain Micro AUROC 0.5875 Pretrain Micro AUPRC 0.5855 Pretrain Macro AUROC 0.5832 Pretrain Macro AUPRC 0.6786
Epoch: 0 Step: 30 LR: 0.00100 Loss 0.6779, Pretrain Micro AUROC 0.6074 Pretrain Micro AUPRC 0.5946 Pretrain Macro AUROC 0.5940 Pretrain Macro AUPRC 0.7044
Epoch: 0 Step: 40 LR: 0.00100 Loss 0.6733, Pretrain Micro AUROC 0.6039 Pretrain Micro AUPRC 0.5972 Pretrain Macro AUROC 0.5984 Pretrain Macro AUPRC 0.7030


KeyboardInterrupt: 

In [None]:
# --- Finetune on drugâ€“disease prediction ---
print('Starting finetune...')
Tx.finetune(
    n_epoch=FINETUNE_EPOCHS,
    learning_rate=FINETUNE_LR,
    train_print_per_n=TRAIN_PRINT_N,
    valid_per_n=VALID_EVERY_N,
    save_name=SAVE_MODEL_PATH,
)
print('Finetune complete. Model saved to', SAVE_MODEL_PATH)

In [None]:
# --- Evaluate ---
print('Evaluating...')
TxE = TxEval(model=Tx)
results = TxE.eval_disease_centric(
    disease_idxs='test_set',
    show_plot=False,
    verbose=True,
    save_result=True,
    return_raw=False,
    save_name=SAVE_EVAL_PATH,
)
print('\nEvaluation summary:')
print(results)

---
### Notes
- You can re-run this notebook for different subgraph folders by changing `DATA_DIR` at the top.
- The printed output and metrics will remain visible in the `.ipynb` file after saving, so your professor can see the training/evaluation logs.
- For faster testing on CPU, reduce `FINETUNE_EPOCHS` or skip pretraining.