First, let's import all necessary libraries:

In [12]:
import collections
import logging
import numpy as np
import os
import pytorch_lightning as pl
import torch

from copy import deepcopy
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.nn import BCEWithLogitsLoss
from tqdm.notebook import tqdm

#os.chdir('..')

from attacks import influence_attack, anchoring_attack
from datamodules import GermanCreditDatamodule, CompasDatamodule, DrugConsumptionDatamodule
from fairness import FairnessLoss
from trainingmodule import BinaryClassifier

logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

Now we create a **general attack function** that handles all different attack methods

In [13]:
def attack(dm, model, eps, method):
    if method == 'IAF':
        # Create adversarial loss according to Mehrabi et al.
        bce_loss, fairness_loss = BCEWithLogitsLoss(), FairnessLoss(dm.get_sensitive_index())
        adv_loss = lambda _model, X, y: (
                bce_loss(_model(X), y.float()) + 0.1 * fairness_loss(X, *_model.get_params())
        )
        
        # Create new training pipeline to use in influence attack
        trainer = pl.Trainer(
            max_epochs=100,
            gpus=1 if torch.cuda.is_available() else 0,
            enable_model_summary=False,
            enable_progress_bar=False,
            log_every_n_steps=1,
            callbacks=[EarlyStopping(monitor="train_acc", mode="max", patience=10)]
        )

        poisoned_dataset = influence_attack(
            model=model,
            datamodule=dm,
            trainer=trainer,
            adv_loss=adv_loss,
            eps=eps,
            eta=0.01,
            attack_iters=100,
        )
    elif method in ['RAA', 'NRAA']:
        poisoned_dataset = anchoring_attack(
            D_c=dm.get_train_dataset(),
            sensitive_idx=dm.get_sensitive_index(),
            eps=eps,
            tau=0,
            sampling_method='random' if attack == 'RAA' else 'non-random',
            attack_iters=1,
        )
    else:
        raise ValueError(f'Unknown attack {method}.')
    
    # Create deep copy of the original dataset and poison the copy
    dm = deepcopy(dm)
    dm.update_train_dataset(poisoned_dataset)

    return dm

and a **nested dictionary**, which is convinient to store results for multiple datasets and metrics:

In [14]:
def nested_dict():
   return collections.defaultdict(nested_dict)

results = nested_dict()

Finally, iterate over all possible combination of Figure 2 in Mehrabi et al.

In [15]:
# Create Datamodules for all datasets
german_credit_datamodule = GermanCreditDatamodule('data/', 200)
compas_datamodule = CompasDatamodule('data/', 200)
drug_consumption_datamodule = DrugConsumptionDatamodule('data/', 200)

# Create Trainer

for dm in [german_credit_datamodule, compas_datamodule, drug_consumption_datamodule]:
    for method in ['IAF', 'RAA', 'NRAA']:
        print(f'Poisoning {dm.get_dataset_name()} dataset with {method} attack:')
        for eps in tqdm(np.arange(0, 1.1, 0.1)):
            # Create a Binary Classifier model for each dataset
            trainer = pl.Trainer(
                max_epochs=300,
                gpus=1 if torch.cuda.is_available() else 0,
                enable_model_summary=False,
                enable_progress_bar=False,
                log_every_n_steps=1,
                callbacks=[EarlyStopping(monitor="train_acc", mode="max", patience=20)]
            )
            model = BinaryClassifier('LogisticRegression', dm.get_input_size(), lr=1e-3)
            
            # Create poisoned dataset
            poisoned_dataset = attack(dm, model, eps, method)
            
            # Train on the poisoned dataset
            trainer.fit(model, poisoned_dataset)
            
            # Save Accuracy and Fairness metrics
            metrics = trainer.test(model, dm)[0]
            results[dm.get_dataset_name()]['Test Error'][method][eps] = metrics['test_error']
            results[dm.get_dataset_name()]['Statistical Parity'][method][eps] = metrics['SPD']
            results[dm.get_dataset_name()]['Equality of Opportunity'][method][eps] = metrics['EOD']

Poisoning German Credit dataset with IAF attack:


  0%|          | 0/11 [00:00<?, ?it/s]

  rank_zero_deprecation(


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'EOD': 0.01527780294418335,
 'SPD': 0.2081504762172699,
 'test_error': 0.29500001668930054}
--------------------------------------------------------------------------------


	https://www.cvxpy.org/tutorial/advanced/index.html#disciplined-parametrized-programming
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
  rank_zero_deprecation(


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'EOD': 0.020833313465118408,
 'SPD': 0.24639499187469482,
 'test_error': 0.2799999713897705}
--------------------------------------------------------------------------------


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'EOD': 0.04583334922790527,
 'SPD': 0.24576802551746368,
 'test_error': 0.3050000071525574}
--------------------------------------------------------------------------------


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
Exception ignored in: <function _releaseLock at 0x7ffb187335e0>
Traceback (most recent call last):
  File "/home/john/anaconda3/envs/pytorchEnv/lib/python3.9/logging/__init__.py", line 227, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 


and plot results:

In [None]:
# Reproduce plot styling of the original paper
colors, markers = ['b', 'r', 'g'], ['s', 'x', '^']

fig, ax = plt.subplots(3, 3, figsize=(20, 10))
for i, dataset in enumerate(['German Credit', 'COMPAS', 'Drug Consumption']):
    for j, metric in enumerate(['Test Error', 'Statistical Parity', 'Equality of Opportunity']):
        for k, method in enumerate(['IAF', 'RAA', 'NRAA']):
            ax[i, j].plot(
                list(results[dataset][metric][method].keys()),
                list(results[dataset][metric][method].values()),
                c=colors[k],
                marker=markers[k],
                label = method,
            )
        
        ax[i, j].set_xlabel('$\epsilon$', fontweight='bold')
        ax[i, j].set_ylabel(metric, fontweight='bold')
        ax[i, j].set_title(dataset, fontweight='bold')
        ax[i, j].legend(loc='upper left', ncol=3)
        
plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.4, hspace=0.4)