In [1]:
import numpy as np
import os
from pathlib import Path
import sys
import torch
from models.fair import FairGNNTrainer, FairGNN
from dataset import NBA, PokecN, PokecZ, Recidivism, Credit

def setup_experiment(seed: int, data_path: str, log_dir: str, device: int = 0):
    experiment_dir = Path(".")
    root_dir = Path(os.path.abspath(f"{str(experiment_dir)}/.."))
    sys.path.append(str((root_dir / "src").resolve()))

    data_path = root_dir / data_path
    log_dir = root_dir / log_dir
    log_dir.mkdir(parents=True, exist_ok=True)

    if device != "cpu":
        device = torch.device(f"cuda:{device}" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    print("Using directories:")
    print("root_dir:", root_dir)
    print("data_dir:", data_path)
    print("log_dir:", log_dir)
    print("========================================")
    print(
        "device:",
        device,
    )

    return root_dir, data_path, log_dir, device


In [2]:
NBA_PATH = "dataset/NBA"
NBA_LOG_DIR = "experiments/fair_gnn/logs/nba"

POKEC_Z_PATH = "dataset/pokec"
POKEC_Z_LOG_DIR = "experiments/fair_gnn/logs/pokec_z"

ROOT_DIR, DATA_PATH, LOG_DIR, DEVICE = setup_experiment(
    seed=41,
    data_path=POKEC_Z_PATH, 
    log_dir=POKEC_Z_LOG_DIR, 
    device=1
)

Using directories:
root_dir: /home/fact21/fact_refactor
data_dir: /home/fact21/fact_refactor/dataset/pokec
log_dir: /home/fact21/fact_refactor/experiments/fair_gnn/logs/pokec_z
device: cuda:1


In [3]:
# Load in the dataset
# dataset = NBA(
#     nodes_path=DATA_PATH / "nba.csv",
#     edges_path=DATA_PATH / "nba_relationship.txt",
#     embedding_path=DATA_PATH / "nba_embedding10.npy",
#     feat_drop_rate=0.3,
#     device=DEVICE,
#     data_seed=20
# )

dataset = PokecZ(
    DATA_PATH / "region_job.csv",
    DATA_PATH / "region_job_relationship.txt",
    DATA_PATH / "pokec_z_embedding10.npy",
    feat_drop_rate=0.3,
    data_seed=42,
    device=DEVICE
)

print(f"Loaded dataset with {dataset.graph.num_nodes()} nodes and {dataset.graph.num_edges()} edges")
print(f"Using feat_drop_rate: {dataset.feat_drop_rate}")

Loaded dataset with 67796 nodes and 1303712 edges
Using feat_drop_rate: 0.3


In [4]:
# Create FairGNN model
fair_gnn = FairGNN(
    num_features=dataset.features.shape[1],
    alpha=10,
    beta=1,
)

print(f"Created FairAC model with {1} sensitive class")

Created FairAC model with 1 sensitive class


In [5]:
trainer = FairGNNTrainer(
    fair_gnn=fair_gnn,
    dataset=dataset,
    device=DEVICE,
    log_dir=LOG_DIR,
    min_acc=0.65,
    min_roc=0.69,
)

print(f"Created trainer with {'GCN'} model, using LOG_DIR: {LOG_DIR}")

Created trainer with GCN model, using LOG_DIR: /home/fact21/fact_refactor/experiments/fair_gnn/logs/pokec_z


In [6]:
# Main training loop, with GNN validation
trainer.train(epochs=3000)

  0%|          | 0/3000 [00:00<?, ?it/s]

[117] updated to: Metrics(acc=0.6508183943881528, roc=0.6789329685362517, parity=0.0031537090360619913, equality=0.01685304772405105)
[337] updated to: Metrics(acc=0.6325019485580671, roc=0.720202877662693, parity=0.004570135746606252, equality=0.036254387222808604)
[392] updated to: Metrics(acc=0.6492595479345284, roc=0.7165471711940591, parity=0.013840669134786765, equality=0.049912666937794836)
[444] updated to: Metrics(acc=0.6418550272798129, roc=0.7109133769786986, parity=0.0017784176607706925, equality=0.03205346711081547)
[871] updated to: Metrics(acc=0.6480904130943103, roc=0.7113866767637287, parity=0.0022857534622240383, equality=0.04916510532860974)
[877] updated to: Metrics(acc=0.6508183943881528, roc=0.7068399452804378, parity=0.011160016454134114, equality=0.05076836638981819)
[879] updated to: Metrics(acc=0.6484801247077163, roc=0.7045668848934923, parity=0.011181955299602309, equality=0.057237830378741084)
[990] updated to: Metrics(acc=0.6461418550272798, roc=0.70857435