# Case Study 1 - Data Modality

## The Task
Get used to loading datasets with the library and generating synthetic data from them, whatever the modality of the real data.

### Imports
Lets get the imports out of the way. We import the required standard and 3rd party libraries and relevant Synthcity modules. We can also set the level of logging here, using Synthcity's bespoke logger. 

In [4]:
# Standard
import sys
import warnings
from pathlib import Path

# 3rd party
import numpy as np
import pandas as pd

# synthcity
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import (GenericDataLoader, SurvivalAnalysisDataLoader, TimeSeriesDataLoader,TimeSeriesSurvivalDataLoader)

# Configure warnings and logging
warnings.filterwarnings("ignore")

# Set the level for the logging
# log.add(sink=sys.stderr, level="DEBUG")
log.remove()

  from .autonotebook import tqdm as notebook_tqdm


# Loading data of different modalities

In this notebook we will load different datasets into synthcity and show that data of many different modalities can be used to generate synthetic data using this module.

### Static Data
Now we will start with the simplest example, static tabular data. For this, we will use the diabetes dataset from sklearn. First, we need to load the dataset.

In [None]:
from sklearn.datasets import load_diabetes

X, y = load_diabetes(return_X_y=True, as_frame=True)
X["target"] = y
display(X)

Then we pass it to the `GenericDataLoader` object from `synthcity`.

In [None]:
loader = GenericDataLoader(
    X,
    target_column="target",
    sensitive_columns=["sex"],
)

We can print out different methods that are compatible with our data by calling `Plugins().list()` with a relevant list passed to the categories parameter.

In [None]:
print(Plugins(categories=["generic"]).list())

No need to worry about the code in this next block here, we will go into lots of detail in how to generate synthetic data in the case studies to come. It is here purely to demonstrate that our dataset can be used to generate synthetic data using the synthcity module. We are using the method `marginal_distributions` to generate the synthetic data, which is one of the available debugging methods.

In [None]:
syn_model = Plugins().get("marginal_distributions")
syn_model.fit(loader)
syn_model.generate(count=10).dataframe()

## Static survival
Next lets look at censored data. Censoring is a form of missing data problem in which time to event is not observed for reasons such as termination of study before all recruited subjects have shown the event of interest or the subject has left the study prior to experiencing an event. Censoring is common in survival analysis. For our next example we will load a static survival dataset. Our dataset this time is a veteran lung cancer dataset provided by scikit-survival. 

First, load the dataset.

In [5]:
from sksurv.datasets import load_veterans_lung_cancer

data_x, data_y = load_veterans_lung_cancer()
data_x["status"], data_x["survival_in_days"] = [record[0] for record in data_y], [record[1] for record in data_y]
display(data_x)

Unnamed: 0,Age_in_years,Celltype,Karnofsky_score,Months_from_Diagnosis,Prior_therapy,Treatment,status,survival_in_days
0,69.0,squamous,60.0,7.0,no,standard,True,72.0
1,64.0,squamous,70.0,5.0,yes,standard,True,411.0
2,38.0,squamous,60.0,3.0,no,standard,True,228.0
3,63.0,squamous,60.0,9.0,yes,standard,True,126.0
4,65.0,squamous,70.0,11.0,yes,standard,True,118.0
...,...,...,...,...,...,...,...,...
132,65.0,large,75.0,1.0,no,test,True,133.0
133,64.0,large,60.0,5.0,no,test,True,111.0
134,67.0,large,70.0,18.0,yes,test,True,231.0
135,65.0,large,80.0,4.0,no,test,True,378.0


Pass it to the DataLoader. This time we will use the `SurvivalAnalysisDataLoader`. We need to pass it the data, the name of the column that contains our labels or targets to `target_column` and the the name of the column  containing the time elapsed when the event occurred (the event defined by the target column) to `time_to_event_column`. Calling `info()` on the loader object allows us to see the information about the dataset we have just prepared.

In [6]:

loader = SurvivalAnalysisDataLoader(
    data_x,
    target_column="status",
    time_to_event_column="survival_in_days",
)
print(loader.info())


