<a href="https://colab.research.google.com/github/rsippy/CKD-CKDu/blob/master/CCAIM_Tutorial_Innovative_Uses_of_Synthetic_Data_Tutorial_Synthcity_Hands_On.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://github.com/vanderschaarlab/synthcity/raw/main/docs/logo.png" alt="Synthcity Logo" width="200"/>

# CCAIM Tutorial: Innovative Uses of Synthetic Data Tutorial

## Set up runtime
We need to set up our runtime, environment and get the files we need. First, we will select "Runtime > Change runtime type"  from the menu. Then we should set the Hardware accelerator to a standard GPU (T4 GPU).
Now we are ready to run some code. Let's clone the lab repo. Then cd into it. Then we will install our dependencies. This may require you to restart the runtime session.

In [None]:
%%capture
!git clone https://github.com/vanderschaarlab/synthetic-data-lab
%cd synthetic-data-lab/
!pip install -r requirements.txt;

In [None]:
import torch

print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("Is CUDA available?", torch.cuda.is_available())

if torch.cuda.is_available():
    print("Number of GPUs available:", torch.cuda.device_count())
    print("Current GPU:", torch.cuda.get_device_name(0))
else:
    print("No GPU detected.")


Before you do the next step, please now restart you runtime with Runtime > Restart session. This will ensure the dependencies are resolved properly in Colab. You do not need to re-run the first cell once the runtime is restarted.

Now cd into the Tutorials directory.

In [None]:
%cd synthetic-data-lab/
%cd Tutorials/
!pwd

We are now set up and ready to start the tutorial.

# Synthcity: an open-source library to facilitate innovative uses of synthetic data


In order to enable wider adoption of synthetic data and facilitate translational research in this promising area, the community needs a software platform that implements a large collection of state-of-the-art generators in a modular, reusable and composable way.

Synthcity is an open-source software package for innovative use cases of synthetic data in ML fairness, privacy and augmentation across diverse tabular data modalities, including static data, regular and irregular time series, data with censoring, multi-source data, composite data, and more. Synthcity provides the practitioners with a single access point to cutting edge research and tools in synthetic data. It also offers the community a playground for rapid experimentation and prototyping, a one-stop-shop for SOTA benchmarks, and an opportunity for extending research impact.

For researcher:
 - A playground for rapid experimentation and prototyping
 - A one-stop-shop for SOTA benchmarks
 - An opportunity for extending reearch impact and translational research

For practitioners:
 - A way to brining cutting edge research to practical problems
 - A unified solution to diverse problems


