# Case Study 2 - Fairness
These notebooks are also available on Google Colab. This enables you to run the notebooks without having to set up an environment locally and gives you access to GPUs to run the notebooks on.

[![Run in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JOstMJmhI2wcufyBqZ1iV3YqOdThJ-_U?usp=sharing#scrollTo=mCX2hPceiAet)

## 1. Introduction
One common problem with some machine learning models is an unfair bias in the training data leading to a models that systematically perform worse for some populations. In this case study we will address the issue of fairness by reducing the bias in a generated synthetic dataset.

### 1.1 The Task
Train a fair prognostic classifier for COVID-19 patients in Brazil.

## 2. Imports
Lets get the imports out of the way. We import the required standard and 3rd party libraries and relevant Synthcity modules. We can also set the level of logging here, using Synthcity's bespoke logger. 

In [None]:
# Standard
import sys
import warnings
from pathlib import Path
from typing import Any, Tuple
import itertools

# 3rd party
import numpy as np
import pandas as pd
import pickle
import networkx as nx
import xgboost as xgb
import matplotlib.pyplot as plt
import seaborn as sns
from graphviz import Digraph
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.ensemble import RandomForestClassifier

# synthcity
import synthcity
import synthcity.logger as log
from synthcity.utils import serialization
from synthcity.plugins import Plugins
from synthcity.metrics import Metrics
from synthcity.plugins.core.dataloader import (GenericDataLoader, SurvivalAnalysisDataLoader)
from synthcity.plugins.privacy.plugin_decaf import plugin as decaf_plugin
from synthcity.plugins.core.constraints import Constraints

# Synthetic-data-lab
from utils import fairness_scores
from utils import plot_dag

# Configure warnings and logging
warnings.filterwarnings("ignore")

# Set the level for the logging
log.remove()
# log.add(sink=sys.stderr, level="INFO")


# Set up paths to resources
FAIR_RES_PATH = Path("../resources/fairness/")

## 3. Load the data
Next, we can load the data from file and formulate it as a classification problem. To do this we can simply set a time horizon and create an "is_dead_at_time_horizon" column. 

In [None]:
time_horizon = 14
X = pd.read_csv(f"../data/Brazil_COVID/covid_normalised_numericalised.csv")

X.loc[(X["Days_hospital_to_outcome"] <= time_horizon) & (X["is_dead"] == 1), f"is_dead_at_time_horizon={time_horizon}"] = 1
X.loc[(X["Days_hospital_to_outcome"] > time_horizon), f"is_dead_at_time_horizon={time_horizon}"] = 0
X.loc[(X["is_dead"] == 0), f"is_dead_at_time_horizon={time_horizon}"] = 0
X[f"is_dead_at_time_horizon={time_horizon}"] = X[f"is_dead_at_time_horizon={time_horizon}"].astype(int)

X.drop(columns=["is_dead", "Days_hospital_to_outcome"], inplace=True) # drop survival columns as they are not needed for a classification problem
display(X)

# Define the mappings of the encoded values in the Ethnicity column to the understandable values
ethnicity_mapper = {
    0: "Mixed",
    1: "White",
    2: "Black",
    3: "East Asian",
    4: "Indigenous",
}

## 4. Potential issue

The Brazilian population is made up of people of different ethnicities in different proportions. We should check the frequency for each ethnicity to see how evenly distributed our data is across ethnicity. Lets create a plot.

In [None]:
ethnicity_frequency_data = pd.DataFrame(data=X["Ethnicity"].value_counts().rename(ethnicity_mapper), columns=["Ethnicity"]).reset_index().rename(columns={"index": "Ethnicity", "Ethnicity": "Ethnicity count"})
sns.barplot(data=ethnicity_frequency_data, x="Ethnicity", y="Ethnicity count")
plt.show()

The population in our dataset is overwhelmingly white and mixed, with little representation of black, East Asian and indigenous people. This poses a problem for us.

## 5. The Problem
We need a prognostic classifier for the whole population. Having little representation from some parts of the population means that any classifier we train on this data is going to be susceptible to bias. Lets test an `RandomForestClassifier` classifier on the whole dataset then test it on each ethnicity. This will show us the extent of the problem, as we will be able to see any disparity between model performance across the different groups.

### 5.1 set up the data
Set up the data splits, using train_test_split from sklearn, for a prognostic classifier.

In [None]:
y = X["is_dead_at_time_horizon=14"]
X_in = X.drop(columns=["is_dead_at_time_horizon=14"])

X_train, X_test, y_train, y_test = train_test_split(X_in, y, random_state=4)
X_train.reset_index(drop=True, inplace=True)
X_test.reset_index(drop=True, inplace=True)
y_train.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)

### 5.2 Load the classifier
Load the trained `RandomForestClassifier`, which has been trained on the whole dataset.

In [None]:
prognostic_model = RandomForestClassifier(
    n_estimators=100,
    criterion='gini',
    max_depth=None,
    random_state=42,
    verbose=0,
    warm_start=False,
)

saved_model_path = FAIR_RES_PATH / "fairness_cond_aug_random_forest_real_data.sav"

# # The saved model was trained with the following code
# prognostic_model.fit(X_train, y_train)
# with open(saved_model_path, 'wb') as f:
#     pickle.dump(prognostic_model, f)

# Load the model trained on the whole dataset
with open(saved_model_path, "rb") as f:
    prognostic_model = pickle.load(f)


### 2.6.3 Evaluate the classifier
Evaluate the overall accuracy of the classifier on the whole dataset. We can see the accuracy on both the train and test sets. This notebook is set up for you to select a preferred performance score of either sklearn's `accuracy_score` or `f1_score`. Feel free to set `performance_score` to either in the line below, or a metric of you choice!

In [None]:
performance_score = accuracy_score
calculated_performance_score = performance_score(y_train, prognostic_model.predict(X_train))
print(f"Evaluating accuracy on train set: {calculated_performance_score:0.4f}")

# Predicted values for whole dataset
y_pred = prognostic_model.predict(X_test)

calculated_performance_score = performance_score(y_test, y_pred)
print(f"Evaluating accuracy on test set: {calculated_performance_score:0.4f}")

### 5.4 Confusion Matrices
Create the confusion matrix for each of the ethnicities and the whole dataset to compare.

In [None]:
# Setup the figure axis
f, axes = plt.subplots(2, 3, figsize=(20, 10))

# Create the whole dataset confusion matrix and add it to the figure
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=prognostic_model.classes_)
disp.plot(ax=axes[0,0])
disp.ax_.set_title(f"Whole test dataset: {calculated_performance_score}")