{'data_type': 'survival_analysis', 'len': 137, 'static_features': ['Age_in_years', 'Celltype', 'Karnofsky_score', 'Months_from_Diagnosis', 'Prior_therapy', 'Treatment', 'status', 'survival_in_days'], 'sensitive_features': [], 'important_features': [], 'outcome_features': ['status'], 'target_column': 'status', 'time_to_event_column': 'survival_in_days', 'time_horizons': [250.5, 500.0, 749.5], 'train_size': 0.8}


If we get the `marginal_distributions` plugin again and fit it to the `loader` object, we can then call `generate` to produce the synthetic data.

In [8]:
syn_model = Plugins().get("marginal_distributions")
syn_model.fit(loader)
syn_model.generate(count=10)

Unnamed: 0,Age_in_years,Celltype,Karnofsky_score,Months_from_Diagnosis,Prior_therapy,Treatment,status,survival_in_days
0,59.794235,smallcell,50.0,48.197961,no,standard,True,548.715877
1,67.6139,large,60.0,62.506286,yes,test,False,714.758988
2,62.329879,squamous,40.0,52.83765,yes,test,False,602.557849
3,59.60951,smallcell,10.0,47.859954,no,standard,True,544.793417
4,53.911776,large,40.0,37.434313,yes,test,False,423.80749
5,64.357023,large,75.0,56.546894,yes,test,False,645.602325
6,54.566599,large,85.0,38.6325,yes,test,False,437.712037
7,75.913331,large,40.0,77.692478,yes,test,False,890.989455
8,79.29215,squamous,50.0,83.874997,yes,test,False,962.735435
9,52.021751,large,20.0,33.975971,yes,test,False,383.674636


### Regular Time Series

In this next example we will load up a simple regular time series dataset and show that it is compatible with Synthcity. The temporal data must be passed to the loader as a list of dataframes, where each dataframe in the list refers to a different record and contains all time points for the record. So, there is a small amount of pre-processing to get our data into the right shape. As it is a regular time series we can simply pass a sequential list for each record.

The dataset we will use here is the basic motions dataset provided by SKTime. So, we need to import the library.

In [9]:
from sktime.datasets import load_basic_motions

Load the data and re-format it into a list of dataframes, where each dataframe in the list refers to a different record and contains all time points for the record. We also need the outcomes as a dataframe and the observation times as a list of time steps for each record. As this is a regular time series our time steps can simply be a sequential list of integers.

In [14]:

X, y = load_basic_motions(
    split="TRAIN", return_X_y=True, return_type="pd-multiindex"
)
num_instances = len(set((x[0] for x in X.index)))
num_time_steps = len(set((x[1] for x in X.index)))

temporal_data = [X.loc[i] for i in range(num_instances)]
y = pd.DataFrame(y, columns=["label"])
observation_times = [list(range(num_time_steps)) for i in range(num_instances)]

Pass the data we just prepared to the DataLoader. Here we will use the `TimeSeriesDataLoader`. Then we will print out the loader info to check everything looks correct.

In [21]:
loader = TimeSeriesDataLoader(
    temporal_data=temporal_data,
    observation_times=observation_times,
    outcome=y,
)
display(loader.dataframe())
print(loader.info())

Unnamed: 0,seq_id,seq_time_id,seq_temporal_dim_0,seq_temporal_dim_1,seq_temporal_dim_2,seq_temporal_dim_3,seq_temporal_dim_4,seq_temporal_dim_5,seq_out_label
0,0,0,0.079106,0.394032,0.551444,0.351565,0.023970,0.633883,standing
1,0,1,0.079106,0.394032,0.551444,0.351565,0.023970,0.633883,standing
2,0,2,-0.903497,-3.666397,-0.282844,-0.095881,-0.319605,0.972131,standing
3,0,3,1.116125,-0.656101,0.333118,1.624657,-0.569962,1.209171,standing
4,0,4,1.638200,1.405135,0.393875,1.187864,-0.271664,1.739182,standing
...,...,...,...,...,...,...,...,...,...
3995,39,95,1.239144,-6.142442,0.028264,-2.309144,1.472845,-0.998765,badminton
3996,39,96,0.261434,0.205915,-0.224944,-0.524684,0.769715,0.157139,badminton
3997,39,97,2.490353,-0.878765,-0.597296,0.111862,-0.117188,-0.050604,badminton
3998,39,98,4.122120,0.911620,-0.465409,0.535338,0.197090,0.442120,badminton


