## FairGNN Experiments on the Pokec-N dataset
This notebook can be used to run FairGNN experiments on the Pokec-N dataset.

It is currently configured to run a full training run on three different seeds.

In [None]:
import sys
import os
import torch
sys.path.append(os.path.abspath("../"))
from base_experiment import ExperimentRunner

### Set up the experiment runner
First we create an experiment runner, which is used to set the random seeds and provide params/logging directories to the different runs.

In [None]:
experiment = ExperimentRunner(
    experiment_name = "pokec_n_fair_gnn_main",
    seeds = [40, 41, 42],
    data_path = "dataset/pokec",
    log_dir="experiments/fair_gnn/logs/pokec_n", 
    device=0,
    params=[{"alpha": 50, "beta": 1}] # alpha is set to 50 for Pokec N
)

# after we set up the experiment, we can import the rest
from dataset import PokecN
from models.gnn import WrappedGNNConfig
from models.fair.gnn import FairGNN, Trainer

### Run the experiments
Now we're ready to run the experiments!

We can do this by iterating of the `ExperimentRunner.runs()` method. This method returns a generator that yields the seed, logging directory, device and the params for the current experiment run.


For each experiment run we first:
1. Load in the dataset
2. Create the FairAC model instance
3. Create the FairAC trainer


Once everything is initiliased, we can run the pretraining using `Trainer.pretrain()`. This trains the AE and sensitivity classifier.

Then we run the main training loop, this trains the full FairAC model for the remaining epochs.

In [None]:
for (seed, log_dir, device, params) in experiment.runs():
    print("===========================")
    print(f"Running {experiment.experiment_name} using seed {seed}")
    print(f"Log directory: {log_dir}")
    print(f"Params: {params}")
    print("===========================")
    
    # Load in the dataset
    dataset = PokecN(
        nodes_path=experiment.data_path / "region_job_2.csv",
        edges_path=experiment.data_path / "region_job_2_relationship.txt",
        embedding_path=experiment.data_path / "pokec_n_embedding10.npy",
        feat_drop_rate=0.3,
        device=experiment.device
    )

    
    print(f"Loaded dataset with {dataset.graph.num_nodes()} nodes and {dataset.graph.num_edges()} edges")
    print(f"Using feat_drop_rate: {dataset.feat_drop_rate}")
    
    # Create FairGNN model
    fair_gnn = FairGNN(
        num_features=dataset.features.shape[1],
        device = device,
    ).to(device)
    
    # load pre-trained estimator
    fair_gnn.load_estimator(experiment.data_path / "GCN_sens_nba_ns_50")
    print(f"Created FairGNN model with {1} sensitive class")
    
    # Create fair gnn trainer
    trainer = Trainer(
        dataset=dataset,
        fair_gnn=fair_gnn,
        device=device,
        log_dir=log_dir,
        alpha=params["alpha"],
        beta=params["beta"],
        min_acc=0.65,
        min_roc=0.69,
    )
    
    print(f"Created trainer with {'GCN'} model, using LOG_DIR: {log_dir}")

    print("Starting training phase")
    # Train the model
    trainer.train(epochs=3000)