# Get the indices to loop through
ethnicity_idxes = X_in["Ethnicity"].unique()
ethnicity_idxes.sort()

# for each ethnicity create a confusion matrix
for ethnicity_idx in ethnicity_idxes:
    # Get the slice of the dataset for each ethnicity
    X_test_per_ethnicity = X_test.loc[X_test["Ethnicity"] == ethnicity_idx]
    test_records_per_ethnicity_indices = X_test_per_ethnicity.index
    y_true = y_test.iloc[test_records_per_ethnicity_indices]

    # Generate prediction values for each ethnicity subpopulation
    y_pred_per_ethnicity = prognostic_model.predict(X_test_per_ethnicity)

    # Calculate the model performance for each ethnicity subpopulation
    calculated_performance_score = performance_score(y_true, y_pred_per_ethnicity)

    # Generate the confusion matrix and add it to the figure
    cm = confusion_matrix(y_true, y_pred_per_ethnicity)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=prognostic_model.classes_)
    ax_index = [0, ethnicity_idx + 1] if ethnicity_idx <= 1 else [1, (ethnicity_idx + 1) % 3]
    disp.plot(ax=axes[ax_index[0], ax_index[1]])
    disp.ax_.set_title(f"Ethnicity: {ethnicity_mapper[ethnicity_idx]} | Performance: {calculated_performance_score:0.4f}")
plt.show()

As you can see the performance of the model on the black population is significantly worse than the overall performance. Interestingly, however the model performs better on the East Asian subpopulation. This is likely to be due to random chance, i.e. it happens that the East Asian patients in this sample had features that are good predictors of the outcome, but this would not necessarily be true for a bigger sample from the same population. The Indigenous population is so poorly represented in the dataset, with only 3 records, that it is difficult to even accurately assess performance. However, the indication we have from these three records suggests performance may be poor.

This confirms by using a naive method like the one above, we would create a model that systematically performs worse for people of one ethnicity compared to another. This unfairness must be addressed.

## 6. The solution - Augment the dataset to improve the fairness

### 6.1 Load the data with the synthcity module
First we load the data with the GenericDataLoader. For this we need to pass the names of our `target_column` to the data loader. Then we can see the data by calling loader.dataframe() and we could also get the information about the data loader object with loader.info().