Know more about Synthcity by visiting [GitHub](https://github.com/vanderschaarlab/synthcity), reading our [White Paper](https://arxiv.org/abs/2301.07573), or our Paper from [NeurIPS 2023](https://openreview.net/pdf?id=uIppiU2JKP).

# Introduction to synthetic data workflow

Today we will use Synthcity to illustrate the different use cases of synthetic data in ML and data analytics.

![Standard workflow of generating and evaluating synthetic data with synthcity.](https://drive.google.com/uc?export=view&id=1tuRfTk9WQZRD9Be_Szzna8mk4zskAR7_)

The synthcity library captures the entire workflow of synthetic data generation and evaluation. The typical workflow contains the following steps, as illustrated above.

1. **Loading the dataset using a DataLoader**. The DataLoader class provides a consistent interface for loading and storing different types of input data (e.g. tabular, time series, and survival data). The user can also provide meta-data to inform downstream algorithms (e.g. specifying the sensitive columns for privacy-preserving algorithms).
2. **Training the generator using a Plugin**. In synthcity, the users instantiate, train, and apply different data generators via the Plugin class. Each Plugin represents a specific data generation algorithm. The generator can be trained using the fit() method of a Plugin.
3. **Generating synthetic data**. After the Plugin is trained, the user can use the generate() method to generate synthetic data. Some plugins also allow for conditional generation.
4. **Evaluating synthetic data**. Synthcity provides a large set of metrics for evaluating the fidelity, utility, and privacy of synthetic data. The Metrics class allows users to perform evaluation.

In addition, synthcity also has a Benchmark class that wraps around all the four steps, which is helpful for comparing and evaluating different generators.
After the synthetic data is evaluated, it can then be used in various downstream tasks.

We are now set up and ready to start the tutorial.

# 1. Case Study 1 - Data Modality

## 1.1 Introduction
![Catgorization of data modalities](https://drive.google.com/uc?export=view&id=1nTejenRyLAXv2mwaZdUE63kQc3H2J822)

\"Tabular data\" is a general category that encompasses many different data modalities. In this section, we introduce how to categorize these diverse modalities and how synthcity can be used to handle it.

### Single dataset

We start by introducing the most fundamental case where there is a single training dataset (e.g. a single DataFrame in Pandas). We characterize the data modalities by two axes: the observation pattern and the feature type.

The observation pattern describes whether and how the data are collected over time. There are three most prominent patterns, all supported by synthcity:

1. Static. All features are observed in a snapshot. There is no temporal ordering.
2. Regular time series.  Observations are made at regular intervals, t = 1, 2, 3... Of note, it is possible that different series may have different number of observations.
3. Irregular time series. Observations are made at irregular intervals, t = t1, t2, t3, ... Note that, for different series, the observation times may vary.

The feature type describes the domain of individual features. Synthcity supports the following three types. It also supports multivariate cases with a mixture of different feature types.

1. Continuous feature
2. Categorical feature
3. Integer feature
4. Censored feature: survival time and censoring indicator

The combination of observation patterns and feature types give rise to an array of data modalities. Synthcity supports all combinations.

### Composite dataset

A composite dataset involves multiple sub datasets. For instance, it may contain datasets collected from different sources or domains (e.g. from different countries). It may also contain both static and time series data. Such composite data are quite often seen in practice. For example, a patient's medical record may contain both static demographic information and longitudinal follow up data.

synthcity can handle the generation of different classes of composite datasets. Currently, it supports (1) multiple static datasets, (2) a static and a regular time series dataset, and (3) a static and a irregular time series dataset.

### Metadata

Very often we have access to metadata that describes the properties of the underlying data. Synthcity can make use of these information to guide the generation and evaluation process. It supports the following types of metadata:

1. sensitive features: indicator of sensitive features that should be protected for privacy.
2. outcome features: indicator of outcome feature that will be used as the  target in downstream prediction tasks.
3. domain: information about the data type and allowed value range.



## 1.2 The Task
In this first exercise, we will get used to loading datasets with the library and generating synthetic data from them, whatever the modality of the real data.

## 1.3 Imports
Lets get the imports out of the way. We import the required standard and 3rd party libraries and relevant Synthcity modules. We can also set the level of logging here, using Synthcity's bespoke logger.

In [None]:
# Standard
import sys
import warnings
from pathlib import Path

# 3rd party
import numpy as np
import pandas as pd

# synthcity
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import (GenericDataLoader, SurvivalAnalysisDataLoader, TimeSeriesDataLoader, TimeSeriesSurvivalDataLoader)

# Configure warnings and logging
warnings.filterwarnings("ignore")

# Set the level for the logging
# log.add(sink=sys.stderr, level="INFO")
log.remove()

## 1.4 Loading data of different modalities

In this notebook we will load different datasets into synthcity and show that data of many different modalities can be used to generate synthetic data using this module.


### 1.4.1 Static Data
Now we will start with the simplest example, static tabular data. For this, we will use the diabetes dataset from sklearn. First, we need to load the dataset.

In [None]:
from sklearn.datasets import load_diabetes

X, y = load_diabetes(return_X_y=True, as_frame=True)
X["target"] = y
display(X)

Then we pass it to the `GenericDataLoader` object from `synthcity`.

In [None]:
loader = GenericDataLoader(
    X,
    target_column="target",
    sensitive_columns=["sex"],
)

We can print out different methods that are compatible with our data by calling `Plugins().list()` with a relevant list passed to the categories parameter.

In [None]:
print(Plugins(categories=["generic"]).list())

No need to worry about the code in this next block here, we will go into lots of detail in how to generate synthetic data in the case studies to come. It is here purely to demonstrate that our dataset can be used to generate synthetic data using the synthcity module. We are using the method `marginal_distributions` to generate the synthetic data, which is one of the available debugging methods.

In [None]:
syn_model = Plugins().get("marginal_distributions")
syn_model.fit(loader)
syn_model.generate(count=10).dataframe()

### 1.4.2 Static survival
Next lets look at censored data. Censoring is a form of missing data problem in which time to event is not observed for reasons such as termination of study before all recruited subjects have shown the event of interest or the subject has left the study prior to experiencing an event. Censoring is common in survival analysis. For our next example we will load a static survival dataset. Our dataset this time is a veteran lung cancer dataset provided by scikit-survival.

First, load the dataset.

In [None]:
from sksurv.datasets import load_veterans_lung_cancer # If this causes an error you may need to restart the runtime

data_x, data_y = load_veterans_lung_cancer()
data_x["status"], data_x["survival_in_days"] = [record[0] for record in data_y], [record[1] for record in data_y]
display(data_x)

Pass it to the DataLoader. This time we will use the `SurvivalAnalysisDataLoader`. We need to pass it the data, the name of the column that contains our labels or targets to `target_column` and the the name of the column  containing the time elapsed when the event occurred (the event defined by the target column) to `time_to_event_column`. Calling `info()` on the loader object allows us to see the information about the dataset we have just prepared.

In [None]:
loader = SurvivalAnalysisDataLoader(
    data_x,
    target_column="status",
    time_to_event_column="survival_in_days",
)
display(loader.info())


If we get the `marginal_distributions` plugin again and fit it to the `loader` object, we can then call `generate` to produce the synthetic data.

In [None]:
syn_model = Plugins().get("marginal_distributions")
syn_model.fit(loader)
syn_model.generate(count=10)

### 1.4.3 Regular Time Series

In this next example we will load up a simple regular time series dataset and show that it is compatible with Synthcity. The temporal data must be passed to the loader as a list of DataFrames, where each DataFrame in the list refers to a different record and contains all time points for the record. So, there is a small amount of pre-processing to get our data into the right shape. As it is a regular time series we can simply pass a sequential list for each record.

The dataset we will use here is the basic motions dataset provided by SKTime. So, we need to import the library.

In [None]:
from tslearn.datasets import UCR_UEA_datasets

# Initialize the dataset loader
ucr = UCR_UEA_datasets()

# Load the "BasicMotions" dataset
X_train, y_train, X_test, y_test = ucr.load_dataset("BasicMotions")

# Combine train and test splits
X = np.concatenate((X_train, X_test), axis=0)
y = np.concatenate((y_train, y_test), axis=0)

print(X.shape)
print(y.shape)

Load the data and re-format it into a list of DataFrames, where each DataFrame in the list refers to a different record and contains all time points for the record. We also need the outcomes as a DataFrame and the observation times as a list of time steps for each record. As this is a regular time series our time steps can simply be a sequential list of integers. We will also print the some of the data when we have it in the correct shape.

In [None]:
# Convert the data to a multi-index DataFrame format
def convert_to_multiindex_dataframe(X):
    """Converts numpy array to a pandas DataFrame with multi-index for time series data."""


    # Create multi-index: first level is instance, second level is time step
    index = pd.MultiIndex.from_product([range(n_instances), range(n_timesteps)],
                                       names=["instance", "time_step"])

    # Reshape the data into a long format
    data_reshaped = X.reshape(n_instances * n_timesteps, n_features)

    # Create the DataFrame with multi-index
    df = pd.DataFrame(data_reshaped, index=index, columns=[f"feature_{i}" for i in range(n_features)])
    return df

n_instances, n_timesteps, n_features = X.shape

# Convert X to multi-index DataFrame
X_multiindex_df = convert_to_multiindex_dataframe(X)

# Convert y to a simple DataFrame for consistency
y_df = pd.DataFrame(y, columns=["label"], index=pd.Index(range(len(y)), name="instance"))

# Now, X_multiindex_df and y_df are in the desired multi-index format
print(X_multiindex_df.head())
print(y_df.head())

# Convert multi-index dataframe into list of dataframes
temporal_data = [X_multiindex_df.loc[i] for i in range(len(X))]  # Slice rows by instance
y = pd.DataFrame(y, columns=["label"])
observation_times = [list(range(X.shape[1])) for _ in range(X.shape[0])]

print("The first 3 dataframes in the list, `temporal_data`. They refer to the first 3 instances in the dataset. Each instance contains all time steps for all features.")
for i in range(3):
    display(temporal_data[i])
print("\nThe first 3 label values, `y`.")
display(y[0:3])

Pass the data we just prepared to the DataLoader. Here we will use the `TimeSeriesDataLoader`. Then we will print out the loader info to check everything looks correct.

In [None]:
loader = TimeSeriesDataLoader(
    temporal_data=temporal_data,
    observation_times=observation_times,
    outcome=y,
)
display(loader.dataframe())
print(loader.info())

Now we are ready to produce the synthetic data.

In [None]:
syn_model = Plugins().get("marginal_distributions")
syn_model.fit(loader)
syn_model.generate(count=10)

### 1.4.4 Irregular Time Series

Now lets load an irregular time series dataset and show that that is also compatible with Synthcity. The dataset we will use here is a google stocks dataset provided by the synthcity module itself.

In [None]:
from synthcity.utils.datasets.time_series.google_stocks import GoogleStocksDataloader

Try to generate synthetic data for the Google stocks data yourself. Use the previous example as a clue. If you need help with `GoogleStocksDataLoader`, it is defined [here](https://github.com/vanderschaarlab/synthcity/blob/main/src/synthcity/utils/datasets/time_series/google_stocks.py). And the docs for the Synthcity DataLoaders are available [here](https://synthcity.readthedocs.io/en/latest/generated/synthcity.plugins.core.dataloader.html#synthcity.plugins.core.dataloader.TimeSeriesDataLoader). Once you have your answer, compare it to the code hidden in the cell labelled 1.4.4 (i) below.

Please use the empty cell immediately below for your answer.

In [None]:
#@title (i) Create Synthetic Data for Irregular Time Series
static_data, temporal_data, observation_times, outcome = GoogleStocksDataloader().load()

print(f"static_data | type:{type(static_data)} | shape {static_data.shape}")
print(f"temporal_data | type:{type(temporal_data)} | type:{type(temporal_data[0])} |shape {temporal_data[0].shape}")
print(f"observation_times | type:{type(observation_times)} | type:{type(observation_times[0])} | shape {len(observation_times)}")
print(f"outcome | type:{type(outcome)} | shape {outcome.shape}")
print()

loader = TimeSeriesDataLoader(
    temporal_data=temporal_data,
    observation_times=observation_times,
    static_data=static_data,
    outcome=outcome,
)
print(loader.info())
display(loader.dataframe())

# Exactly as for the regular time series, we can now generate synthetic data,
# by selecting our time series compatible plugin, then calling `fit()` and
# `generate()`.
syn_model = Plugins().get("marginal_distributions")
syn_model.fit(loader)
syn_model.generate(count=5)

## 1.5 Extension

Use the code block below as a space to complete the extension exercises below.

### 1.5.1 Create synthetic datasets

Generate synthetic data for another dataset of your choice using the methods described above. You can use any of the other dataset from the sources we have used above: [SKLearn](https://scikit-learn.org/stable/datasets/toy_dataset.html), [SKTime](https://www.sktime.org/en/stable/api_reference/datasets.html), [SKSurv](https://scikit-survival.readthedocs.io/en/stable/api/datasets.html)  or [synthcity](https://github.com/vanderschaarlab/synthcity/tree/main/src/synthcity/utils/datasets) itself. Why not try a composite dataset including both static and temporal features?

### 1.5.2 What is the best value for n_iter?
<details>
<summary>Show answer</summary>
It depends on your use case. The larger the value the longer it will take to run, however the plugins are equipped with early stopping, so that when a specified metric converges, the GAN stops at that point. So, setting an arbitrarily large value, is often a good option.
</details>

# 2. Case Study 2 - Fairness
If you are running this case study without having also run the previous ones, make sure you have set up the runtime correctly, by running the "Set up runtime" and "Install requirements" cells at the top of the notebook.

## 2.1 Introduction
One common problem with some machine learning models is an unfair bias in the training data leading to a models that systematically perform worse for some populations. In this case study we will address the issue of fairness by reducing the bias in a generated synthetic dataset.



## 2.2 The Task
Train a fair prognostic classifier for COVID-19 patients in Brazil.

## 2.3 Imports
Lets import all the modules we need for the second case study. Some we have already imported, but this means you can run work through the case studies in any order, if you want to revisit them after the lab. We import the required standard and 3rd party libraries and relevant Synthcity modules. We can also set the level of logging here, using Synthcity's bespoke logger.

In [None]:
# Standard
import sys
import warnings
from pathlib import Path
from typing import Any, Tuple
import itertools

# 3rd party
import numpy as np
import pandas as pd
import pickle
import networkx as nx
import xgboost as xgb
import matplotlib.pyplot as plt
import seaborn as sns
from graphviz import Digraph
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.ensemble import RandomForestClassifier

# synthcity
import synthcity
import synthcity.logger as log
from synthcity.utils import serialization
from synthcity.plugins import Plugins
from synthcity.metrics import Metrics
from synthcity.plugins.core.dataloader import (GenericDataLoader, SurvivalAnalysisDataLoader)
from synthcity.plugins.privacy.plugin_decaf import plugin as decaf_plugin
from synthcity.plugins.core.constraints import Constraints

# # Synthetic-data-lab
from utils import fairness_scores
from utils import plot_dag

# Configure warnings and logging
warnings.filterwarnings("ignore")

# Set the level for the logging
log.remove()
# log.add(sink=sys.stderr, level="INFO")


# Set up paths to resources
FAIR_RES_PATH = Path().cwd() / "/resources/fairness/"

## 2.4 Load the data
Next, we can load the data from file and formulate it as a classification problem. To do this we can simply set a time horizon and create an "is_dead_at_time_horizon" column.

In [None]:
time_horizon = 14
X = pd.read_csv(f"../data/Brazil_COVID/covid_normalised_numericalised.csv")

X.loc[(X["Days_hospital_to_outcome"] <= time_horizon) & (X["is_dead"] == 1), f"is_dead_at_time_horizon={time_horizon}"] = 1
X.loc[(X["Days_hospital_to_outcome"] > time_horizon), f"is_dead_at_time_horizon={time_horizon}"] = 0
X.loc[(X["is_dead"] == 0), f"is_dead_at_time_horizon={time_horizon}"] = 0
X[f"is_dead_at_time_horizon={time_horizon}"] = X[f"is_dead_at_time_horizon={time_horizon}"].astype(int)

X.drop(columns=["is_dead", "Days_hospital_to_outcome"], inplace=True) # drop survival columns as they are not needed for a classification problem
display(X)

# Define the mappings of the encoded values in the Ethnicity column to the understandable values
ethnicity_mapper = {
    0: "Mixed",
    1: "White",
    2: "Black",
    3: "East Asian",
    4: "Indigenous",
}

## 2.5 Potential issue

The Brazilian population is made up of people of different ethnicities in different proportions. We should check the frequency for each ethnicity to see how evenly distributed our data is across ethnicity. Lets create a plot.

In [None]:
ethnicity_frequency_data = pd.DataFrame(
    data=X["Ethnicity"].value_counts().rename(ethnicity_mapper),
).reset_index().rename(
    columns={"index": "Ethnicity"}  # Rename the index column to "Ethnicity"
)
display(pd.DataFrame(
    data=X["Ethnicity"].value_counts().rename(ethnicity_mapper),
))
sns.barplot(data=ethnicity_frequency_data, x="Ethnicity", y="count")
plt.show()


The population in our dataset is overwhelmingly white and mixed, with little representation of black, East Asian and indigenous people. This poses a problem for us.

## 2.6 The Problem
We need a prognostic classifier for the whole population. Having little representation from some parts of the population means that any classifier we train on this data is going to be susceptible to bias. Lets test an `RandomForestClassifier` classifier on the whole dataset then test it on each ethnicity. This will show us the extent of the problem, as we will be able to see any disparity between model performance across the different groups.



### 2.6.1 set up the data
Set up the data splits, using train_test_split from sklearn, for a prognostic classifier.

In [None]:
y = X["is_dead_at_time_horizon=14"]
X_in = X.drop(columns=["is_dead_at_time_horizon=14"])

X_train, X_test, y_train, y_test = train_test_split(X_in, y, random_state=4, train_size=0.8)

X_train.reset_index(drop=True, inplace=True)
X_test.reset_index(drop=True, inplace=True)
y_train.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)

### 2.6.2 Load the classifier
Load the trained `RandomForestClassifier`, which has been trained on the whole dataset.

In [None]:
prognostic_model = RandomForestClassifier(
    n_estimators=100,
    criterion='gini',
    max_depth=None,
    random_state=42,
    verbose=0,
    warm_start=False,
)

saved_model_path = "/content/synthetic-data-lab/resources/fairness/fairness_cond_aug_random_forest_real_data.sav"

# # The saved model was trained with the following code
prognostic_model.fit(X_train, y_train)
with open(saved_model_path, 'wb') as f:
    pickle.dump(prognostic_model, f)

# Load the model trained on the whole dataset
# with open(saved_model_path, "rb") as f:
#     prognostic_model = pickle.load(f)

### 2.6.3 Evaluate the classifier
Evaluate the overall accuracy of the classifier on the whole dataset. We can see the accuracy on both the train and test sets. This notebook is set up for you to select a preferred performance score of either sklearn's `accuracy_score` or `f1_score`. Feel free to set `performance_score` to either in the line below, or a metric of you choice!

In [None]:
performance_score = accuracy_score
calculated_performance_score = performance_score(y_train, prognostic_model.predict(X_train))
print(f"Evaluating accuracy on train set: {calculated_performance_score:0.4f}")

# Predicted values for whole dataset
y_pred = prognostic_model.predict(X_test)

calculated_performance_score = performance_score(y_test, y_pred)
print(f"Evaluating accuracy on test set: {calculated_performance_score:0.4f}")

### 2.6.4 Confusion Matrices
Create the confusion matrix for each of the ethnicities and the whole dataset to compare.

In [None]:
# Setup the figure axis
f, axes = plt.subplots(2, 3, figsize=(20, 10))

# Create the whole dataset confusion matrix and add it to the figure
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=prognostic_model.classes_)
disp.plot(ax=axes[0,0])
disp.ax_.set_title(f"Whole test dataset: {calculated_performance_score}")

# Get the indices to loop through
ethnicity_idxes = X_in["Ethnicity"].unique()
ethnicity_idxes.sort()

# for each ethnicity create a confusion matrix
for ethnicity_idx in ethnicity_idxes:
    # Get the slice of the dataset for each ethnicity
    X_test_per_ethnicity = X_test.loc[X_test["Ethnicity"] == ethnicity_idx]
    test_records_per_ethnicity_indices = X_test_per_ethnicity.index
    y_true = y_test.iloc[test_records_per_ethnicity_indices]

    # Generate prediction values for each ethnicity subpopulation
    y_pred_per_ethnicity = prognostic_model.predict(X_test_per_ethnicity)

    # Calculate the model performance for each ethnicity subpopulation
    calculated_performance_score = performance_score(y_true, y_pred_per_ethnicity)

    # Generate the confusion matrix and add it to the figure
    cm = confusion_matrix(y_true, y_pred_per_ethnicity)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=prognostic_model.classes_)
    ax_index = [0, ethnicity_idx + 1] if ethnicity_idx <= 1 else [1, (ethnicity_idx + 1) % 3]
    disp.plot(ax=axes[ax_index[0], ax_index[1]])
    disp.ax_.set_title(f"Ethnicity: {ethnicity_mapper[ethnicity_idx]} | Performance: {calculated_performance_score:0.4f}")
plt.show()

As you can see the performance of the model on the black population is significantly worse than the overall performance. Interestingly, however the model performs better on the East Asian subpopulation. This is likely to be due to random chance, i.e. it happens that the East Asian patients in this sample had features that are good predictors of the outcome, but this would not necessarily be true for a bigger sample from the same population. The Indigenous population is so poorly represented in the dataset, with only 3 records, that it is difficult to even accurately assess performance. However, the indication we have from these three records suggests performance may be poor.

This confirms by using a naive method like the one above, we would create a model that systematically performs worse for people of one ethnicity compared to another. This unfairness must be addressed.

## 2.7 The solution - Augment the dataset to improve the fairness

### 2.7.1 Load the data with the synthcity module
First we load the data with the GenericDataLoader. For this we need to pass the names of our `target_column` to the data loader. Then we can see the data by calling loader.dataframe() and we could also get the information about the data loader object with loader.info().

In [None]:
X_train["is_dead_at_time_horizon=14"] = y_train
loader = GenericDataLoader(
    X_train,
    target_column=f"is_dead_at_time_horizon={time_horizon}",
    sensitive_features=["Ethnicity"],
    random_state=42,
)

display(loader.dataframe())

### 2.7.2 Load/Create the synthetic data model
We are now going to generate synthetic data with a condition such that the new dataset is more balanced with regard to ethnicity. We will first define some values which we will use below.

In [None]:
model = "ctgan"
prefix = "fairness.conditional_augmentation"
random_state = 42

We will now create and fit a synthetic data model. The loading bar may suggest this will take a long time, but early stopping should limit it to approximately 7 minutes.

In [None]:
# Define saved model name
save_file = Path("saved_models") / f"{prefix}_{model}_numericalised_rnd={random_state}.bkp"

syn_model = Plugins().get(model, n_iter=100, random_state=random_state)
syn_model.fit(loader, cond=loader["Ethnicity"])
serialization.save_to_file(save_file, syn_model)

### 2.7.3 Generate fairer data
Use the synthetic data model to generate data using the `cond` argument to try and make the data evenly distributed across ethnicity. We will then augment the original, real dataset with the synthetic records from the under-represented ethnicities.

In [None]:
count = 10000
cond = [(i % 3) + 2 for i in range(count)] # set cond to an equal proportion of each  minority index
syn_data = syn_model.generate(count=count, cond=cond, random_state=random_state).dataframe()
augmented_data = pd.concat([
    X_train,
    syn_data.loc[syn_data["Ethnicity"] >= 2],
])

display(augmented_data)

print("Here is the ethnicity breakdown for the real dataset:")
print(loader["Ethnicity"].value_counts().rename(ethnicity_mapper))
print("\nHere is the ethnicity breakdown for the synthetic dataset:")
print(syn_data["Ethnicity"].value_counts().rename(ethnicity_mapper))
print("\nHere is the ethnicity breakdown for the augmented dataset:")
print(augmented_data["Ethnicity"].value_counts().rename(ethnicity_mapper))

Check the ethnicity breakdown again now to check we have augmented the under-represented groups properly. This is important as the conditional only optimizes the GAN here it does not guarantee that generated samples perfectly meet that condition. If you require rules to be strictly adhered to, use `Constraints` instead.

### 2.7.4 Re-evaluate the classifier on the augmented dataset
Lets train our classifier again trained on the augmented dataset.

In [None]:
augmented_y = augmented_data["is_dead_at_time_horizon=14"]
augmented_X = augmented_data.drop(columns=["is_dead_at_time_horizon=14"])
augmented_y.reset_index(drop=True, inplace=True)
augmented_X.reset_index(drop=True, inplace=True)


We need a model trained on the new data. We can load this from file, as before.

In [None]:
# saved_model_path = "/content/synthetic-data-lab/resources/fairness/fairness_cond_aug_random_forest_augmented_data.sav"

# The saved model was trained with the following code
prognostic_model.fit(augmented_X, augmented_y)
with open(saved_model_path, 'wb') as f:
    pickle.dump(prognostic_model, f)


# # Load the model trained on the whole dataset
# with open(saved_model_path, "rb") as f:
#     prognostic_model = pickle.load(f)

Evaluate the performance of the model on the real dataset according to the "train-on-synthetic, test-on-real rule".

In [None]:
y_pred = prognostic_model.predict(X_test)
calculated_performance_score = performance_score(y_test, y_pred)
print(f"evaluating test set: {calculated_performance_score:0.4f}")

### 2.7.5 New confusion matrices
Create the confusion matrix for the synthetic dataset.

In [None]:
# Setup the figure axis
f, axes = plt.subplots(2, 3, figsize=(20, 10))

# Create the whole dataset confusion matrix and add it to the figure
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=prognostic_model.classes_)
disp.plot(ax=axes[0,0])
# disp.plot(ax=axes[0])
disp.ax_.set_title(f"Whole Test Dataset | Performance: {calculated_performance_score:0.4f}")

# Get the indices to loop through
ethnicity_idxes = augmented_X["Ethnicity"].unique()
ethnicity_idxes.sort()

# for each ethnicity create a confusion matrix
for ethnicity_idx in ethnicity_idxes:
    # Get the slice of the dataset for each ethnicity
    X_test_per_ethnicity = X_test.loc[X_test["Ethnicity"] == ethnicity_idx]
    test_records_per_ethnicity_indicies = X_test_per_ethnicity.index
    y_true_per_ethnicity = y_test.iloc[test_records_per_ethnicity_indicies]

    # Generate prediction values for each ethnicity subpopulation
    y_pred_per_ethnicity = prognostic_model.predict(X_test_per_ethnicity)

    # Calculate the model performance for each ethnicity subpopulation
    calculated_performance_score = performance_score(y_true_per_ethnicity, y_pred_per_ethnicity)

    # Generate the confusion matrix and add it to the figure
    cm = confusion_matrix(y_true_per_ethnicity, y_pred_per_ethnicity)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=prognostic_model.classes_)
    # disp.plot(ax=axes[ethnicity_idx + 1])
    try:
      ax_index = [0, ethnicity_idx + 1] if ethnicity_idx <= 1 else [1, (ethnicity_idx + 1) % 3]
      disp.plot(ax=axes[ax_index[0], ax_index[1]])
      disp.ax_.set_title(f"Ethnicity: {ethnicity_mapper[ethnicity_idx]} | Performance: {calculated_performance_score:0.4f}")
    except ValueError as e:
      print(f"Not all quadrants contain values for {ethnicity_mapper[ethnicity_idx]}\n\n")