{'data_type': 'time_series', 'len': 4000, 'static_features': [], 'temporal_features': ['dim_0', 'dim_1', 'dim_2', 'dim_3', 'dim_4', 'dim_5'], 'outcome_features': ['label'], 'outcome_len': 1.0, 'window_len': 100, 'sensitive_features': [], 'important_features': [], 'random_state': 0, 'train_size': 0.8, 'fill': nan, 'seq_static_features': [], 'seq_temporal_features': ['seq_temporal_dim_0', 'seq_temporal_dim_1', 'seq_temporal_dim_2', 'seq_temporal_dim_3', 'seq_temporal_dim_4', 'seq_temporal_dim_5'], 'seq_outcome_features': ['seq_out_label'], 'seq_offset': 0, 'seq_id_feature': 'seq_id', 'seq_time_id_feature': 'seq_time_id', 'seq_features': ['seq_id', 'seq_time_id', 'seq_temporal_dim_0', 'seq_temporal_dim_1', 'seq_temporal_dim_2', 'seq_temporal_dim_3', 'seq_temporal_dim_4', 'seq_temporal_dim_5', 'seq_out_label']}


Now we are ready to produce the synthetic data. We will use the `timegan` plugin to handle the timeseries data. As we don't care about the quality of the dataset here, we just want to check that it is compatible and practice loading datasets, we can pass `n_iter=1` to limit the number of iterations in the generator.

In [25]:
syn_model = Plugins().get("timegan", n_iter=1)
syn_model.fit(loader)
syn_model.generate(count=10)

100%|██████████| 1/1 [00:00<00:00,  5.53it/s]


Unnamed: 0,seq_id,seq_time_id,seq_temporal_dim_0,seq_temporal_dim_1,seq_temporal_dim_2,seq_temporal_dim_3,seq_temporal_dim_4,seq_temporal_dim_5,seq_out_label
0,0,94,12.530062,-3.094555,-3.035400,2.310466,1.820244,-0.957999,badminton
1,0,93,3.869224,3.163741,3.006886,-6.300015,-0.001746,0.382490,badminton
2,0,93,10.422528,-1.176147,-0.621660,1.920073,2.263130,-3.279178,badminton
3,0,79,11.544957,3.167509,-4.428284,0.383934,-0.936623,-10.698672,badminton
4,0,98,-9.961639,2.388232,0.129205,-0.619688,1.673140,6.145846,badminton
...,...,...,...,...,...,...,...,...,...
657,9,98,0.485499,2.347785,-4.107735,-1.964277,-4.986314,5.072160,walking
658,9,96,1.374680,0.041658,-4.252335,1.341734,1.113546,2.830677,walking
659,9,98,-0.014014,-0.146475,-0.650388,-1.520942,-0.438627,-6.482527,walking
660,9,94,11.141310,-10.039393,0.874606,-1.724200,-0.522479,2.099312,walking


### Irregular Time Series

Now lets load an irregular time series dataset and show that that is also compatible with Synthcity. The dataset we will use here is a google stocks dataset provided by the synthcity module itself.

In [48]:
import numpy as np
from synthcity.utils.datasets.time_series.google_stocks import GoogleStocksDataloader

static_data, temporal_data, observation_times, outcome = GoogleStocksDataloader().load()

As the dataset is wrapped by synthcity, it is already provided to us in the correct format, but the requirements are the same as before. The temporal data is a list of dataframes, where each dataframe in the list refers to a different record and contains all time points for the record. The outcomes are all in one dataframe and the observation times are a list of time steps for each record. The main difference here is that the observation times is a list of floats that represent the time between each data point.

In [49]:
loader = TimeSeriesDataLoader(
    temporal_data=temporal_data,
    observation_times=observation_times,
    static_data=static_data,
    outcome=outcome,
)
print(loader.info())
display(loader.dataframe())