In [None]:
loader = GenericDataLoader(
    X,
    target_column=f"is_dead_at_time_horizon={time_horizon}",
    sensitive_features=["Ethnicity"],
    random_state=42,
)

display(loader.dataframe())

### 6.2 Load/Create the synthetic data model
We are now going to generate synthetic data with a condition such that the new dataset is more balanced with regard to ethnicity. We will first define some values which we will use below.

In [None]:
model = "ctgan"
prefix = "fairness.conditional_augmentation"
random_state = 42

We will now either create and fit a synthetic data model then save that model to file, or load one we have already saved from file.

In [None]:
# Define saved model name
save_file = Path("saved_models") / f"{prefix}_{model}_numericalised_rnd={random_state}.bkp"
print(save_file)
# Load if available
if Path(save_file).exists():
    syn_model = serialization.load_from_file(save_file)
# create and fit if not available
else:
    syn_model = Plugins().get(model, random_state=random_state)
    syn_model.fit(loader, cond=loader["Ethnicity"])
    serialization.save_to_file(save_file, syn_model)

### 6.3 Generate fairer data
Use the synthetic data model to generate data using the `cond` argument to try and make the data evenly distributed across ethnicity. We will then augment the original, real dataset with the synthetic records from the under-represented ethnicities.

In [None]:
count = 6882 # set the count equal to the number of rows in the original dataset for a fair comparison
cond = [(i % 5) for i in range(count)] # set cond to an equal proportion of each  minority index
syn_data = syn_model.generate(count=count, cond=cond, random_state=random_state).dataframe()
augmented_data = pd.concat([
    X,
    syn_data.loc[syn_data["Ethnicity"] >= 2],
])

display(augmented_data)

print("Here is the ethnicity breakdown for the real dataset:")
print(loader["Ethnicity"].value_counts().rename(ethnicity_mapper))
print("\nHere is the ethnicity breakdown for the synthetic dataset:")
print(syn_data["Ethnicity"].value_counts().rename(ethnicity_mapper))
print("\nHere is the ethnicity breakdown for the augmented dataset:")
print(augmented_data["Ethnicity"].value_counts().rename(ethnicity_mapper))

Check the ethnicity breakdown again now to check we have augmented the under-represented groups properly. This is important as the conditional only optimizes the GAN here it does not guarantee that generated samples perfectly meet that condition. If you require rules to be strictly adhered to, use `Constraints` instead. 

### 6.4 Re-evaluate the classifier on the new, fairer dataset
Lets try our classifier again with the synthetic dataset. First we need to set up the synthetic data as we did before.

In [None]:
augmented_y = augmented_data["is_dead_at_time_horizon=14"]
augmented_X = augmented_data.drop(columns=["is_dead_at_time_horizon=14"])
augmented_y.reset_index(drop=True, inplace=True)
augmented_X.reset_index(drop=True, inplace=True)


We need a model trained on the new data. We can load this from file, as before.

In [None]:
saved_model_path = FAIR_RES_PATH / "fairness_cond_aug_random_forest_augmented_data.sav"

# # The saved model was trained with the following code
# prognostic_model.fit(augmented_X, augmented_y)
# with open(saved_model_path, 'wb') as f:
#     pickle.dump(prognostic_model, f)

    
# Load the model trained on the whole dataset
with open(saved_model_path, "rb") as f:
    prognostic_model = pickle.load(f)


Evaluate the performance of the model on the real dataset according to the "train-on-synthetic, test-on-real rule".

In [None]:
y_pred = prognostic_model.predict(X_test)
calculated_performance_score = performance_score(y_test, y_pred)
print(f"evaluating test set: {calculated_performance_score:0.4f}")

## 6.5 New confusion matrices
Create the confusion matrix for the synthetic dataset.

In [None]:
# Setup the figure axis
f, axes = plt.subplots(2, 3, figsize=(20, 10))

# Create the whole dataset confusion matrix and add it to the figure
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=prognostic_model.classes_)
disp.plot(ax=axes[0,0])
# disp.plot(ax=axes[0])
disp.ax_.set_title(f"Whole Test Dataset | Performance: {calculated_performance_score:0.4f}")

# Get the indices to loop through
ethnicity_idxes = augmented_X["Ethnicity"].unique()
ethnicity_idxes.sort()