plt.show()

As you can hopefully see the new model trained on the synthetic data performs more similarly across the different populations.

Why not play with the different synthetic data generation methods and their parameters to see if you can achieve the same improvement in fairness, but with a higher performance? If you need help identifying the right methods then, remember you can list the available plugins with `Plugins().list()` and to learn what they do refer to the [docs](https://synthcity.readthedocs.io/en/latest/generators.html).

## 2.8 Extension - Using inference-time debiassing with DECAF

This section (2.8) involves fitting the decaf model, which can take more than 30 minutes, so it is included here as an extension exercise to try after the session, if you don't have time.

DECAF is an inference-time de-biasing method, where the data-generating process is embedded explicitly as a structural causal model in the input layers of the generator, allowing each variable to be reconstructed conditioned on its causal parents. We will use this method to create a different solution to the issue of fairness in this Brazilian COVID-19 dataset.

In this section, we consider the fairness issue that may arise when one operationalizes a clinical prognositic tool as a triaging system (to decide who is first to admit into the ICU). In this context, the notion of **[demographic parity](https://en.wikipedia.org/wiki/Fairness_(machine_learning)#Definitions_based_on_predicted_outcome)** is likely to be of interest.

### 2.8.1 Load the data
Lets load the data from file again to make sure we are working with the correct data and nothing has changed. As before we will construct it as a classification problem. But for this excercise we will set the problem up as a classic 2-class fairness task with one class representing the majority ethnic groups (White and Mixed) and one the Minority (Black, East-Asian, and Indigenous).

In [None]:
# Load the data from file
X = pd.read_csv(f"../data/Brazil_COVID/covid_normalised_numericalised.csv")

# Set it up as classification task
time_horizon = 14
X.loc[(X["Days_hospital_to_outcome"] <= time_horizon) & (X["is_dead"] == 1), f"is_dead_at_time_horizon={time_horizon}"] = 1
X.loc[(X["Days_hospital_to_outcome"] > time_horizon), f"is_dead_at_time_horizon={time_horizon}"] = 0
X.loc[(X["is_dead"] == 0), f"is_dead_at_time_horizon={time_horizon}"] = 0
X[f"is_dead_at_time_horizon={time_horizon}"] = X[f"is_dead_at_time_horizon={time_horizon}"].astype(int)

X.drop(columns=["is_dead", "Days_hospital_to_outcome"], inplace=True) # drop survival columns as they are not needed for a classification problem

# Set up ethnicity as two classes, minority and majority
X.loc[(X["Ethnicity"] == 0) | (X["Ethnicity"] == 1), "Ethnicity"] = 0
X.loc[(X["Ethnicity"] == 2) | (X["Ethnicity"] == 3) | (X["Ethnicity"] == 4), "Ethnicity"] = 1

Pass the data into the synthcity dataloader. We will just use the `GenericDataloader` here.

In [None]:
loader = GenericDataLoader(
    X,
    target_column="is_dead_at_time_horizon=14",
    sensitive_features=["Ethnicity"],
    random_state=42,
)

display(loader.dataframe())

### 2.8.2 Load/Create the synthetic datamodel using DECAF
First, we define some useful variables.

In [None]:
prefix = "fairness.causal_generation"
model = "decaf"
n_iter = 101
count = 6882 # set the count equal to the number of rows in the original dataset
random_state=6

Then load the synthetic model from file. If you want to try something a little different you can also fit your own model.

In [None]:
# Define saved model name
save_file = Path("saved_models") / f"{prefix}_{model}_n_iter={n_iter}_rnd={random_state}.bkp"

syn_model = decaf_plugin(struct_learning_enabled=True, n_iter=n_iter, random_state=random_state) # Pass struct_learning_enabled=True in order for the syn_model to learn the Dag
dag_before = syn_model.get_dag(loader.dataframe())
syn_model.fit(loader, dag=dag_before)
serialization.save_to_file(save_file, syn_model)

### 2.8.3 Generate the data
We can simply generate the de-biased dataset, by passing the biased edges we wish to remove from the data to `generate`.

In [None]:
bias={"Ethnicity": ["is_dead_at_time_horizon=14"]}
decaf_syn_data = syn_model.generate(count, biased_edges=bias, random_state=14)
display(decaf_syn_data.dataframe())

### 2.8.4 DECAF fairness tests

We will now check the dag for the synthetic data to see if the biased edgr has been removed. You may also wish to test demographic parity or fairness through unawareness metrics. The definitions for which can be seen in section 4.1 of the [DECAF paper](https://arxiv.org/abs/2110.12884).

In [None]:
print(syn_model.get_dag(decaf_syn_data.dataframe()))

## 2.9 Extension to the extension
Use the code block below as a space to complete the extension exercises.

### 2.9.1 Plot the DECAF synthetic data to show fairness

What is a simple plot we could make to show that the generated data is fair?

In [None]:
#@title (i) Plot to show fairness in data generated by DECAF

# 2 class ethnicity mapper
ethnicity_mapper_2_class = {
    0: "Majority",
    1: "Minority"
}

# Define the model
xgb_model = xgb.XGBClassifier(
    n_estimators=2000,
    learning_rate=0.01,
    max_depth=5,
    subsample=0.8,
    colsample_bytree=1,
    gamma=1,
    objective="binary:logistic",
    random_state=42,
)

# Load the model trained on the whole dataset
saved_model_path = "../resources/fairness/fairness_causal_gen_decaf_synthetic_data.json"
xgb_model.load_model(saved_model_path)
synth_data_to_predict = decaf_syn_data.dataframe().drop(columns=["is_dead_at_time_horizon=14"])
targets_synth_data_to_predict = decaf_syn_data["is_dead_at_time_horizon=14"]

# # The saved model was trained with the following code
# xgb_model.fit(synth_data_to_predict, targets_synth_data_to_predict)
# xgb_model.save_model(saved_model_path)

# Get the indices to loop through
ethnicity_idxes = decaf_syn_data["Ethnicity"].unique()
ethnicity_idxes.sort()

predictions = {}
for ethnicity_idx in ethnicity_idxes:
    synth_data_to_predict_per_ethnicity = synth_data_to_predict.loc[synth_data_to_predict["Ethnicity"] == ethnicity_idx]
    # display(synth_data_to_predict)

    synthetic_predictions = xgb_model.predict(synth_data_to_predict)

    unique, counts = np.unique(synthetic_predictions, return_counts=True)
    prediction_counts = {unique[0]: counts[0], unique[1]: counts[1]}
    predictions[ethnicity_mapper_2_class[ethnicity_idx]] = prediction_counts

prediction_frequency_data = pd.DataFrame(data={
    "Ethnicity": predictions.keys(),
    "0": [p_c[0] for p, p_c in predictions.items()],
    "1": [p_c[1] for p, p_c in predictions.items()],
})

prediction_frequency_data_m = pd.melt(prediction_frequency_data, id_vars="Ethnicity")
prediction_frequency_data_m = prediction_frequency_data_m.rename(
    columns={"variable": "is_dead_at_time_horizon", "value": "Prediction count"}
)


sns.catplot(
    data=prediction_frequency_data_m,
    x="Ethnicity",
    y="Prediction count",
    hue="is_dead_at_time_horizon",
    kind="bar"
)
plt.show()

# 3. Case Study 3 - Privacy
If you are running this case study without having also run the previous ones, make sure you have set up the runtime correctly, by running the "Set up runtime" and "Install requirements" cells at the top of the notebook.

## 3.1 Introduction
Machine learning (ML) is empowering more and more communities by using their historical datasets. Unfortunately, some sectors and use cases have been precluded from the benefits of ML, due to the requirement of their data to remain private. In this case study we will look at methods that aim to solve this problem by creating synthetic datasets that are not bound by the constraints of privacy.

## 3.2 The Task
Make a private version of the Brazil COVID-19 dataset, that could safely be used by anyone to create a COVID-19 survival analysis model, without the risk of (re-)identification of individuals.

## 3.3 Imports
Lets import all the modules we need for the second case study. Some we have already imported, but this means you can run work through the case studies in any order, if you want to revisit them after the lab.

In [None]:
# Standard
import sys
import warnings
from pathlib import Path

# 3rd party
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report

# synthcity
import synthcity.logger as log
from synthcity.utils import serialization
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import (GenericDataLoader, SurvivalAnalysisDataLoader)
from synthcity.metrics import Metrics

# Configure warnings and logging
warnings.filterwarnings("ignore")

# Set the level for the logging
# log.add(sink=sys.stderr, level="DEBUG")
log.remove()

## 3.4 Load the data

Load the data from file into a SurvivalAnalysisDataLoader object. For this we need to pass the names of our `target_column` and our `time_to_event_column` to the data loader. Then we can see the data by calling loader.dataframe() and get the information about the data loader object with loader.info().

In [None]:
X = pd.read_csv(f"../data/Brazil_COVID/covid_normalised_numericalised.csv")
X = X[X.Days_hospital_to_outcome != 0]
loader = SurvivalAnalysisDataLoader(
    X,
    target_column="is_dead",
    time_to_event_column="Days_hospital_to_outcome",
    sensitive_features=["Age", "Sex", "Ethnicity", "Region"],
    random_state=42,
)

print(loader.info())
display(loader.dataframe())

## 3.5 Load/Create synthetic datasets

We can list the available synthetic generators by calling list() on the Plugins object.

In [None]:
print(Plugins().list())

From the above list we are going to select the synthetic generation models for privacy: "adsgan", and "pategan". Then we will create and fit the synthetic model before using it to generate a synthetic dataset.

In [None]:
outdir = Path("saved_models")
prefix = "privacy"
n_iter = 100
random_state=1
models={
    "adsgan":None,
    "pategan":None,
}

## 3.6 Evaluate the generated synthetic dataset in terms of privacy
We can select some metrics to choose. The full list of available metrics can be seen by calling Metrics().list(). We are going to use the metrics associated with detection of the synthetic data and data privacy. Then we will print them to a DataFrame to look at the results.
<img src="https://drive.google.com/uc?export=view&id=1Mt9_dQuhGFGl1XhEL3XADqHr-E0V6NE2" alt="Evaluating Different Aspects of Synthetic Data and Their Metrics">


In [None]:
eval_results = {}
for model in models:
    syn_model = Plugins().get(model, n_iter=n_iter, random_state=random_state)
    syn_model.fit(loader)
    models[model] = syn_model
    save_file = outdir / f"{prefix}.{model}_n_iter={n_iter}_rnd={random_state}.bkp"
    serialization.save_to_file(save_file, syn_model)
    selected_metrics = {
        "detection": ["detection_xgb", "detection_mlp", "detection_gmm"],
        "privacy": ["k-anonymization", "k-map", "distinct l-diversity", "identifiability_score"],
        'performance': ['linear_model', 'mlp', 'xgb'],
    }
    my_metrics = Metrics()
    selected_metrics_in_my_metrics = {k: my_metrics.list()[k] for k in my_metrics.list().keys() & selected_metrics.keys()}
    X_syn = syn_model.generate(count=6882, random_state=1)
    evaluation = my_metrics.evaluate(
        loader,
        X_syn,
        task_type="survival_analysis",
        metrics=selected_metrics_in_my_metrics,
        workspace="workspace",
    )
    # Select the metrics that we need
    display_metrics = [
      "performance.xgb.syn_ood.c_index",
      "performance.linear_model.syn_ood.c_index",
      "performance.mlp.syn_ood.c_index",
      "detection.detection_xgb.mean",
      "detection.detection_mlp.mean",
      "detection.detection_gmm.mean",
      "detection.detection_linear.mean",
      "privacy.k-anonymization.syn",
      "privacy.k-map.score",
      "privacy.distinct l-diversity.syn",
      "privacy.identifiability_score.score",
    ]

    evaluation = evaluation.loc[display_metrics]
    display(evaluation)
    eval_results[model] = evaluation

### 3.6.1 Display the evalution results
The above table contains all the information we need to evaluate the methods, but lets convert it to a format where it is easier to compare the methods.

In [None]:
means = []
for plugin in eval_results:
    data = eval_results[plugin]["mean"]
    directions = eval_results[plugin]["direction"].to_dict()
    means.append(data)

out = pd.concat(means, axis=1)
out = out.set_axis(eval_results.keys(), axis=1)

bad_highlight = "background-color: lightcoral;"
ok_highlight = "background-color: green;"
default = ""


def highlights(row):
    metric = row.name
    if directions[metric] == "minimize":
        best_val = np.min(row.values)
        worst_val = np.max(row)
    else:
        best_val = np.max(row.values)
        worst_val = np.min(row)

    styles = []
    for val in row.values:
        if val == best_val:
            styles.append(ok_highlight)
        elif val == worst_val:
            styles.append(bad_highlight)
        else:
            styles.append(default)

    return styles


out.style.apply(highlights, axis=1)

## 3.7 Results of evaluation
We are using three types of metric here: performance, detection and privacy. Performance metrics explain the utility of a synthetic dataset. Detection metrics measure the ability to identify the real data compared to the synthetic data. The privacy metrics measure how easy it would be to re-identify a patient given the quasi-identifying fields in the dataset.
Generally, ADSGAN performs best in synthetic data detection and performance tasks, then PATEGAN.

k-anonymization - risk of re-identification is approximately 1/k according to [this paper](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2528029/). Therefore the risk of re-identification is lowest for ADSGAN then PATEGAN.

k-map - is a metric where every combination of values for the quasi-identifiers appears at least k times in the synthetic dataset. ADSGAN performs worse than PATEGAN.

l-diversity - is a similar metric to k-anonymization, but ir is also concerned with the diversity of the generalized block. We see broardly the same pattern as for k-anonymization.

identifiability_score - Risk of re-identification as defined in [this paper](https://ieeexplore.ieee.org/document/9034117).


## 3.8 Synthetic Data Quality
To get a good sense of the quality of the synthetic datasets and validate our previous conclusion. Lets plot the correlation/strength-of-association of features in data-set with both categorical and continuous features using:
- Pearson's R for continuous-continuous cases
- Correlation Ratio for categorical-continuous cases
- Cramer's V or Theil's U for categorical-categorical cases

In each of the following plots we are looking for the synthetic data to be as similar to the real data as possible. That is minimal values for Jensen-Shannon distance and pairwise correlation distance, and T-SNEs with similar looking distribution in the representation space.

In [None]:
import matplotlib.pyplot as plt
for model in models:
    models[model].plot(plt, loader, plots=["associations","marginal", "tsne"])
    plt.show()

## 3.9. Extension
Use the code block below as a space to complete the extension exercises below.

### 3.9.1 Training models on both sets of data

1) Use the metrics to get a the performance of a model trained on the real dataset to put our performance scores in context.

2) Train your own downstream model on both the original dataset and each of the private datasets we have generated to see if you reach the same conclusion. Which privacy method provides the best performance and what are the trade-offs?

# 4. Case Study 4 - Augmentation
If you are running this case study without having also run the previous ones, make sure you have set up the runtime correctly, by running the "Set up runtime" and "Install requirements" cells at the top of the notebook.

## 4.1 Introduction
One of the most common problems for machine Learning practitioners is only having a small dataset for the specific problem they are working on. This traditionally can often lead to dead ends or hold-ups in projects while more data is collected. However, if you have different dataset that shares common features then Synthcity may have the ability to solve the issue for you with "Augmentation by domain adaption".

## 4.2 The Task
Augment a small dataset using the concept of domain adaptation (or transfer learning). For this we will be using a RadialGAN as discussed in [this paper](https://arxiv.org/pdf/1802.06403.pdf).

## 4.3 Imports
Lets import all the modules we need for the second case study. Some we have already imported, but this means you can run work through the case studies in any order, if you want to revisit them after the lab.

In [None]:
# stdlib
import warnings
from pathlib import Path

# 3rd Party
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import auc, accuracy_score
import xgboost as xgb
import matplotlib.pyplot as plt
import seaborn as sns

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.utils import serialization

warnings.filterwarnings("ignore")
# Set the level for the logging
# log.add(sink=sys.stderr, level="DEBUG")
log.remove()

# Set up paths to resources
AUG_RES_PATH = Path("../resources/augmentation/")

## 4.4 The Scenario
Brazil is divided geopolitically into five macro-regions: North, North-East, Central-West, South-East, and South. For this case study, we will be acting as government officials in the Central-West Region of Brazil. Central-West Brazil is the smallest region in the country by population. It is also one of the larger and more rural regions. This means the number of COVID-19 patient records is significantly smaller compared to the larger regions.

<img src="https://drive.google.com/uc?export=view&id=1880Ux6qEb3MQSAirr4_908p20mvqQGot" alt="Brazil Region Map" width="500"/>

COVID-19 hit different regions at different time. Cases peaked later in the Central-West than in the more densely-populated and well-connected regions. Giving us the problem of scarce data in terms of COVID-19 patients in the region, but the potential lifeline of having larger datasets from the other regions, which we can learn from in order to augment our dataset. We cannot simply train our model on the data from all regions, because there is significant co-variate shift between the different regions and so we will achieve a better classifier by training on solely Central-West data, even if it is synthetic.


## 4.5 Load the data
Lets set it up as a classification task with a death at time horizon column, as we did in a previous case study.

In [None]:
time_horizon = 14
X = pd.read_csv(f"../data/Brazil_COVID/covid_normalised_numericalised.csv")

X.loc[(X["Days_hospital_to_outcome"] <= time_horizon) & (X["is_dead"] == 1), f"is_dead_at_time_horizon={time_horizon}"] = 1
X.loc[(X["Days_hospital_to_outcome"] > time_horizon), f"is_dead_at_time_horizon={time_horizon}"] = 0
X.loc[(X["is_dead"] == 0), f"is_dead_at_time_horizon={time_horizon}"] = 0
X[f"is_dead_at_time_horizon={time_horizon}"] = X[f"is_dead_at_time_horizon={time_horizon}"].astype(int)

X.drop(columns=["is_dead", "Days_hospital_to_outcome"], inplace=True) # drop survival columns as they are not needed for a classification problem

Here we define a region mapper, which maps the region encoding to the real values. These can be found in `synthetic-data-lab/data/Brazil_COVID/Brazil_COVID_data.md`.

In [None]:
# Define the mappings from region index to region
region_mapper = {
    0: "Central-West",
    1: "North",
    2: "North-East",
    3: "South",
    4: "South-East",
}

As we are acting as officials from Central-West Brazil, we need to split the data into data from our region and data from other regions. We then drop Region column to simulate not knowing what region the data is from. It is either in our region's dataset or in the dataset for other regions.

In [None]:
our_region_index = 0

X_our_region_only = X.loc[X["Region"] == our_region_index].copy()
X_other_regions = X.loc[X["Region"] != our_region_index].copy()
X_all_regions = X.copy()

display(X_our_region_only)

## 4.6 The problem

Lets see how a model trained just on our data from the Central-West region performs.

### 4.6.1 Set up the data splits using train_test_split from sklearn.

In [None]:
our_region_y = X_our_region_only["is_dead_at_time_horizon=14"]
our_region_X = X_our_region_only.drop(columns=["is_dead_at_time_horizon=14"])

X_train, X_test, y_train, y_test = train_test_split(our_region_X, our_region_y, random_state=4)
X_train.reset_index(drop=True, inplace=True)
X_test.reset_index(drop=True, inplace=True)
y_train.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)

### 4.6.2 Load classifier
Load the trained xgboost classifier, which has been trained on Central-West data only.

In [None]:
# Define the model
xgb_model = xgb.XGBClassifier(
    n_estimators=2000,
    learning_rate=0.01,
    max_depth=5,
    subsample=0.8,
    colsample_bytree=1,
    gamma=1,
    objective="binary:logistic",
    random_state=42,
)

# Load the model trained on the whole dataset
saved_model_path = AUG_RES_PATH / f"augmentation_xgboost_real_{region_mapper[our_region_index]}_data.json"
xgb_model.load_model(saved_model_path)

# # # The saved model was trained with the following code:
# xgb_model.fit(X_train, y_train)
# xgb_model.save_model(saved_model_path)

### 4.6.3 Evaluate classifier
Now print the performance of the model trained on only Central-West data.

In [None]:
calculated_accuracy_score_train = accuracy_score(y_train, xgb_model.predict(X_train))
y_pred = xgb_model.predict(X_test)
calculated_accuracy_score_test = accuracy_score(y_test, y_pred)
print(f"Evaluating accuracy: train set: {calculated_accuracy_score_train:0.4f} | test set: {calculated_accuracy_score_test:0.4f}")

As you can see we are significantly over-fitting due to the very small dataset. The performance does not look as good as it could be.

## 4.7 A simple solution? Concatenation, not Augmentation

One simple solution is to just concatenate the dataset with data from the other regions, but this will not account for any difference in the populations.

### 4.7.1 Comparing the populations
Let's examine how similar (or not) the populations are by plotting the distributions of a few of the data fields for each region.

In [None]:
# List columns by type to plot differently
continuous_columns = ["Age"]
descrete_cols = [
    "Sex",
    "Fever",
    "Cough",
    "Sore_throat",
    "Shortness_of_breath",
    "Respiratory_discomfort",
    "SPO2",
    "Dihareea",
    "Vomitting",
    "Cardiovascular",
    "Asthma",
    "Diabetis",
    "Pulmonary",
    "Immunosuppresion",
    "Obesity",
    "Liver",
    "Neurologic",
    "Renal",
]

# Plot continuous columns
for column in continuous_columns:
    X_all_regions.groupby("Region")[column].plot(kind="kde")
    plt.legend(region_mapper.values(), title="Region")
    plt.title(f"{column}")
    plt.show()

# Plot Ethnicity separately as it is our focus in this fairness exercise
column = "Ethnicity"
dft = X_all_regions.replace({"Region": region_mapper}).melt(id_vars=["Region"]).loc[
    X_all_regions.melt()["variable"] == column
]
dfa = dft.value_counts().reset_index()
# Rename columns for clarity, include the original column name
dfa.columns = ['Region', 'variable', 'value', 'count_a']

dfb = dft.value_counts(subset=["Region"]).reset_index()
dfb.columns = ['Region', 'count_b']  # Rename columns for clarity

dfc = dfa.merge(dfb, on='Region', how='inner')
dfc['prob'] = dfc['count_a'] / dfc['count_b']

sns.barplot(
    data=dfc,
    x="value",
    y="prob",
    hue="Region",
)
plt.title(f"{column}")
plt.show()

# Plot all other discrete columns in grid
fig, ax = plt.subplots(nrows=3, ncols=6, figsize=(35, 20))
for idx, column in enumerate(descrete_cols):
    dft = X_all_regions.replace({"Region": region_mapper}).melt(id_vars=["Region"]).loc[
        X_all_regions.melt()["variable"] == column
    ]
    dfa = dft.value_counts().reset_index()
    dfa.columns = ['Region', 'variable', 'value', 'count_a']  # Rename columns for clarity

    dfb = dft.value_counts(subset=["Region"]).reset_index()
    dfb.columns = ['Region', 'count_b']  # Rename columns for clarity

    dfc = dfa.merge(dfb, on='Region', how='inner')
    dfc['prob'] = dfc['count_a'] / dfc['count_b']

    sns.barplot(
        data=dfc,
        x='value',
        y='prob',
        ax=ax[idx // 6][idx % 6],
        hue="Region",
    )
    ax[idx // 6][idx % 6].title.set_text(f"{column}")

plt.show()


Analyzing the plots above it appears that there are some differences in the populations. The distributions for the symptoms and co-morbidities seem to have different shapes in many cases. To pick a few examples, having a renal co-morbidity is much more common in the North or North-East than South or Central-West regions; the symptom of cough is more common in the North than other regions; and the Central-West region seems to be on average younger than other regions. There are other examples we could have listed as well. Spend a moment reviewing these plots to spot differences for youself.

### 4.7.2 Set up the training and testing data sets for the model

Make sure the training sets come from the all region dataset, but the test sets come from our region.

In [None]:
# drop Region column to simulate a simple concatenation of two datasets, one from the Central-West, one from the rest of Brazil.
X_our_region_only_for_baseline = X_our_region_only.drop(columns=["Region"])
X_all_regions_for_baseline = X_all_regions.drop(columns=["Region"])

concat_y = X_all_regions_for_baseline["is_dead_at_time_horizon=14"]
concat_X = X_all_regions_for_baseline.drop(columns=["is_dead_at_time_horizon=14"])

X_train, _, y_train, _ = train_test_split(concat_X, concat_y, random_state=4)
X_train.reset_index(drop=True, inplace=True)
y_train.reset_index(drop=True, inplace=True)

our_region_y = X_our_region_only_for_baseline["is_dead_at_time_horizon=14"]
our_region_X = X_our_region_only_for_baseline.drop(columns=["is_dead_at_time_horizon=14"])

_, X_test, _, y_test = train_test_split(our_region_X, our_region_y, random_state=4)
X_test.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)

### 4.7.3 Load the model trained on data from all regions.

In [None]:
# Define the model
xgb_model = xgb.XGBClassifier(
    n_estimators=2000,
    learning_rate=0.01,
    max_depth=5,
    subsample=0.8,
    colsample_bytree=1,
    gamma=1,
    objective="binary:logistic",
    random_state=42,
)

# Load the model trained on the whole dataset
# saved_model_path = AUG_RES_PATH / f"augmentation_xgboost_real_all_data.json"
# xgb_model.load_model(saved_model_path)

# # The saved model was trained with the following code:
xgb_model.fit(X_train, y_train)
xgb_model.save_model(saved_model_path)

### 4.7.4 Show the performance of the model.

In [None]:
calculated_accuracy_score_train = accuracy_score(y_train, xgb_model.predict(X_train))
y_pred = xgb_model.predict(X_test)
calculated_accuracy_score_test = accuracy_score(y_test, y_pred)
print(f"Evaluating accuracy: train set: {calculated_accuracy_score_train:0.4f} | test set: {calculated_accuracy_score_test:0.4f}")

As you can see our accuracy does improve, but we can do better!

Why else my this not be the best, or even a possible approach?
<details>
<summary>Show answer</summary>
There may well be cases where there is a greater co-variate shift that impacts this accuracy to a much greater extent, i.e. the population of one region may be much more different compared to the overall population in some settings. It is also worth bearing in mind that there are contexts where the above approach is not even an option, such as in the case of only partially overlapping (or missing) features. Use cannot concatenate datasets that don't share columns, but using a RadialGAN to augment a dataset like this is still possible.
</details>

## 4.8 The Solution

Augment this dataset with the use of a RadialGAN.

### 4.8.1 Load the data

First, lets load the super-set of data from all regions into the `GenericDataLoader` object.

In [None]:
loader = GenericDataLoader(
    X, # X is the dataframe which is a superset of all region data
    target_column="is_dead_at_time_horizon=14", # The column containing the labels to predict
    sensitive_features=["Ethnicity"], # The sensitive features in the dataset
    domain_column="Region", # This labels the domain that each record is from. Where it is `0` it is from our small dataset.
    random_state=42,
)

### 4.8.2 Load/Create the synthetic data model
Lets use a RadialGan to augment the data. We need to load the plugin and then fit it to the dataloader object.

In [None]:
outdir = Path("saved_models")
prefix = "augmentation"
model="radialgan"
n_iter = 49
random_state = 8

We will now either create and fit a synthetic data model then save that model to file, or load one we have already saved from file.

In [None]:
# Define saved model name
save_file = outdir / f"{prefix}.{model}_numericalised_{region_mapper[our_region_index]}_n_iter={n_iter}_rnd={random_state}_final.bkp"
# Load if available
if Path(save_file).exists():
    syn_model = serialization.load_from_file(save_file)
# Create if not available
else:
  syn_model = Plugins().get(model, n_iter=n_iter, random_state=random_state)
  syn_model.fit(loader)
  serialization.save_to_file(save_file, syn_model)

### 4.8.3 Augment the dataset

Lets use our synthetic model to generate some data and use it to augment our original dataset.

In [None]:
n_gen_records = 500

synth_data = syn_model.generate(n_gen_records, domains=[our_region_index], random_state=random_state)

# Now we can augment our original dataset with our new synthetic data
augmented_data = pd.concat([
    synth_data.dataframe(),
    X_our_region_only,
])

print(f"{len(synth_data['Region'])} synthetic records generated.")
print(f"{len(X_our_region_only['Region'])} original real records.")
print(f"{len(X_our_region_only['Region']) + len(synth_data['Region'])} records in the augmented dataset.")

### 4.8.4 Set up the data for the classifier
Now we need to test a model trained on the augmented dataset, so we need to set up our data splits again.

In [None]:
augmented_y = augmented_data["is_dead_at_time_horizon=14"]
augmented_X = augmented_data.drop(columns=["is_dead_at_time_horizon=14"])

augmented_y.reset_index(drop=True, inplace=True)
augmented_X.reset_index(drop=True, inplace=True)

our_region_y = X_our_region_only["is_dead_at_time_horizon=14"]
our_region_X = X_our_region_only.drop(columns=["is_dead_at_time_horizon=14"])

_, X_test, _, y_test = train_test_split(our_region_X, our_region_y, random_state=4)
X_test.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)

### 4.8.5 Load the trained model

In [None]:
# Define the model
xgb_model = xgb.XGBClassifier(
    n_estimators=2000,
    learning_rate=0.01,
    max_depth=5,
    subsample=0.8,
    colsample_bytree=1,
    gamma=1,
    objective="binary:logistic",
    random_state=42,
)

# # Load the model trained on the whole dataset
# saved_model_path = AUG_RES_PATH / f"augmentation_xgboost_augmented_data.json"
# xgb_model.load_model(saved_model_path)

# The saved model was trained with the following code:
xgb_model.fit(augmented_X, augmented_y)
xgb_model.save_model(saved_model_path)

In [None]:
y_pred = xgb_model.predict(augmented_X)
calculated_accuracy_score_test = accuracy_score(augmented_y, y_pred)
print(f"Accuracy on training set (augmented data): {calculated_accuracy_score_test:0.4f}")
y_pred = xgb_model.predict(X_test)
calculated_accuracy_score_test = accuracy_score(y_test, y_pred)
print(f"Accuracy on real data test set: {calculated_accuracy_score_test:0.4f}")

### 4.8.7 Results

The model over-fitting on the training data is significantly reduced and the accuracy that is much higher than for the small dataset comprised solely of data from the Central-West region. We also see a significant improvement over training the model on the superset of the real data.

## 4.9 Extension
Use the code block below as a space to complete the extension exercises below.

### 4.9.1 Can you generate some more augmented datasets to answer the following questions?
1) How much synthetic data should you create for best results? Is there a minimum value that you would use? Is there a maximum value? In each case, why does that minium/maximum occur?
<br/>2) How much does changing the RadialGan plugin parameter `n_iter` change the quality of the generated data?

