In [None]:
import sys
import os
import torch
sys.path.append(os.path.abspath("../"))
from base_experiment import ExperimentRunner

In [None]:
# Set up the experiment runner with the all the seeds and params we want
experiment = ExperimentRunner(
    experiment_name = "pokec_z_fair_ac_alpha_experiments",
    seeds = [40, 41, 42],
    data_path = "dataset/pokec",
    log_dir="experiments/fair_ac/logs/pokec_z", 
    device=3,
    params=[{"feat_drop_rate": f} for f in (0.1, 0.3, 0.5, 0.8)]
)

# after we set up the experiment, we can import the rest
from dataset import PokecZ
from models.gnn import WrappedGNNConfig
from models.fair.ac import FairAC, Trainer

In [None]:
gnn_config = WrappedGNNConfig(
    hidden_dim=128,
    kind="GCN",
    lr=1e-3,
    weight_decay=1e-5,
    kwargs={"dropout": 0.5},
)

In [None]:
for (seed, log_dir, device, params) in experiment.runs():
    print("===========================")
    print(f"Running {experiment.experiment_name} using seed {seed}")
    print(f"Log directory: {log_dir}")
    print(f"Params: {params}")
    print("===========================")
    
    # Load in the dataset
    dataset = PokecZ(
        nodes_path=experiment.data_path / "region_job.csv",
        edges_path=experiment.data_path / "region_job_relationship.txt",
        embedding_path=experiment.data_path / "pokec_z_embedding10.npy",
        feat_drop_rate=params["feat_drop_rate"],
        device=device
    )

    print(f"Loaded dataset with {dataset.graph.num_nodes()} nodes and {dataset.graph.num_edges()} edges")
    print(f"Using feat_drop_rate: {dataset.feat_drop_rate}")

    # Create FairAC model
    fair_ac = FairAC(
        feature_dim=dataset.features.shape[1],
        transformed_feature_dim=128,
        emb_dim=dataset.embeddings.shape[1],
        attn_vec_dim=128,
        attn_num_heads=1,
        dropout=0.5,
        num_sensitive_classes=1,
    ).to(device)
    print(f"Created FairAC model with {1} sensitive class")

    # Create FairAC trainer
    trainer = Trainer(
        ac_model=fair_ac,
        lambda1=1.0,
        lambda2=1.0,
        dataset=dataset,
        device=device,
        gnn_config=gnn_config,
        log_dir=log_dir,
        min_acc=0.65,
        min_roc=0.69,
    )
    print(f"Created trainer with {'GCN'} model, using log_dir: {log_dir}")

    print("Starting pre-training phase")
    # Run pre-training
    trainer.pretrain(epochs=200)
    print("Finished pretraining")
    
    # Main training loop, with GNN validation
    print("Starting main training...")
    trainer.train(val_start_epoch=800, val_epoch_interval=200, epochs=2800)

    # As we allocate the entire dataset on the gpu, we need to de-allocate it, before starting over.
    del dataset
    del trainer
    del fair_ac
    torch.cuda.empty_cache()
    print("Cleared cuda cache")
