First, let's import all necessary libraries:

In [None]:
import collections
import logging
import numpy as np
import os
import pytorch_lightning as pl
import torch

from copy import deepcopy
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.nn import BCEWithLogitsLoss
from tqdm.notebook import tqdm

#os.chdir('..')

from attacks import influence_attack
from datamodules import GermanCreditDatamodule, CompasDatamodule, DrugConsumptionDatamodule
from fairness import FairnessLoss
from trainingmodule import BinaryClassifier

logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

Now we create a **general attack function** that handles all different attack methods

In [None]:
def attack(dm, model, eps, lamda):
    # Create adversarial loss according to Mehrabi et al.
    bce_loss, fairness_loss = BCEWithLogitsLoss(), FairnessLoss(dm.get_sensitive_index())
    adv_loss = lambda _model, X, y: (
            bce_loss(_model(X), y.float()) + lamda * fairness_loss(X, *_model.get_params())
    )
    
    # Create new training pipeline to use in influence attack
    trainer = pl.Trainer(
        max_epochs=100,
        gpus=1 if torch.cuda.is_available() else 0,
        enable_model_summary=False,
        enable_progress_bar=False,
        log_every_n_steps=1,
        callbacks=[EarlyStopping(monitor="train_acc", mode="max", patience=10)]
    )

    poisoned_dataset = influence_attack(
        model=model,
        datamodule=dm,
        trainer=trainer,
        adv_loss=adv_loss,
        eps=eps,
        eta=0.01,
        attack_iters=100,
    )
    
    # Create deep copy of the original dataset and poison the copy
    dm = deepcopy(dm)
    dm.update_train_dataset(poisoned_dataset)

    return dm

and a **nested dictionary**, which is convinient to store results for multiple datasets and metrics:

In [None]:
def nested_dict():
   return collections.defaultdict(nested_dict)

results = nested_dict()

Finally, iterate over all possible combination of Figure 2 in Mehrabi et al.

In [None]:
# Create Datamodules for all datasets
german_credit_datamodule = GermanCreditDatamodule('data/', 10)
compas_datamodule = CompasDatamodule('data/', 50)
drug_consumption_datamodule = DrugConsumptionDatamodule('data/', 10)

# Create Trainer


for dm in [german_credit_datamodule, compas_datamodule, drug_consumption_datamodule]:
    for eps in [0, 0.1, 1]:
        print(f'Poisoning {dm.get_dataset_name()} dataset with eps = {eps}:')
        for lamda in tqdm([0, 0.1, 0.2, 0.4, 0.6, 0.8, 1]):
            pl.seed_everything(123)
            
            # Create a Binary Classifier model for each dataset
            model = BinaryClassifier('LogisticRegression', dm.get_input_size(), lr=1e-3)
            
            # Create poisoned dataset
            if eps == 0:
                dm_poisoned = dm
            else:
                dm_poisoned = attack(dm, model, eps, lamda)
            trainer = pl.Trainer(
                max_epochs=300,
                gpus=1 if torch.cuda.is_available() else 0,
                enable_model_summary=False,
                enable_progress_bar=False,
                log_every_n_steps=1,
                callbacks=[EarlyStopping(monitor="train_acc", mode="max", patience=10)]
            )
            
            # Train on the poisoned dataset
            trainer.fit(model, dm_poisoned)
            
            # Save Accuracy and Fairness metrics
            metrics = trainer.test(model, dm)[0]
            results[dm.get_dataset_name()]['Test Error'][eps][lamda] = metrics['test_error']
            results[dm.get_dataset_name()]['Statistical Parity'][eps][lamda] = metrics['SPD']
            results[dm.get_dataset_name()]['Equality of Opportunity'][eps][lamda] = metrics['EOD']
            
            # for eps = 0 run the experiment only with lambda = 0
            if eps == 0:
                break

and plot results:

In [None]:
# Reproduce plot styling of the original paper
colors, markers = ['r', 'b', 'g'], ['*', 's', '^']

fig, ax = plt.subplots(3, 3, figsize=(20, 10))
for i, dataset in enumerate(['German Credit', 'COMPAS', 'Drug Consumption']):
    for j, metric in enumerate(['Test Error', 'Statistical Parity', 'Equality of Opportunity']):
        for k, eps in enumerate([0, 0.1, 1]):
            ax[i, j].plot(
                list(results[dataset][metric][eps].keys()),
                list(results[dataset][metric][eps].values()),
                c=colors[k],
                marker=markers[k],
                label = '$\epsilon$' + f'={eps}',
            )
        
        ax[i, j].set_xlabel('$\lambda$', fontweight='bold')
        ax[i, j].set_ylabel(metric, fontweight='bold')
        ax[i, j].set_title(dataset, fontweight='bold')
        ax[i, j].legend(loc='upper left', ncol=3)
        
plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.4, hspace=0.4)