In [1]:
import sys
import os
sys.path.append(os.path.abspath("../"))
import base_experiment

ROOT_DIR, DATA_PATH, LOG_DIR, DEVICE = base_experiment.setup_experiment(
    seed=41, 
    data_path="dataset/NBA", 
    log_dir="experiments/fair_gnn/logs/nba", 
    device=2
)

# import these after, as we need to set correct path in setup_experiment
import torch
import numpy as np
from models.fair.gnn import FairGNN, FairGNNTrainer
from dataset import NBA

Using directories:
root_dir: /home/fact21/fact_refactor
data_dir: /home/fact21/fact_refactor/dataset/NBA
log_dir: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba
device: cuda:2


In [2]:
# Load in the dataset
dataset = NBA(
    nodes_path=DATA_PATH / "nba.csv",
    edges_path=DATA_PATH / "nba_relationship.txt",
    embedding_path=DATA_PATH / "nba_embedding10.npy",
    feat_drop_rate=0.3,
    device=DEVICE,
    
)

print(f"Loaded dataset with {dataset.graph.num_nodes()} nodes and {dataset.graph.num_edges()} edges")
print(f"Using feat_drop_rate: {dataset.feat_drop_rate}")

Loaded dataset with 403 nodes and 21645 edges
Using feat_drop_rate: 0.3


In [3]:
# Create FairGNN model
fair_gnn = FairGNN(
    num_features=dataset.features.shape[1],
    alpha=10, # for nba, alpha is set to 10
).to(DEVICE)

# load pre-trained estimator
estimator_path = ROOT_DIR / "src/checkpoint/GCN_sens_nba_ns_50"
fair_gnn.estimator.load_state_dict(torch.load(str(estimator_path.resolve())))

print(f"Created FairGNN model with {1} sensitive class")

Created FairGNN model with 1 sensitive class


In [4]:
# Create fair gnn trainer
trainer = FairGNNTrainer(
    dataset=dataset,
    fair_gnn=fair_gnn,
    device=DEVICE,
    log_dir=LOG_DIR,
    min_acc=0.65,
    min_roc=0.69,
)

print(f"Created trainer with {'GCN'} model, using LOG_DIR: {LOG_DIR}")

Created trainer with GCN model, using LOG_DIR: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba


In [5]:
# Train the model
trainer.train(epochs=3000)

  0%|          | 0/3000 [00:00<?, ?it/s]

[446] updated to: Metrics(acc=0.6835443037974683, roc=0.7913992297817716, parity=0.13095238095238093, equality=0.15555555555555556)
[488] updated to: Metrics(acc=0.6708860759493671, roc=0.8010269576379975, parity=0.1468253968253968, equality=0.15555555555555556)
[590] updated to: Metrics(acc=0.6835443037974683, roc=0.7991014120667522, parity=0.10515873015873012, equality=0.0888888888888888)
[2544] updated to: Metrics(acc=0.7468354430379747, roc=0.8369704749679076, parity=0.18253968253968256, equality=0.0888888888888888)
[2550] updated to: Metrics(acc=0.7468354430379747, roc=0.8376123234916559, parity=0.18253968253968256, equality=0.0888888888888888)
Finished training!

Best fair model:
	acc: 0.7468
	roc: 0.8376
	parity: 0.1825
	equality: 0.0889
