# Case Study 1 - Data Modality

## The Task
Get used to loading datasets with the library and generating synthetic data from them, whatever the modality of the real data.

### Imports
Lets get the imports out of the way. We import the required standard and 3rd party libraries and relevant Synthcity modules. We can also set the level of logging here, using Synthcity's bespoke logger. 

In [2]:
# Standard
import sys
import warnings
from pathlib import Path

# 3rd party
import numpy as np
import pandas as pd

# synthcity
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import (GenericDataLoader, SurvivalAnalysisDataLoader, TimeSeriesDataLoader,TimeSeriesSurvivalDataLoader)

# Configure warnings and logging
warnings.filterwarnings("ignore")

# Set the level for the logging
# log.add(sink=sys.stderr, level="DEBUG")
log.remove()

## Synthetic generators

We can list the available generic synthetic generators by calling list() on the Plugins object.

In [3]:
print(Plugins().list())

['pategan', 'timevae', 'nflow', 'rtvae', 'ctgan', 'timegan', 'decaf', 'adsgan', 'survival_ctgan', 'survival_nflow', 'survival_gan', 'probabilistic_ar', 'radialgan', 'fflows', 'survae', 'tvae', 'bayesian_network', 'privbayes', 'dpgan']


## Loading data of different modalities
### Static Data
Lets start with the simplest example static tabular data. First we need to load the dataset.

In [4]:
from sklearn.datasets import load_diabetes

X, y = load_diabetes(return_X_y=True, as_frame=True)
X["target"] = y
display(X)


Unnamed: 0,age,sex,bmi,bp,s1,s2,s3,s4,s5,s6,target
0,0.038076,0.050680,0.061696,0.021872,-0.044223,-0.034821,-0.043401,-0.002592,0.019907,-0.017646,151.0
1,-0.001882,-0.044642,-0.051474,-0.026328,-0.008449,-0.019163,0.074412,-0.039493,-0.068332,-0.092204,75.0
2,0.085299,0.050680,0.044451,-0.005670,-0.045599,-0.034194,-0.032356,-0.002592,0.002861,-0.025930,141.0
3,-0.089063,-0.044642,-0.011595,-0.036656,0.012191,0.024991,-0.036038,0.034309,0.022688,-0.009362,206.0
4,0.005383,-0.044642,-0.036385,0.021872,0.003935,0.015596,0.008142,-0.002592,-0.031988,-0.046641,135.0
...,...,...,...,...,...,...,...,...,...,...,...
437,0.041708,0.050680,0.019662,0.059744,-0.005697,-0.002566,-0.028674,-0.002592,0.031193,0.007207,178.0
438,-0.005515,0.050680,-0.015906,-0.067642,0.049341,0.079165,-0.028674,0.034309,-0.018114,0.044485,104.0
439,0.041708,0.050680,-0.015906,0.017293,-0.037344,-0.013840,-0.024993,-0.011080,-0.046883,0.015491,132.0
440,-0.045472,-0.044642,0.039062,0.001215,0.016318,0.015283,-0.028674,0.026560,0.044529,-0.025930,220.0


Then we need to pass it to the `GenericDataLoader` object.

In [5]:
loader = GenericDataLoader(
    X,
    target_column="target",
    sensitive_columns=["sex"],
)

We can print out different methods that are compatible with our data by calling `Plugins().list()` with a relevant list passed to the categories parameter.

In [6]:
print(Plugins(categories=["generic"]).list())

['nflow', 'rtvae', 'tvae', 'bayesian_network', 'ctgan']


No need to worry about the code in this next block here, we will go into lots of detail in how to generate synthetic data in the case studies to come. It is here purely to demonstrate that our dataset can be used to generate synthetic data using the synthcity module.

In [7]:
syn_model = Plugins().get("marginal_distributions")
syn_model.fit(loader)
syn_model.generate(count=10).dataframe()