{'data_type': 'time_series', 'len': 500, 'static_features': [], 'temporal_features': ['Close', 'High', 'Low', 'Open', 'Volume'], 'outcome_features': ['Open_next'], 'outcome_len': 1.0, 'window_len': 10, 'sensitive_features': [], 'important_features': [], 'random_state': 0, 'train_size': 0.8, 'fill': nan, 'seq_static_features': [], 'seq_temporal_features': ['seq_temporal_Close', 'seq_temporal_High', 'seq_temporal_Low', 'seq_temporal_Open', 'seq_temporal_Volume'], 'seq_outcome_features': ['seq_out_Open_next'], 'seq_offset': 0, 'seq_id_feature': 'seq_id', 'seq_time_id_feature': 'seq_time_id', 'seq_features': ['seq_id', 'seq_time_id', 'seq_temporal_Close', 'seq_temporal_High', 'seq_temporal_Low', 'seq_temporal_Open', 'seq_temporal_Volume', 'seq_out_Open_next']}


Unnamed: 0,seq_id,seq_time_id,seq_temporal_Close,seq_temporal_High,seq_temporal_Low,seq_temporal_Open,seq_temporal_Volume,seq_out_Open_next
0,0,0.420455,0.795143,0.750354,0.712517,0.697940,0.334331,0.710852
1,0,0.409091,0.648194,0.610592,0.618136,0.597390,0.321474,0.710852
2,0,0.397727,0.591407,0.539365,0.434550,0.390659,0.490578,0.710852
3,0,0.363636,0.413450,0.341687,0.407823,0.385989,0.392094,0.710852
4,0,0.352273,0.378020,0.378079,0.409257,0.361401,0.110766,0.710852
...,...,...,...,...,...,...,...,...
495,49,0.920455,0.889539,0.879778,0.900782,0.885578,0.413625,0.246566
496,49,0.909091,0.832628,0.838572,0.831813,0.785577,0.244291,0.246566
497,49,0.897727,0.791407,0.784055,0.800261,0.751374,0.140203,0.246566
498,49,0.886364,0.748318,0.716935,0.731552,0.667446,0.150912,0.246566


Exactly as for the regular time series, we can now generate synthetic data, by selecting our time series compatible plugin, then calling `fit()` and `generate()`.

In [51]:
syn_model = Plugins().get("timegan", n_iter=1)
syn_model.fit(loader)
syn_model.generate(count=5)

100%|██████████| 1/1 [00:00<00:00, 20.39it/s]


Unnamed: 0,seq_id,seq_time_id,seq_temporal_Close,seq_temporal_High,seq_temporal_Low,seq_temporal_Open,seq_temporal_Volume,seq_out_Open_next
0,0,0.227674,0.740012,0.363054,0.216962,0.68521,0.426429,0.3279
1,0,0.633709,0.294282,0.620159,0.250119,0.62885,0.436839,0.3279
2,0,0.281716,0.729902,0.385716,0.261911,0.233747,0.645109,0.3279
3,0,0.093432,0.411217,0.229563,0.677443,0.638325,0.555954,0.3279
4,1,0.443256,0.679591,0.62985,0.417286,0.395919,0.448475,0.326633
5,1,0.379312,0.423856,0.638886,0.416107,0.371163,0.694256,0.326633
6,1,0.286197,0.30815,0.747453,0.419901,0.234233,0.407084,0.326633
7,1,0.328184,0.712702,0.371477,0.273448,0.236638,0.435504,0.326633
8,1,0.629596,0.685311,0.654764,0.273487,0.163162,0.230441,0.326633
9,2,0.640672,0.739825,0.693935,0.670616,0.742992,0.250959,0.326177


### Composite Irregular Time Series Survival Analysis

In this final example we will look at composite data while adding all the other more complex elements we have looked at so far. This next dataset is a composite irregular time series survival analysis dataset. 

Again this dataset is provided by synthcity, so there is little to do in terms of pre-processing as everything is in the right format to begin with.

In [52]:
from synthcity.utils.datasets.time_series.pbc import PBCDataloader
(
    static_surv,
    temporal_surv,
    temporal_surv_horizons,
    outcome_surv,
) = PBCDataloader().load()

