In [None]:
import sys
import os
import torch
sys.path.append(os.path.abspath("../"))
from base_experiment import ExperimentRunner

In [None]:
experiment = ExperimentRunner(
    experiment_name = "nba_fair_gnn_main",
    seeds = [40, 41, 42],
    data_path = "dataset/NBA",
    log_dir="experiments/fair_gnn/logs/nba", 
    device=3,
    params=[{"alpha": 10}] # alpha is set to 10 for NBA
)

# after we set up the experiment, we can import the rest
from dataset import NBA
from models.gnn import WrappedGNNConfig
from models.fair.gnn import FairGNN, Trainer

In [None]:
# Load in the dataset
dataset = NBA(
    nodes_path=experiment.data_path / "nba.csv",
    edges_path=experiment.data_path / "nba_relationship.txt",
    embedding_path=experiment.data_path / "nba_embedding10.npy",
    feat_drop_rate=0.3,
    device=experiment.device
)

print(f"Loaded dataset with {dataset.graph.num_nodes()} nodes and {dataset.graph.num_edges()} edges")
print(f"Using feat_drop_rate: {dataset.feat_drop_rate}")

In [None]:
for (seed, log_dir, device, params) in experiment.runs():
    print("===========================")
    print(f"Running {experiment.experiment_name} using seed {seed}")
    print(f"Log directory: {log_dir}")
    print(f"Params: {params}")
    print("===========================")

    # Create FairGNN model
    fair_gnn = FairGNN(
        num_features=dataset.features.shape[1],
        alpha=params["alpha"],
        device = device,
    ).to(device)
    
    # load pre-trained estimator
    fair_gnn.load_estimator(experiment.data_path / "GCN_sens_nba_ns_50")
    print(f"Created FairGNN model with {1} sensitive class")
    
    # Create fair gnn trainer
    trainer = Trainer(
        dataset=dataset,
        fair_gnn=fair_gnn,
        device=device,
        log_dir=log_dir,
        min_acc=0.65,
        min_roc=0.69,
    )
    
    print(f"Created trainer with {'GCN'} model, using LOG_DIR: {log_dir}")

    print("Starting training phase")
    # Train the model
    trainer.train(epochs=3000)
