In [None]:
import optuna
import torch
import numpy as np
import random

from args import Args
from data_processing import load_data_and_process
from train_and_test import train_gnn_cv

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# This creates an Args() internally and returns it, along with data and phenotypic affinity
args_base, dev_subject_IDs, test_subject_IDs, dev_features, test_features, \
    dev_y_true, test_y_true, pheno_affinity_matrix = load_data_and_process()

In [None]:
N_FOLDS = 5 

def create_trial_args(base_args, trial):
    """
    Copy base Args and overwrite the hyperparameters we want to tune.
    """
    args = Args()

    # Keep the non-tuned settings consistent with your main script
    args.split_mode       = base_args.split_mode
    if hasattr(base_args, "test_percentage"):
        args.test_percentage = base_args.test_percentage
    args.use_pheno_data   = base_args.use_pheno_data
    args.model            = base_args.model
    args.lg               = base_args.lg
    args.patience         = base_args.patience
    args.num_classes      = base_args.num_classes
    args.use_batching     = base_args.use_batching
    if hasattr(base_args, "batch_size"):
        args.batch_size   = base_args.batch_size
    if hasattr(base_args, "num_neighbors"):
        args.num_neighbors = base_args.num_neighbors
    args.activation       = base_args.activation
    if hasattr(base_args, "prelu_unit"):
        args.prelu_unit   = base_args.prelu_unit
    args.cls_hidden       = base_args.cls_hidden
    args.use_combat       = base_args.use_combat

    args.hiddenU          = base_args.hiddenU

    args.ckpt_path = None  # train_gnn_cv checks `if args.ckpt_path`

    # Set the hyperparameters to optimize
    args.dropout = trial.suggest_float(
        "dropout", 0.02, 0.15
    )

    args.edge_dropout = trial.suggest_float(
        "edge_dropout", 0.05, 0.25
    )

    # Log search for weight decay and learning rate
    args.weight_decay = trial.suggest_float(
        "weight_decay", 1e-5, 1e-2, log=True
    )

    args.lr = trial.suggest_float(
        "lr", 5e-5, 1e-2, log=True 
    )

    # Categorical number of selected SC features per subject
    args.node_ftr_dim = trial.suggest_categorical(
        "node_ftr_dim", [200, 300, 400, 500]
    )

    args.affinity_threshold = trial.suggest_float(
        "affinity_threshold", 0.5, 0.7
    )

    args.epochs = 300

    return args

In [None]:
def objective(trial):
    # Build trial-specific args object
    trial_args = create_trial_args(args_base, trial)

    # Run k-fold CV on the development set
    model, fold_results, mean_results, results_dir = train_gnn_cv(
        trial_args,
        dev_features,
        dev_y_true,
        pheno_affinity_matrix,
        n_folds=N_FOLDS,
        device=device,)

    # Mean validation AUC as objective
    val_auc_mean = mean_results["auc"][0]

    # Store metrics as user attributes
    trial.set_user_attr("train_loss", mean_results["train_loss"][0])
    trial.set_user_attr("val_loss", mean_results["val_loss"][0])

    trial.set_user_attr("train_acc", mean_results["train_accuracy"][0])
    trial.set_user_attr("val_acc", mean_results["val_accuracy"][0])

    trial.set_user_attr("precision", mean_results["precision"][0])
    trial.set_user_attr("recall", mean_results["recall"][0])
    trial.set_user_attr("f1", mean_results["f1_score"][0])

    trial.set_user_attr("specificity", mean_results["specificity"][0])
    trial.set_user_attr("npv", mean_results["npv"][0])

    trial.set_user_attr("results_dir", results_dir)
    trial.set_user_attr("fold_results", fold_results)

    trial.set_user_attr("results_dir", results_dir)
    trial.set_user_attr("fold_results", fold_results)

    return val_auc_mean

In [None]:
def print_callback(study, trial):
    print(f"Trial {trial.number} finished.")
    print(f"  Value (Val AUC): {trial.value:.4f}")
    # print(f"  Value (Val acc): {trial.value:.4f}")
    # print("  Params:")
    # for k, v in trial.params.items():
    #     print(f"    {k}: {v}")
    print("-" * 40)

study = optuna.create_study(
    direction="maximize",
    study_name="gnn_abide_val_auc",
)

study.optimize(
    objective,
    n_trials=200,
    callbacks=[print_callback], 
)

print("Best value (mean val AUC):", study.best_value)
print("Best params:", study.best_trial.params)

In [None]:
import pandas as pd

df = study.trials_dataframe()
df_sorted = df.sort_values("value", ascending=False)

# Show top 10 hyperparameter combinations ranked by val AUC
df_top10 = df_sorted.head(10)
df_top10

In [None]:
df_sorted.head(50)

In [None]:
df_sorted.to_csv("C:/Users/20202932/8STAGE/Code/8STAGE_Internship/GNN_model/Results/optuna_200trials", index=False)