Unnamed: 0,age,sex,bmi,bp,s1,s2,s3,s4,s5,s6,target
0,0.01239,-0.044642,0.052872,0.021754,0.027268,0.056934,0.053274,0.067191,0.016427,0.012267,201.169135
1,0.048652,0.05068,0.096268,0.062424,0.073969,0.109243,0.100439,0.11072,0.059634,0.057751,254.575787
2,0.024148,0.05068,0.066944,0.034942,0.042412,0.073896,0.068568,0.081306,0.030437,0.027016,218.487044
3,0.011533,-0.044642,0.051847,0.020794,0.026165,0.055699,0.05216,0.066163,0.015406,0.011192,199.907502
4,-0.014889,0.05068,0.020227,-0.00884,-0.007863,0.017584,0.017793,0.034446,-0.016076,-0.021949,160.993191
5,0.033548,0.05068,0.078194,0.045485,0.054518,0.087457,0.080795,0.09259,0.041638,0.038807,232.33201
6,-0.011852,0.05068,0.023861,-0.005434,-0.003952,0.021965,0.021743,0.038091,-0.012458,-0.01814,165.465495
7,0.087138,0.05068,0.142326,0.105588,0.123535,0.164761,0.150498,0.156919,0.105491,0.106025,311.259133
8,0.102807,0.05068,0.161077,0.123161,0.143714,0.187364,0.170878,0.175728,0.124161,0.125678,334.335746
9,-0.023654,0.05068,0.009738,-0.018669,-0.019151,0.004941,0.006393,0.023925,-0.02652,-0.032942,148.084728


### Regular Time Series

Now lets load a time series dataset and show that that is also compatible with Synthcity.

In [8]:
import numpy as np
from synthcity.utils.datasets.time_series.google_stocks import GoogleStocksDataloader

static_data, temporal_data, horizons, outcome = GoogleStocksDataloader().load()
loader = TimeSeriesDataLoader(
    temporal_data=temporal_data,
    observation_times=horizons,
    static_data=static_data,
    outcome=outcome,
)
print(loader.info())

syn_model = Plugins().get("marginal_distributions", n_iter=1)
syn_model.fit(loader)
syn_model.generate(count=10)
display(syn_model.generate(count=10).dataframe())

{'data_type': 'time_series', 'len': 500, 'static_features': [], 'temporal_features': ['Close', 'High', 'Low', 'Open', 'Volume'], 'outcome_features': ['Open_next'], 'outcome_len': 1.0, 'window_len': 10, 'sensitive_features': [], 'important_features': [], 'random_state': 0, 'train_size': 0.8, 'fill': nan, 'seq_static_features': [], 'seq_temporal_features': ['seq_temporal_Close', 'seq_temporal_High', 'seq_temporal_Low', 'seq_temporal_Open', 'seq_temporal_Volume'], 'seq_outcome_features': ['seq_out_Open_next'], 'seq_offset': 0, 'seq_id_feature': 'seq_id', 'seq_time_id_feature': 'seq_time_id', 'seq_features': ['seq_id', 'seq_time_id', 'seq_temporal_Close', 'seq_temporal_High', 'seq_temporal_Low', 'seq_temporal_Open', 'seq_temporal_Volume', 'seq_out_Open_next']}


Unnamed: 0,seq_id,seq_time_id,seq_temporal_Close,seq_temporal_High,seq_temporal_Low,seq_temporal_Open,seq_temporal_Volume,seq_out_Open_next
0,0,0.559068,0.548814,0.548814,0.548814,0.548814,0.548814,0.44478
1,0,0.721662,0.715189,0.715189,0.715189,0.715189,0.715189,0.44478
2,0,0.611791,0.602763,0.602763,0.602763,0.602763,0.602763,0.44478
3,0,0.555227,0.544883,0.544883,0.544883,0.544883,0.544883,0.44478
4,0,0.436754,0.423655,0.423655,0.423655,0.423655,0.423655,0.44478
5,0,0.653942,0.645894,0.645894,0.645894,0.645894,0.645894,0.44478
6,0,0.450369,0.437587,0.437587,0.437587,0.437587,0.437587,0.44478
7,0,0.894233,0.891773,0.891773,0.891773,0.891773,0.891773,0.44478
8,0,0.964489,0.963663,0.963663,0.963663,0.963663,0.963663,0.44478
9,0,0.397454,0.383442,0.383442,0.383442,0.383442,0.383442,0.44478


### Composite Irregular Time Series Survival Analysis

This next dataset is a composite irregular time series survival analysis dataset. Even complex datasets such as this are compatible with Synthcity. By calling `loader.info()`, we can see information about the dataset. We can see that it contains both one static feature and 14 temporal features, making it a composite dataset. The `seq_time_id` field shows the irregular time sampling, which we create by passing the values to the `observation_times` parameter of the `TimeSeriesSurvivalDataLoader` object. And finally, we are formulating this data as a survival analysis problem, which is indicated by the presence of a `time_to_event` field.

In [18]:
import numpy as np
from synthcity.utils.datasets.time_series.pbc import PBCDataloader

