First, let's import all necessary libraries:

In [None]:
import collections
import numpy as np
import os
import pytorch_lightning as pl
import torch

from copy import deepcopy
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.nn import BCEWithLogitsLoss

os.chdir('..')

from attacks import influence_attack, anchoring_attack
from datamodules import GermanCreditDatamodule, CompasDatamodule, DrugConsumptionDatamodule
from fairness import FairnessLoss
from trainingmodule import BinaryClassifier

Now we create a **general attack function** that handles all different attack methods

In [None]:
def attack(dm, model, eps, method):
    if method == 'IAF':
        # Create adversarial loss according to Mehrabi et al.
        bce_loss, fairness_loss = BCEWithLogitsLoss(), FairnessLoss(dm.get_sensitive_index())
        adv_loss = lambda _model, X, y: (
                bce_loss(_model(X), y.float()) + 0.1 * fairness_loss(X, *_model.get_params())
        )
        
        # Create new training pipeline to use in influence attack
        trainer = pl.Trainer(
            max_epochs=100,
            gpus=1 if torch.cuda.is_available() else 0,
            enable_model_summary=False,
            enable_progress_bar=False,
            callbacks=[EarlyStopping(monitor="train_acc", mode="min", patience=10)]
        )
        
        # Create deep copy of the original model to use in influence attack
        model = deepcopy(model)
        
        poisoned_dataset = influence_attack(
            model=model,
            datamodule=dm,
            trainer=trainer,
            adv_loss=adv_loss,
            eps=eps,
            eta=0.01,
            attack_iters=100,
        )
    elif method in ['RAA', 'NRAA']:
        poisoned_dataset = anchoring_attack(
            D_c=dm.get_train_dataset(),
            sensitive_idx=dm.get_sensitive_index(),
            eps=eps,
            tau=0,
            sampling_method='random' if attack == 'RAA' else 'non-random',
            attack_iters=100,
        )
    else:
        raise ValueError(f'Unknown attack {method}.')
    
    # Create deep copy of the original dataset and poison the copy
    dm = deepcopy(dm)
    dm.update_train_dataset(poisoned_dataset)

    return dm

and a **nested dictionary**, which is convinient to store results for multiple datasets and metrics:

In [None]:
def nested_dict():
   return collections.defaultdict(nested_dict)

results = nested_dict()

Finally, iterate over all possible combination of Figure 2 in Mehrabi et al.

In [None]:
# Create Datamodules for all datasets
german_credit_datamodule = GermanCreditDatamodule('data/', 100)
compas_datamodule = CompasDatamodule('data/', 100)
drug_consumption_datamodule = DrugConsumptionDatamodule('data/', 100)

# Create Trainer
trainer = pl.Trainer(
    max_epochs=1000,
    gpus=1 if torch.cuda.is_available() else 0,
    callbacks=[TQDMProgressBar(), EarlyStopping(monitor="train_loss", mode="min", patience=10)]
)

for dm in [german_credit_datamodule, compas_datamodule, drug_consumption_datamodule]:
    # Create a Binary Classifier model for each dataset
    model = BinaryClassifier('linear_regression', dm.get_input_size(), lr=1e-3)
    
    for method in ['IAF', 'RAA', 'NRAA']:
        for eps in np.arange(0, 1.1, 0.1):
            # Create poisoned dataset
            poisoned_dataset = attack(dm, model, eps, method)
            
            # Train on the poisoned dataset
            trainer.fit(model, dm)
            
            # Save Accuracy and Fairness metrics
            metrics = trainer.test()
            results[dm.get_dataset_name()]['Test Error'][method][eps] = metrics['test_error']
            results[dm.get_dataset_name()]['Statistical Parity'][method][eps] = metrics['SPD']
            results[dm.get_dataset_name()]['Equality of Opportunity'][method][eps] = metrics['EOD']

and plot results:

In [None]:
fig, ax = plt.subplots(3, 3)
for i, dataset in enumerate(['German Credit', 'COMPAS', 'Drug Consumption']):
    for j, metric in enumerate(['Test Error', 'Statistical Parity', 'Equality of Opportunity']):
        for method in ['IAF', 'RAA', 'NRAA']:
            ax[i, j].plot(
                x = list(results[dataset][metric][method].keys()),
                y = list(results[dataset][metric][method].values()),
                label = method,
            )