Even complex datasets such as this are compatible with Synthcity. We can load this data using the `TimeSeriesSurvivalDataLoader`. Then by calling `loader.info()`, we can check the information about the dataset. It contains both one static feature ("sex") and 14 temporal features, making it a composite dataset. The `seq_time_id` field shows the irregular time sampling, which we create by passing the values to the `observation_times` parameter of the `TimeSeriesSurvivalDataLoader` object. And finally, we are formulating this data as a survival analysis problem, which is indicated by the presence of a `time_to_event` field.

In [53]:
T, E = outcome_surv

horizons = [0.25, 0.5, 0.75]
time_horizons = np.quantile(T, horizons).tolist()

loader = TimeSeriesSurvivalDataLoader(
    temporal_data=temporal_surv,
    observation_times=temporal_surv_horizons,
    static_data=static_surv,
    T=T,
    E=E,
    time_horizons=time_horizons,
)

print(loader.info())

{'data_type': 'time_series_survival', 'len': 1945, 'static_features': ['sex'], 'temporal_features': ['SGOT', 'age', 'albumin', 'alkaline', 'ascites', 'drug', 'edema', 'hepatomegaly', 'histologic', 'platelets', 'prothrombin', 'serBilir', 'serChol', 'spiders'], 'outcome_features': ['time_to_event', 'event'], 'outcome_len': 2.0, 'window_len': 16, 'sensitive_features': [], 'important_features': [], 'random_state': 0, 'train_size': 0.8, 'fill': nan, 'seq_static_features': ['seq_static_sex'], 'seq_temporal_features': ['seq_temporal_SGOT', 'seq_temporal_age', 'seq_temporal_albumin', 'seq_temporal_alkaline', 'seq_temporal_ascites', 'seq_temporal_drug', 'seq_temporal_edema', 'seq_temporal_hepatomegaly', 'seq_temporal_histologic', 'seq_temporal_platelets', 'seq_temporal_prothrombin', 'seq_temporal_serBilir', 'seq_temporal_serChol', 'seq_temporal_spiders'], 'seq_outcome_features': ['seq_out_time_to_event', 'seq_out_event'], 'seq_offset': 0, 'seq_id_feature': 'seq_id', 'seq_time_id_feature': 'seq_

We can now generate synthetic data, in the way we are now well familiar with. We select our time series compatible plugin, then call `fit()` and `generate()`.

In [55]:
syn_model = Plugins().get("timegan", n_iter=1)
syn_model.fit(loader)
syn_model.generate(count=5)

100%|██████████| 1/1 [00:00<00:00,  2.91it/s]


Unnamed: 0,seq_id,seq_time_id,seq_static_sex,seq_temporal_SGOT,seq_temporal_age,seq_temporal_albumin,seq_temporal_alkaline,seq_temporal_ascites,seq_temporal_drug,seq_temporal_edema,seq_temporal_hepatomegaly,seq_temporal_histologic,seq_temporal_platelets,seq_temporal_prothrombin,seq_temporal_serBilir,seq_temporal_serChol,seq_temporal_spiders,seq_out_time_to_event,seq_out_event
0,0,12.368695,1.0,-0.777904,-0.267299,-0.556327,0.00136,2.0,0.0,2.0,2.0,0.0,-0.000178,-0.658516,1.678653,0.364817,2.0,6.346922,1.0
1,1,13.374039,0.0,-0.271371,0.783349,0.200952,-0.298128,1.0,0.0,1.0,0.0,3.0,-1.085052,0.92424,-0.000171,-0.511754,1.0,9.175191,1.0


### Create synthetic datasets
 1) Above we have generated data with the debugging method `"marginal_distributions"` for tabular data and `timegan` for time series data. Now, using `Plugins().list()` or the documentation find another method that is compatible with some of the datasets to see if you can generate your own snthetic data.
 2) Generate synthetic data for another dataset of your choice using the methods described above. You can use any of the other dataset from the sources we have used above: [SKLearn](https://scikit-learn.org/stable/datasets/toy_dataset.html), [SKTime](https://www.sktime.org/en/stable/api_reference/datasets.html), [SKSurv](https://scikit-survival.readthedocs.io/en/stable/api/datasets.html)  or [synthcity](https://github.com/vanderschaarlab/synthcity/tree/main/src/synthcity/utils/datasets) itself.