(
    static_surv,
    temporal_surv,
    temporal_surv_horizons,
    outcome_surv,
) = PBCDataloader().load()
T, E = outcome_surv

horizons = [0.25, 0.5, 0.75]
time_horizons = np.quantile(T, horizons).tolist()

loader = TimeSeriesSurvivalDataLoader(
    temporal_data=temporal_surv,
    observation_times=temporal_surv_horizons,
    static_data=static_surv,
    T=T,
    E=E,
    time_horizons=time_horizons,
)

print(loader.info())

syn_model = Plugins().get("marginal_distributions", n_iter=1)
syn_model.fit(loader)
syn_model.generate(count=10)
display(syn_model.generate(count=10).dataframe())

[0.56948856 1.0951703 ]
   seq_id  seq_time_id  seq_static_sex  seq_temporal_SGOT  seq_temporal_age  \
0       0     0.569489             1.0           0.195488          0.248058   
1       0     1.095170             1.0          -1.485263          0.248058   
2       1     5.319790             0.0          -0.116943          1.292856   
3       1     6.261636             0.0           0.214616          1.292856   
4       1     7.266455             0.0           0.274552          1.292856   
5       1     8.263060             0.0           0.274552          1.292856   
6       1     9.251451             0.0           0.116423          1.292856   
7       1    12.049611             0.0           0.116423          1.292856   
8       1    13.152995             0.0           0.293680          1.292856   
9       1    13.654036             0.0          -0.046806          1.292856   

   seq_temporal_albumin  seq_temporal_alkaline  seq_temporal_ascites  \
0             -1.570646           

Unnamed: 0,seq_id,seq_time_id,seq_static_sex,seq_temporal_SGOT,seq_temporal_age,seq_temporal_albumin,seq_temporal_alkaline,seq_temporal_ascites,seq_temporal_drug,seq_temporal_edema,seq_temporal_hepatomegaly,seq_temporal_histologic,seq_temporal_platelets,seq_temporal_prothrombin,seq_temporal_serBilir,seq_temporal_serChol,seq_temporal_spiders,seq_out_time_to_event,seq_out_event
0,0,7.852376,0.0,6.904687,0.363324,3.050269,5.318721,0.0,0.0,0.0,0.0,3.0,3.426769,8.671106,3.513956,5.355897,0.0,7.852376,0.0
1,0,10.232037,0.0,9.448147,1.226531,5.313141,7.26833,1.0,1.0,2.0,1.0,0.0,5.078592,11.709433,4.780858,7.61486,1.0,7.852376,0.0
2,0,8.624017,0.0,7.729442,0.643232,3.78404,5.950911,0.0,1.0,0.0,0.0,2.0,3.962398,9.656329,3.924768,6.0884,0.0,7.852376,0.0
3,0,7.796161,0.0,6.844603,0.342933,2.996813,5.272665,1.0,0.0,2.0,1.0,3.0,3.387748,8.599331,3.484028,5.302533,1.0,7.852376,0.0
4,0,6.06224,0.0,4.991332,-0.286036,1.34799,3.852098,1.0,1.0,2.0,1.0,0.0,2.184161,6.385479,2.560911,3.656559,1.0,7.852376,0.0
5,0,9.240912,0.0,8.388801,0.867007,4.370659,6.456321,2.0,1.0,1.0,2.0,0.0,4.390611,10.443975,4.253196,6.674006,2.0,7.852376,0.0
6,0,6.261515,0.0,5.204323,-0.21375,1.537485,4.015359,0.0,1.0,0.0,0.0,0.0,2.322486,6.63991,2.667002,3.845726,0.0,7.852376,0.0
7,0,12.757699,0.0,12.147658,2.142699,7.714849,9.337556,2.0,1.0,1.0,2.0,0.0,6.83176,14.934173,6.125489,10.012418,2.0,7.852376,0.0
8,0,13.785933,0.0,13.246668,2.515684,8.692619,10.179967,0.0,1.0,0.0,0.0,2.0,7.545501,16.247012,6.672908,10.988498,0.0,7.852376,0.0
9,0,5.487073,0.0,4.376574,-0.494674,0.801051,3.380874,0.0,1.0,0.0,0.0,0.0,1.784913,5.651111,2.254699,3.110565,0.0,7.852376,0.0


### Create synthetic datasets
 - Above we have generated data with the debugging method `"marginal_distributions"`. Now, using `Plugins().list()` or the documentation find another method that is compatible with some of the datasets to see if you can generate your own snthetic data.