In [None]:
#@title (i) How many synthetic records should you create?
accuracies = []
generated_records = [
    100,
    300,
    500,
    700,
    900,
    1000,
    2000,
    3000,
    4000,
    5000,
    8000,
    10000,
    # 15000,
    # 20000,
    # 50000,
    # 100000,
    # 1000000,
] # Larger values take longer to run
repeats = 1 # Can be set higher to reduce the variance by using mean value
for n_gen_records in generated_records:
    rep_vals = []
    for i in range(repeats):
        synth_data = syn_model.generate(n_gen_records, domains=[our_region_index])

        # Now we can augment our original dataset with our new synthetic data
        augmented_data = pd.concat([
            synth_data.dataframe(),
            X_our_region_only,
        ])

        augmented_y = augmented_data["is_dead_at_time_horizon=14"]
        augmented_X_in = augmented_data.drop(columns=["is_dead_at_time_horizon=14"])

        X_train, X_test, y_train, y_test = train_test_split(augmented_X_in, augmented_y, random_state=4)
        X_train.reset_index(drop=True, inplace=True)
        X_test.reset_index(drop=True, inplace=True)
        y_train.reset_index(drop=True, inplace=True)
        y_test.reset_index(drop=True, inplace=True)

        # Train model on whole dataset
        xgb_model = xgb.XGBClassifier(
            n_estimators=2000,
            learning_rate=0.01,
            max_depth=5,
            subsample=0.8,
            colsample_bytree=1,
            gamma=1,
            objective="binary:logistic",
            random_state=42,
        )
        xgb_model.fit(X_train, y_train)

        y_pred = xgb_model.predict(X_test)
        calculated_accuracy_score_train = accuracy_score(y_train, xgb_model.predict(X_train))
        calculated_accuracy_score_test = accuracy_score(y_test, y_pred)
        # print(f"Evaluating accuracy: n_gen_records: {n_gen_records} train set: {calculated_accuracy_score_train}| test set: {calculated_accuracy_score_test}")
        rep_vals.append(calculated_accuracy_score_test)
    accuracies.append(np.mean(rep_vals))

