In [9]:
import sys
import os
sys.path.append(os.path.abspath("../"))
import base_experiment



ROOT_DIR, DATA_PATH, LOG_DIR, DEVICE = base_experiment.setup_experiment(
    seed=20, 
    data_path="dataset/NBA", 
    log_dir="experiments/fair_gnn/logs/nba", 
    device=2
)

# import these after, as we need to set correct path in setup_experiment
import torch
import numpy as np
from models.fair import FairGNN, FairGNNTrainer
from dataset import NBA

Using directories:
root_dir: /home/fact21/fact_refactor
data_dir: /home/fact21/fact_refactor/dataset/NBA
log_dir: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba
device: cuda:2


In [10]:
# Load in the dataset
dataset = NBA(
    nodes_path=DATA_PATH / "nba.csv",
    edges_path=DATA_PATH / "nba_relationship.txt",
    embedding_path=DATA_PATH / "nba_embedding10.npy",
    feat_drop_rate=0.3,
    device=DEVICE
)

print(f"Loaded dataset with {dataset.graph.num_nodes()} nodes and {dataset.graph.num_edges()} edges")
print(f"Using feat_drop_rate: {dataset.feat_drop_rate}")

Loaded dataset with 403 nodes and 21645 edges
Using feat_drop_rate: 0.3


In [18]:
# Create FairGNN model
fair_gnn = FairGNN(
    num_features=dataset.features.shape[1],
).to(DEVICE)

# load pre-trained estimator
estimator_path = ROOT_DIR / "src/checkpoint/GCN_sens_nba_ns_50"
fair_gnn.estimator.load_state_dict(torch.load(str(estimator_path.resolve())))

print(f"Created FairGNN model with {1} sensitive class")

Created FairGNN model with 1 sensitive class


In [19]:
# Create fair gnn trainer
trainer = FairGNNTrainer(
    dataset=dataset,
    fair_gnn=fair_gnn,
    device=DEVICE,
    log_dir=LOG_DIR,
    alpha=100,
    min_acc=0.65,
    min_roc=0.69,
)

print(f"Created trainer with {'GCN'} model, using LOG_DIR: {LOG_DIR}")

Created trainer with GCN model, using LOG_DIR: /home/fact21/fact_refactor/experiments/fair_gnn/logs/nba


In [21]:
# Train the model
trainer.train(epochs=3000)

Epoch 2999: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [01:05<00:00, 45.75it/s, Acc: 0.7468, Roc: 0.8543, Partity: 0.1954, Equality: 0.1222]

Finished training!

Best fair model:
	acc: 0.7342
	roc: 0.8273
	parity: 0.2917
	equality: 0.2611

Best acc model:
	acc: 0.7342
	roc: 0.8209
	parity: 0.2917
	equality: 0.2611

Best auc model:
	acc: 0.7342
	roc: 0.8273
	parity: 0.2917
	equality: 0.2611



