In [1]:
import numpy as np
import os
from pathlib import Path
import sys
import torch
from models.fair import FairGNNTrainer, FairGNN
from dataset import NBA

def setup_experiment(seed: int, data_path: str, log_dir: str, device: int = 0):
    experiment_dir = Path(".")
    root_dir = Path(os.path.abspath(f"{str(experiment_dir)}/.."))
    sys.path.append(str((root_dir / "src").resolve()))

    data_path = root_dir / data_path
    log_dir = root_dir / log_dir
    log_dir.mkdir(parents=True, exist_ok=True)

    device = torch.device(f"cuda:{device}" if torch.cuda.is_available() else "cpu")

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    print("Using directories:")
    print("root_dir:", root_dir)
    print("data_dir:", data_path)
    print("log_dir:", log_dir)
    print("========================================")
    print(
        "device:",
        device,
    )

    return root_dir, data_path, log_dir, device


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ROOT_DIR, DATA_PATH, LOG_DIR, DEVICE = setup_experiment(
    seed=20, 
    data_path="dataset/NBA", 
    log_dir="experiments/fair_gnn/logs/nba", 
    device=2
)

Using directories:
root_dir: /home/harold/repos/fact
data_dir: /home/harold/repos/fact/dataset/NBA
log_dir: /home/harold/repos/fact/experiments/fair_gnn/logs/nba
device: cpu


In [3]:
# Load in the dataset
dataset = NBA(
    nodes_path=DATA_PATH / "nba.csv",
    edges_path=DATA_PATH / "nba_relationship.txt",
    embedding_path=DATA_PATH / "nba_embedding10.npy",
    feat_drop_rate=0.3,
    device=DEVICE
)

print(f"Loaded dataset with {dataset.graph.num_nodes()} nodes and {dataset.graph.num_edges()} edges")
print(f"Using feat_drop_rate: {dataset.feat_drop_rate}")

Loaded dataset with 403 nodes and 21645 edges
Using feat_drop_rate: 0.3


In [4]:
# Create FairGNN model
fair_gnn = FairGNN(
    num_features=dataset.features.shape[1],
    alpha=10,
    beta=1,
)

print(f"Created FairAC model with {1} sensitive class")

Created FairAC model with 1 sensitive class


In [5]:
trainer = FairGNNTrainer(
    fair_gnn=fair_gnn,
    # these are not being used
    alpha=10,
    beta=1,
    dataset=dataset,
    device=DEVICE,
    log_dir=LOG_DIR,
    min_acc=0.65,
    min_roc=0.69,
)

print(f"Created trainer with {'GCN'} model, using LOG_DIR: {LOG_DIR}")

Created trainer with GCN model, using LOG_DIR: /home/harold/repos/fact/experiments/fair_gnn/logs/nba


In [6]:
# Main training loop, with GNN validation
trainer.train(epochs=3000)

Epoch 0:   0%|          | 0/3000 [00:00<?, ?it/s]

: 