d = {"generated_records": generated_records, "accuracies": accuracies}
accuracy_data= pd.DataFrame(d)
plot = sns.lineplot(
    y="accuracies",
    x="generated_records",
    data=accuracy_data
).set(title=f"Augmenting {region_mapper[our_region_index]}, n_iter={500}, without {'original data'}")


## 4.10 Benchmarking augmentation
Use the benchmarking interface (documented [here](https://synthcity.readthedocs.io/en/latest/generated/synthcity.benchmark.html) and covered in [tutorial 2](https://colab.research.google.com/github/vanderschaarlab/synthcity/blob/main/tutorials/tutorial2_benchmarks.ipynb)) to see if you can improve the performance of a downstream classifer. You may have to use different models and parameters. There is some code to start you off below, but you are encouraged to change it up and experiment.

In [None]:
# @title Helpful code to start you off
# stdlib
import sys
import warnings

warnings.filterwarnings("ignore")

# third party
from sklearn.datasets import load_iris

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.benchmark import Benchmarks

""" You can limit this to just the augmentation metrics by passing a dictionary
of the the desired metrics, as we did in section 3.6 of this notebook"""

score = Benchmarks.evaluate(
    [("ctgan", "ctgan", {})],
    loader,
    synthetic_size=len(X),
    augmentation_rule="equal",
    strict_augmentation=True,
    repeats=1,
)
Benchmarks.print(score)