# for each ethnicity create a confusion matrix
for ethnicity_idx in ethnicity_idxes:
    # Get the slice of the dataset for each ethnicity
    X_test_per_ethnicity = X_test.loc[X_test["Ethnicity"] == ethnicity_idx]
    test_records_per_ethnicity_indicies = X_test_per_ethnicity.index
    y_true_per_ethnicity = y_test.iloc[test_records_per_ethnicity_indicies]

    # Generate prediction values for each ethnicity subpopulation
    y_pred_per_ethnicity = prognostic_model.predict(X_test_per_ethnicity)

    # Calculate the model performance for each ethnicity subpopulation
    calculated_performance_score = performance_score(y_true_per_ethnicity, y_pred_per_ethnicity)

    # Generate the confusion matrix and add it to the figure
    cm = confusion_matrix(y_true_per_ethnicity, y_pred_per_ethnicity)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=prognostic_model.classes_)
    # disp.plot(ax=axes[ethnicity_idx + 1])
    ax_index = [0, ethnicity_idx + 1] if ethnicity_idx <= 1 else [1, (ethnicity_idx + 1) % 3]
    disp.plot(ax=axes[ax_index[0], ax_index[1]])
    disp.ax_.set_title(f"Ethnicity: {ethnicity_mapper[ethnicity_idx]} | Performance: {calculated_performance_score:0.4f}")
plt.show()

As you can hopefully see the new model trained on the synthetic data performs more similarly across the different populations. 

