In [1]:
import sys
import os
import torch
sys.path.append(os.path.abspath("../"))
from base_experiment import ExperimentRunner

In [2]:
experiment = ExperimentRunner(
    experiment_name = "nba_fair_gnn_main",
    seeds = [40, 41, 42],
    data_path = "dataset/NBA",
    log_dir="experiments/fair_gnn/logs/nba", 
    device=3,
    params=[{"alpha": 10}] # alpha is set to 10 for NBA
)

# after we set up the experiment, we can import the rest
from dataset import NBA
from models.gnn import WrappedGNNConfig
from models.fair.gnn import FairGNN, Trainer

In [3]:
# Load in the dataset
dataset = NBA(
    nodes_path=experiment.data_path / "nba.csv",
    edges_path=experiment.data_path / "nba_relationship.txt",
    embedding_path=experiment.data_path / "nba_embedding10.npy",
    feat_drop_rate=0.3,
    device=experiment.device
)

print(f"Loaded dataset with {dataset.graph.num_nodes()} nodes and {dataset.graph.num_edges()} edges")
print(f"Using feat_drop_rate: {dataset.feat_drop_rate}")

Loaded dataset with 403 nodes and 21645 edges
Using feat_drop_rate: 0.3


In [None]:
for (seed, log_dir, device, params) in experiment.runs():
    print("===========================")
    print(f"Running {experiment.experiment_name} using seed {seed}")
    print(f"Log directory: {log_dir}")
    print(f"Params: {params}")
    print("===========================")

    # Create FairGNN model
    fair_gnn = FairGNN(
        num_features=dataset.features.shape[1],
        alpha=params["alpha"],
        device = device,
    ).to(device)
    
    # load pre-trained estimator
    fair_gnn.load_estimator(experiment.data_path / "GCN_sens_nba_ns_50")
    print(f"Created FairGNN model with {1} sensitive class")
    
    # Create fair gnn trainer
    trainer = Trainer(
        dataset=dataset,
        fair_gnn=fair_gnn,
        device=device,
        log_dir=log_dir,
        min_acc=0.65,
        min_roc=0.69,
    )
    
    print(f"Created trainer with {'GCN'} model, using LOG_DIR: {log_dir}")

    print("Starting training phase")
    # Train the model
    trainer.train(epochs=3000)


Running nba_fair_gnn_main using seed 40
Log directory: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba/nba_fair_gnn_main_40_alpha_10
Params: {'alpha': 10}
Created FairGNN model with 1 sensitive class
Created trainer with GCN model, using LOG_DIR: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba/nba_fair_gnn_main_40_alpha_10
Starting training phase


  0%|          | 0/3000 [00:00<?, ?it/s]

[505] updated to: Metrics(epoch=505, acc=0.6572769953051644, roc=0.6994708994708995, parity=0.2304948216340621, equality=0.2852969814995131, consistency=None)
[526] updated to: Metrics(epoch=526, acc=0.6525821596244131, roc=0.7044091710758377, parity=0.22658227848101264, equality=0.24634858812074, consistency=None)
[541] updated to: Metrics(epoch=541, acc=0.6713615023474179, roc=0.7088183421516756, parity=0.22658227848101264, equality=0.2205452775073029, consistency=None)
[547] updated to: Metrics(epoch=547, acc=0.6525821596244131, roc=0.7069664902998237, parity=0.07249712313003454, equality=0.10418695228821806, consistency=None)
[971] updated to: Metrics(epoch=971, acc=0.6525821596244131, roc=0.7627865961199295, parity=0.07410817031070194, equality=0.052580331061343744, consistency=None)
[981] updated to: Metrics(epoch=981, acc=0.6525821596244131, roc=0.7622574955908289, parity=0.001380897583429186, equality=0.02434274586173324, consistency=None)
[2106] updated to: Metrics(epoch=2106,

  0%|          | 0/3000 [00:00<?, ?it/s]

[429] updated to: Metrics(epoch=429, acc=0.6525821596244131, roc=0.696031746031746, parity=0.1285385500575374, equality=0.14410905550146058, consistency=None)
[585] updated to: Metrics(epoch=585, acc=0.6619718309859155, roc=0.7224867724867725, parity=0.10402761795166854, equality=0.10564751703992215, consistency=None)
[587] updated to: Metrics(epoch=587, acc=0.6525821596244131, roc=0.7194003527336863, parity=0.08354430379746836, equality=0.06621226874391428, consistency=None)
[636] updated to: Metrics(epoch=636, acc=0.676056338028169, roc=0.7283950617283952, parity=0.08826237054085162, equality=0.05355404089581306, consistency=None)
[654] updated to: Metrics(epoch=654, acc=0.6713615023474179, roc=0.7323633156966491, parity=0.08193325661680095, equality=0.04089581304771173, consistency=None)
[1207] updated to: Metrics(epoch=1207, acc=0.6525821596244131, roc=0.7728395061728395, parity=0.03693901035673186, equality=0.011197663096397248, consistency=None)
[1287] updated to: Metrics(epoch=1