## FairGNN Experiments on the NBA dataset
This notebook can be used to run FairGNN experiments on the NBA dataset.

It is currently configured to run a full training run on three different seeds.

In [1]:
import sys
import os
import torch
sys.path.append(os.path.abspath("../"))
from base_experiment import ExperimentRunner

### Set up the experiment runner
First we create an experiment runner, which is used to set the random seeds and provide params/logging directories to the different runs.

In [2]:
experiment = ExperimentRunner(
    experiment_name = "nba_fair_gnn_main",
    seeds = [40, 41, 42],
    data_path = "dataset/NBA",
    log_dir="experiments/fair_gnn/logs/nba", 
    device=0,
    params=[{"alpha": 10, "beta": 1}] # alpha is set to 10 for NBA
)

# after we set up the experiment, we can import the rest
from dataset import NBA
from models.gnn import WrappedGNNConfig
from models.fair.gnn import FairGNN, Trainer

### Run the experiments
Now we're ready to run the experiments!

We can do this by iterating of the `ExperimentRunner.runs()` method. This method returns a generator that yields the seed, logging directory, device and the params for the current experiment run.


For each experiment run we first:
1. Load in the dataset
2. Create the FairGNN model instance
3. Create the FairGNN trainer


Once everything is initiliased, we can run the training using `Trainer.train()`. This trains the full FairGNN model for the specified epochs.

In [3]:
for (seed, log_dir, device, params) in experiment.runs():
    print("===========================")
    print(f"Running {experiment.experiment_name} using seed {seed}")
    print(f"Log directory: {log_dir}")
    print(f"Params: {params}")
    print("===========================")
    
    # Load in the dataset
    dataset = NBA(
        nodes_path=experiment.data_path / "nba.csv",
        edges_path=experiment.data_path / "nba_relationship.txt",
        embedding_path=experiment.data_path / "nba_embedding10.npy",
        feat_drop_rate=0.3,
        device=experiment.device,
    )
    
    print(f"Loaded dataset with {dataset.graph.num_nodes()} nodes and {dataset.graph.num_edges()} edges")
    print(f"Using feat_drop_rate: {dataset.feat_drop_rate}")
    
    # Create FairGNN model
    fair_gnn = FairGNN(
        num_features=dataset.features.shape[1],
        device = device,
    ).to(device)
    
    # load pre-trained estimator
    fair_gnn.load_estimator(experiment.data_path / "GCN_sens_nba_ns_50")
    print(f"Created FairGNN model with {1} sensitive class")
    
    # Create fair gnn trainer
    trainer = Trainer(
        dataset=dataset,
        fair_gnn=fair_gnn,
        device=device,
        log_dir=log_dir,
        alpha=params["alpha"],
        beta=params["beta"],
        min_acc=0.65,
        min_roc=0.69,
    )
    
    print(f"Created trainer with {'GCN'} model, using LOG_DIR: {log_dir}")

    print("Starting training phase")
    # Train the model
    trainer.train(epochs=3000)


Running nba_fair_gnn_main using seed 40
Log directory: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba/nba_fair_gnn_main_40_alpha_10_beta_1
Params: {'alpha': 10, 'beta': 1}
Loaded dataset with 403 nodes and 21645 edges
Using feat_drop_rate: 0.3
Created FairGNN model with 1 sensitive class
Created trainer with GCN model, using LOG_DIR: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba/nba_fair_gnn_main_40_alpha_10_beta_1
Starting training phase


  0%|          | 0/3000 [00:00<?, ?it/s]

Finished training!
Best fair model:
	acc: 0.6526
	roc: 0.7332
	parity: 0.0124
	equality: 0.0399
	consistency: 0.0264
Running nba_fair_gnn_main using seed 41
Log directory: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba/nba_fair_gnn_main_41_alpha_10_beta_1
Params: {'alpha': 10, 'beta': 1}
Loaded dataset with 403 nodes and 21645 edges
Using feat_drop_rate: 0.3
Created FairGNN model with 1 sensitive class
Created trainer with GCN model, using LOG_DIR: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba/nba_fair_gnn_main_41_alpha_10_beta_1
Starting training phase


  0%|          | 0/3000 [00:00<?, ?it/s]

Finished training!
Best fair model:
	acc: 0.7230
	roc: 0.7882
	parity: 0.0005
	equality: 0.0024
	consistency: 0.0264
Running nba_fair_gnn_main using seed 42
Log directory: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba/nba_fair_gnn_main_42_alpha_10_beta_1
Params: {'alpha': 10, 'beta': 1}
Loaded dataset with 403 nodes and 21645 edges
Using feat_drop_rate: 0.3
Created FairGNN model with 1 sensitive class
Created trainer with GCN model, using LOG_DIR: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba/nba_fair_gnn_main_42_alpha_10_beta_1
Starting training phase


  0%|          | 0/3000 [00:00<?, ?it/s]

Finished training!
Best fair model:
	acc: 0.6667
	roc: 0.7694
	parity: 0.0045
	equality: 0.0019
	consistency: 0.0264