Why not play with the different synthetic data generation methods and their parameters to see if you can achieve the same improvement in fairness, but with a higher performance? If you need help identifying the right methods then, remember you can list the available plugins with `Plugins().list()` and to learn what they do refer to the [docs](https://synthcity.readthedocs.io/en/latest/generators.html).

## 7. Removing bias via causal generation with DECAF

DECAF is an inference-time de-biasing method, where the data-generating process is embedded explicitly as a structural causal model in the input layers of the generator, allowing each variable to be reconstructed conditioned on its causal parents. We will use this method to create a different solution to the issue of fairness in this Brazilian COVID-19 dataset.
### 7.1 Load the data
Lets load the data from file again to make sure we are working with the correct data and nothing has changed. As before we will construct it as a classification problem. But for this excercise we will set the problem up as a classic 2-class fairness task with one class representing the majority ethnic groups (White and Mixed) and one the Minority (Black, East-Asian, and Indigenous).

In [None]:
# Load the data from file
X = pd.read_csv(f"../data/Brazil_COVID/covid_normalised_numericalised.csv")

# Set it up as classification task
time_horizon = 14
X.loc[(X["Days_hospital_to_outcome"] <= time_horizon) & (X["is_dead"] == 1), f"is_dead_at_time_horizon={time_horizon}"] = 1
X.loc[(X["Days_hospital_to_outcome"] > time_horizon), f"is_dead_at_time_horizon={time_horizon}"] = 0
X.loc[(X["is_dead"] == 0), f"is_dead_at_time_horizon={time_horizon}"] = 0
X[f"is_dead_at_time_horizon={time_horizon}"] = X[f"is_dead_at_time_horizon={time_horizon}"].astype(int)

X.drop(columns=["is_dead", "Days_hospital_to_outcome"], inplace=True) # drop survival columns as they are not needed for a classification problem

# Set up ethnicity as two classes, minority and majority
X.loc[(X["Ethnicity"] == 0) | (X["Ethnicity"] == 1), "Ethnicity"] = 0 
X.loc[(X["Ethnicity"] == 2) | (X["Ethnicity"] == 3) | (X["Ethnicity"] == 4), "Ethnicity"] = 1

Pass the data into the synthcity dataloader. We will just use the `GenericDataloader` here.

In [None]:
loader = GenericDataLoader(
    X,
    target_column="is_dead_at_time_horizon=14",
    sensitive_features=["Ethnicity"],
    random_state=42,
)

display(loader.dataframe())

## 7.2 Load/Create the synthetic data model using DECAF
First, we define some useful variables.

In [None]:
prefix = "fairness.causal_generation"
model = "decaf"
n_iter = 101
count = 6882 # set the count equal to the number of rows in the original dataset
random_state=6

Then load the synthetic model from file. If you want to try something a little different you can also fit your own model.

In [None]:
# Define saved model name
save_file = Path("saved_models") / f"{prefix}_{model}_n_iter={n_iter}_rnd={random_state}.bkp"

# Load from file if available
if Path(save_file).exists():
    syn_model = serialization.load_from_file(save_file)
    dag = syn_model.get_dag(loader.dataframe())
    print(f"DAG before biased edges are removed:")
    display(plot_dag.get_dag_plot(dag))
    
# create and fit if not available
else:
    syn_model = decaf_plugin(struct_learning_enabled=True, n_iter=n_iter, random_state=random_state) # Pass struct_learning_enabled=True in order for the syn_model to learn the Dag
    dag_before = syn_model.get_dag(loader.dataframe())
    syn_model.fit(loader, dag=dag_before, random_state=random_state)
    serialization.save_to_file(save_file, syn_model)
    print(f"DAG before biased edges are removed:")
    dag = syn_model.get_dag(loader.dataframe())
    display(plot_dag.get_dag_plot(dag))

### 7.3 Generate the data
We can simply generate the de-biased dataset, by passing the biased edges we wish to remove from the data to `generate`.

In [None]:
bias={"Ethnicity": ["is_dead_at_time_horizon=14"]}
decaf_syn_data = syn_model.generate(count, biased_edges=bias, random_state=random_state)
display(decaf_syn_data.dataframe())

### 7.4 DECAF fairness tests

We will now check our synthetic data is fairer than the original, real data, by measuring demographic parity. A definition for which can be seen in section 4.1 of the [DECAF paper](https://arxiv.org/abs/2110.12884).

In [None]:
demographic_parity_score_gt = fairness_scores.demographic_parity_score(loader)
demographic_parity_score_syn = fairness_scores.demographic_parity_score(decaf_syn_data)

print(f"Demographic Parity scores \nreal data: {demographic_parity_score_gt} | synthetic data: {demographic_parity_score_syn}")

## 8. Extension
Use the code block below as a space to complete the extension exercises below.

### 8.1 Plot the DECAF synthetic data to show fairness

What is a simple plot we could make to show that the generated data is fair?


### 8.2 Our solution

In [None]:
#@title (i) Plot to show fairness in data generated by DECAF

# 2 class ethnicity mapper
ethnicity_mapper_2_class = {
    0: "Majority",
    1: "Minority"
}

# Define the model
xgb_model = xgb.XGBClassifier(
    n_estimators=2000,
    learning_rate=0.01,
    max_depth=5,
    subsample=0.8, 
    colsample_bytree=1, 
    gamma=1, 
    objective="binary:logistic",
    random_state=42,
)

# Load the model trained on the whole dataset
saved_model_path = FAIR_RES_PATH / "fairness_causal_gen_decaf_synthetic_data.json"
xgb_model.load_model(saved_model_path)
synth_data_to_predict = decaf_syn_data.dataframe().drop(columns=["is_dead_at_time_horizon=14"])
targets_synth_data_to_predict = decaf_syn_data["is_dead_at_time_horizon=14"]

# # The saved model was trained with the following code
# xgb_model.fit(synth_data_to_predict, targets_synth_data_to_predict)
# xgb_model.save_model(saved_model_path)

# Get the indices to loop through
ethnicity_idxes = decaf_syn_data["Ethnicity"].unique()
ethnicity_idxes.sort()

predictions = {}
for ethnicity_idx in ethnicity_idxes:
    synth_data_to_predict_per_ethnicity = synth_data_to_predict.loc[synth_data_to_predict["Ethnicity"] == ethnicity_idx]
    # display(synth_data_to_predict)

    synthetic_predictions = xgb_model.predict(synth_data_to_predict)

    unique, counts = np.unique(synthetic_predictions, return_counts=True)
    prediction_counts = {unique[0]: counts[0], unique[1]: counts[1]}
    predictions[ethnicity_mapper_2_class[ethnicity_idx]] = prediction_counts

prediction_frequency_data = pd.DataFrame(data={
    "Ethnicity": predictions.keys(),
    "0": [p_c[0] for p, p_c in predictions.items()],
    "1": [p_c[1] for p, p_c in predictions.items()],
})

prediction_frequency_data_m = pd.melt(prediction_frequency_data, id_vars="Ethnicity")
prediction_frequency_data_m = prediction_frequency_data_m.rename(
    columns={"variable": "is_dead_at_time_horizon", "value": "Prediction count"}
)


sns.catplot(
    data=prediction_frequency_data_m,
    x="Ethnicity",
    y="Prediction count",
    hue="is_dead_at_time_horizon",
    kind="bar"
)
plt.show()