## TREE Reproduction Test
### This notebook demonstrates the reproduction of TREE model performance on using pre-trained weights.

#### Importing libraries and functions

In [1]:
import os
import numpy as np
import h5py
import tensorflow as tf
from sklearn.model_selection import StratifiedKFold
from config import ModelConfig, PROCESSED_DATA_DIR, ADJ_TEMPLATE, FEATURE_TEMPLATE, SPATIAL_TEMPLATE, SUBGRAPHA_TEMPLATE, H5_FILE
from utils import format_filename
from models import TREE
from main import get_optimizer, seed_tensorflow

# GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

seed_tensorflow(42)

#### Setting up and loading data

In [2]:
DATASET = 'STRINGdb'
FOLD_TO_LOAD = 1
MODEL_PATH = f"./Data/TREE.h5"


N_GRAPHS = 3
N_NEIGHBORS = 8
N_LAYERS = 5
SPATIAL_TYPE = 'rw'
EMBED_DIM = 64
NUM_HEADS = 4
DFF = 256
DROPOUT = 0.5
BATCH_SIZE = 64

print(f">>> Loading Data for {DATASET}...")

datapath = H5_FILE[DATASET]
with h5py.File(datapath, 'r') as f:
    y_train = f['y_train'][:]
    y_test = f['y_test'][:]
    y_val = f['y_val'][:] if 'y_val' in f else None
    train_mask = f['mask_train'][:]
    test_mask = f['mask_test'][:]
    val_mask = f['mask_val'][:] if 'mask_val' in f else None

y_train_val = np.logical_or(y_train, y_val)
mask_train_val = np.logical_or(train_mask, val_mask)
label_idx = np.where(mask_train_val == 1)[0]
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=100)
splits = list(kf.split(label_idx, y_train_val[label_idx]))

train_indices, val_indices = splits[FOLD_TO_LOAD - 1]
train_id = label_idx[train_indices]
val_id = label_idx[val_indices]
test_id = np.where(test_mask == 1)[0]

train_label = y_train_val[train_id]
val_label = y_train_val[val_id]
test_label = y_test[test_id]

print(f"Data Prepared.")

>>> Loading Data for STRINGdb...
Data Prepared.


#### Initialize model and load weights

In [3]:
print(">>> Initializing Model...")

# Config
config = ModelConfig()
config.dataset = DATASET
config.d_model = EMBED_DIM
config.n_layers = N_LAYERS
config.concat_n_layers = N_LAYERS
config.n_graphs = N_GRAPHS
config.n_neighbors = N_NEIGHBORS
config.num_heads = NUM_HEADS
config.dff = DFF
config.d_sp_enc = DFF
config.dropout = DROPOUT
config.max_degree = int(max(np.sum(np.load(format_filename(PROCESSED_DATA_DIR, ADJ_TEMPLATE, dataset=DATASET), allow_pickle=True), axis=-1)) + 1)
config.optimizer = get_optimizer('adam', 0.004)
config.loss_mul = 0.24
config.training = False

config.distance_matrix = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_TEMPLATE, dataset=DATASET), allow_pickle=True)
config.node_feature = np.load(format_filename(PROCESSED_DATA_DIR, FEATURE_TEMPLATE, dataset=DATASET), allow_pickle=True)
config.node_neighbor = np.load(format_filename(PROCESSED_DATA_DIR, SUBGRAPHA_TEMPLATE, dataset=DATASET, strategy='rw', n_channel=N_GRAPHS, n_neighbor=N_NEIGHBORS), allow_pickle=True)
config.spatial_matrix = np.load(format_filename(PROCESSED_DATA_DIR, SPATIAL_TEMPLATE, dataset=DATASET, strategy=SPATIAL_TYPE, n_channel=N_GRAPHS, n_neighbor=N_NEIGHBORS), allow_pickle=True)

model = TREE(config)


if os.path.exists(MODEL_PATH):
    model.model.load_weights(MODEL_PATH)
    print(f"Successfully loaded weights from {MODEL_PATH}")
else:
    raise FileNotFoundError(f"Weights not found at {MODEL_PATH}")

>>> Initializing Model...
Successfully loaded weights from ./Data/TREE.h5


#### Inference and Performance Evaluation

In [4]:
print(">>> Running Inference...")

auc, acc, spe, sen, f1, aupr, fpr, tpr = model.score(x=test_id, y=test_label)

print("\n" + "="*40)
print(f"TREE Reproduction Results")
print("="*40)
print(f"AUROC    : {auc:.4f}")
print(f"AUPRC    : {aupr:.4f}")
print(f"F1-Score : {f1:.4f}")
print(f"Accuracy : {acc:.4f}")
print("="*40)

>>> Running Inference...

TREE Reproduction Results
AUROC    : 0.8124
AUPRC    : 0.7010
F1-Score : 0.6708
Accuracy : 0.7282
