In [1]:
import pandas as pd
import numpy as np
import os

from fairlearn.datasets import fetch_diabetes_hospital
from sklearn.model_selection import train_test_split
from sdv.metadata import Metadata
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table import CTGANSynthesizer
from sdv.single_table import TVAESynthesizer
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality
from sdv.evaluation.single_table import get_column_plot

In [2]:
data = fetch_diabetes_hospital(as_frame=True)

X = data.data.copy()
y = data.target.copy()

X.shape, y.shape

((101766, 24), (101766,))

In [3]:
X.head()
list(X.columns)

['race',
 'gender',
 'age',
 'discharge_disposition_id',
 'admission_source_id',
 'time_in_hospital',
 'medical_specialty',
 'num_lab_procedures',
 'num_procedures',
 'num_medications',
 'primary_diagnosis',
 'number_diagnoses',
 'max_glu_serum',
 'A1Cresult',
 'insulin',
 'change',
 'diabetesMed',
 'medicare',
 'medicaid',
 'had_emergency',
 'had_inpatient_days',
 'had_outpatient_days',
 'readmitted',
 'readmit_binary']

In [4]:
dropped_columns = ['readmitted', 'readmit_binary']
X = X.drop(columns=dropped_columns)

real_data = X.copy()
real_data['readmit_binary'] = (y == 1)
real_data['readmit_binary'] = real_data['readmit_binary'].astype(bool)
real_data.shape, real_data['readmit_binary'].dtype

((101766, 23), dtype('bool'))

In [5]:
real_train, real_test = train_test_split(
    real_data,
    test_size=0.2,
    random_state=66,
    stratify=real_data['readmit_binary']
)

real_train = real_train.reset_index(drop=True)
real_test = real_test.reset_index(drop=True)

real_train.shape, real_test.shape

((81412, 23), (20354, 23))

In [6]:
cat_cols = real_train.select_dtypes(include=['category']).columns
real_train[cat_cols] = real_train[cat_cols].astype('object')
real_test[cat_cols] = real_test[cat_cols].astype('object')

real_train['readmit_binary'] = real_train['readmit_binary'].astype(bool)
real_test['readmit_binary'] = real_test['readmit_binary'].astype(bool)

In [7]:
metadata = Metadata.detect_from_dataframe(
    data=real_train,
    table_name='diabetes'
)

In [8]:
meta_dict = metadata.to_dict()
columns = list(meta_dict['tables']['diabetes']['columns'].keys())
for col in columns[:]:
    print(col, "->", meta_dict['tables']['diabetes']['columns'][col].get('sdtype'))

race -> categorical
gender -> categorical
age -> categorical
discharge_disposition_id -> id
admission_source_id -> id
time_in_hospital -> numerical
medical_specialty -> categorical
num_lab_procedures -> numerical
num_procedures -> categorical
num_medications -> numerical
primary_diagnosis -> categorical
number_diagnoses -> numerical
max_glu_serum -> categorical
A1Cresult -> categorical
insulin -> categorical
change -> categorical
diabetesMed -> categorical
medicare -> categorical
medicaid -> categorical
had_emergency -> categorical
had_inpatient_days -> categorical
had_outpatient_days -> categorical
readmit_binary -> categorical


In [9]:
sensitive_attributes = ['race', 'gender']
for col in sensitive_attributes:
    if col in real_train.columns:
        metadata.update_column(column_name=col, sdtype='categorical')
metadata.update_column(column_name='readmit_binary', sdtype='boolean')
metadata.validate()

In [None]:
baseline_models = {
    # to correct the name
    "gaussian_copuula" : GaussianCopulaSynthesizer(
        metadata,
        enforce_min_max_values=True,
        enforce_rounding=True,
    ),
    "ctgan" : CTGANSynthesizer(
        metadata,
        epochs=500,
        verbose=True,
        enforce_rounding=False,
    ),
    "tvae" : TVAESynthesizer(
        metadata,
        epochs=500,
        verbose=True,
        enforce_rounding=False,
    ),
}



In [None]:
synthetic_train = {}
for name, model in baseline_models.items():
    print(f"\n --- training {name} ---")
    model.fit(real_train)

    synthetic_train[name] = model.sample(num_rows=len(real_train))

    os.makedirs("../artifacts", exist_ok=True)
    model.save(f"../artifacts/{name}_diabetes.pkl")
    print(f"Saved: ../artifacts/{name}_diabetes.pkl")

metadata.save_to_json('../artifacts/diabetes_metadata.json')


 --- training gaussian_copuula ---
Saved: ../artifacts/gaussian_copuula_diabetes.pkl

 --- training ctgan ---


Gen. (-03.12) | Discrim. (-00.13):   4%|▍         | 20/500 [1:08:06<26:58:09, 202.27s/it]

In [None]:
name = "ctgan"

diagnostic = run_diagnostic(
    real_data=real_train,
    synthetic_data=synthetic_train[name],
    metadata=metadata
)
diagnostic.get_score()

Unnamed: 0,race,gender,age,discharge_disposition_id,admission_source_id,time_in_hospital,medical_specialty,num_lab_procedures,num_procedures,num_medications,...,A1Cresult,insulin,change,diabetesMed,medicare,medicaid,had_emergency,had_inpatient_days,had_outpatient_days,readmit_binary
0,Caucasian,Female,'Over 60 years',Other,Emergency,5,InternalMedicine,50,1,32,...,>8,No,Ch,Yes,False,False,True,True,False,False
1,Caucasian,Male,'Over 60 years','Discharged to Home',Emergency,2,Missing,62,0,8,...,,No,No,Yes,True,False,True,True,True,False
2,Caucasian,Male,'Over 60 years',Other,Other,5,Other,41,2,24,...,,No,No,Yes,False,False,True,True,False,False
3,Caucasian,Male,'Over 60 years','Discharged to Home',Emergency,3,Missing,20,3,22,...,,Steady,Ch,Yes,False,False,True,False,False,False
4,Caucasian,Female,'30-60 years','Discharged to Home',Referral,5,Missing,14,0,12,...,,No,Ch,Yes,True,False,False,False,False,True


In [None]:
quality_report = evaluate_quality(
    real_data=real_train,
    synthetic_data=synthetic_train[name],
    metadata=metadata
)
quality_report.get_score

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 23/23 [00:01<00:00, 15.06it/s]|
Column Shapes Score: 93.15%

(2/2) Evaluating Column Pair Trends: |▊         | 21/253 [00:00<00:09, 25.39it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█▋        | 42/253 [00:01<00:07, 26.63it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |██▎       | 60/253 [00:02<00:07, 24.86it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_synthetic = synthetic.groupby(list(columns), dropna=False).size() / len(
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |████▌     | 117/253 [00:02<00:01, 76.05it/s]| 

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█████     | 127/253 [00:03<00:02, 50.17it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█████▊    | 147/253 [00:03<00:02, 40.99it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |██████▍   | 162/253 [00:04<00:02, 39.68it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |██████▉   | 176/253 [00:04<00:02, 38.20it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |███████▎  | 184/253 [00:05<00:02, 30.13it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |███████▊  | 197/253 [00:05<00:01, 33.60it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |████████  | 205/253 [00:05<00:01, 29.74it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |████████▍ | 215/253 [00:06<00:01, 27.97it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |████████▊ | 224/253 [00:06<00:01, 25.73it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█████████ | 230/253 [00:06<00:00, 25.85it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█████████▍| 240/253 [00:07<00:00, 30.36it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█████████▉| 251/253 [00:07<00:00, 40.48it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |██████████| 253/253 [00:07<00:00, 34.77it/s]|
Column Pair Trends Score: 81.71%

Overall Score (Average): 87.43%



  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


<bound method BaseReport.get_score of <sdmetrics.reports.single_table.quality_report.QualityReport object at 0x7f5649076d50>>

In [13]:
fig = get_column_plot(
    real_data=real_train,
    synthetic_data=synthetic_train,
    metadata=metadata,
    column_name='race'
)
fig.show()

In [14]:
os.makedirs('../artifacts', exist_ok=True)

metadata.save_to_json('../artifacts/diabetes_metadata.json')
synthesizer.save('../artifacts/gaussian_copula_diabetes.pkl')

ValueError: A file named 'diabetes_metadata.json' already exists in this folder. Please specify a different filename.

In [15]:
def pct(series):
    return (series.value_counts(normalize=True) * 100).round(2)

print("Real race %:\n", pct(real_train['race']).head(10))
print("\nSynthetic race %:\n", pct(synthetic_train['race']).head(10))

print("Real gender %:\n", pct(real_train['gender']).head(10))
print("\nSynthetic gender %:\n", pct(synthetic_train['gender']).head(10))

Real race %:
 race
Caucasian          74.88
AfricanAmerican    18.74
Unknown             2.29
Hispanic            2.00
Other               1.48
Asian               0.61
Name: proportion, dtype: float64

Synthetic race %:
 race
Caucasian          75.10
AfricanAmerican    18.52
Unknown             2.34
Hispanic            2.03
Other               1.41
Asian               0.61
Name: proportion, dtype: float64
Real gender %:
 gender
Female             53.76
Male               46.24
Unknown/Invalid     0.00
Name: proportion, dtype: float64

Synthetic gender %:
 gender
Female             53.82
Male               46.18
Unknown/Invalid     0.00
Name: proportion, dtype: float64


In [12]:
ctgen = CTGANSynthesizer(
    metadata,
    epochs=500,
    verbose=True,
    enforce_rounding=False
)
ctgen.fit(real_train)



InvalidDataTypeError: Columns ['medicare', 'medicaid', 'had_emergency', 'had_inpatient_days', 'had_outpatient_days'] are stored as a 'category' type, which is not supported. Please cast these columns to an 'object' to continue.