In [13]:
import pandas as pd
import numpy as np
import os

from fairlearn.datasets import fetch_diabetes_hospital
from sklearn.model_selection import train_test_split
from sdv.metadata import Metadata
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality
from sdv.evaluation.single_table import get_column_plot

In [2]:
data = fetch_diabetes_hospital(as_frame=True)

X = data.data.copy()
y = data.target.copy()

X.shape, y.shape

((101766, 24), (101766,))

In [3]:
X.head()
list(X.columns)

['race',
 'gender',
 'age',
 'discharge_disposition_id',
 'admission_source_id',
 'time_in_hospital',
 'medical_specialty',
 'num_lab_procedures',
 'num_procedures',
 'num_medications',
 'primary_diagnosis',
 'number_diagnoses',
 'max_glu_serum',
 'A1Cresult',
 'insulin',
 'change',
 'diabetesMed',
 'medicare',
 'medicaid',
 'had_emergency',
 'had_inpatient_days',
 'had_outpatient_days',
 'readmitted',
 'readmit_binary']

In [4]:
dropped_columns = ['readmitted', 'readmit_binary']
X = X.drop(columns=dropped_columns)

real_data = X.copy()
real_data['readmit_binary'] = (y == 1)
real_data.shape, real_data['readmit_binary'].dtype

((101766, 23), dtype('bool'))

In [5]:
real_train, real_test = train_test_split(
    real_data,
    test_size=0.2,
    random_state=66,
    stratify=real_data['readmit_binary']
)

real_train = real_train.reset_index(drop=True)
real_test = real_test.reset_index(drop=True)

real_train.shape, real_test.shape

((81412, 23), (20354, 23))

In [6]:
metadata = Metadata.detect_from_dataframe(
    data=real_train,
    table_name='diabetes'
)

metadata.validate()

In [7]:
meta_dict = metadata.to_dict()
columns = list(meta_dict['tables']['diabetes']['columns'].keys())
for col in columns[:10]:
    print(col, "->", meta_dict['tables']['diabetes']['columns'][col].get('sdtype'))

race -> categorical
gender -> categorical
age -> categorical
discharge_disposition_id -> id
admission_source_id -> id
time_in_hospital -> numerical
medical_specialty -> categorical
num_lab_procedures -> numerical
num_procedures -> categorical
num_medications -> numerical


In [8]:
sensitive_attributes = ['race', 'gender']
for col in sensitive_attributes:
    if col in real_train.columns:
        metadata.update_column(column_name=col, sdtype='categorical')
metadata.update_column(column_name='readmit_binary', sdtype='boolean')
metadata.validate()

In [9]:
synthesizer = GaussianCopulaSynthesizer(
    metadata,
    enforce_min_max_values=True,
    enforce_rounding=True
)

synthesizer.fit(real_train)

  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})
  data = data.fillna(sentinel).replace({sentinel: None})


In [14]:
synthetic_train = synthesizer.sample(num_rows=len(real_train))
synthetic_train.head()

Unnamed: 0,race,gender,age,discharge_disposition_id,admission_source_id,time_in_hospital,medical_specialty,num_lab_procedures,num_procedures,num_medications,...,A1Cresult,insulin,change,diabetesMed,medicare,medicaid,had_emergency,had_inpatient_days,had_outpatient_days,readmit_binary
0,Caucasian,Female,'30-60 years','Discharged to Home',Emergency,2,Other,59,1,27,...,,Steady,No,No,False,False,False,True,True,False
1,Caucasian,Female,'Over 60 years',Other,Emergency,5,Missing,15,1,12,...,,Up,Ch,No,True,False,True,False,False,True
2,Caucasian,Female,'Over 60 years','Discharged to Home',Emergency,9,Missing,55,0,33,...,>8,Down,Ch,Yes,True,False,True,True,True,False
3,Caucasian,Female,'Over 60 years','Discharged to Home',Other,11,Family/GeneralPractice,56,4,29,...,,No,Ch,Yes,True,False,True,True,False,False
4,AfricanAmerican,Male,'Over 60 years','Discharged to Home',Emergency,3,Missing,14,0,5,...,,Steady,Ch,Yes,False,False,True,True,False,False


In [15]:
diagnostic = run_diagnostic(
    real_data=real_train,
    synthetic_data=synthetic_train,
    metadata=metadata
)
diagnostic.get_score()

Generating report ...

(1/2) Evaluating Data Validity: |██████████| 23/23 [00:00<00:00, 56.16it/s]|
Data Validity Score: 100.0%

(2/2) Evaluating Data Structure: |██████████| 1/1 [00:00<00:00, 274.12it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 100.0%



np.float64(1.0)

In [16]:
quality_report = evaluate_quality(
    real_data=real_train,
    synthetic_data=synthetic_train,
    metadata=metadata
)
quality_report.get_score

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 23/23 [00:01<00:00, 14.93it/s]|
Column Shapes Score: 93.16%

(2/2) Evaluating Column Pair Trends: |▊         | 21/253 [00:00<00:09, 25.41it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█▋        | 42/253 [00:01<00:07, 26.65it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |██▍       | 61/253 [00:02<00:07, 25.74it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_synthetic = synthetic.groupby(list(columns), dropna=False).size() / len(
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |████▌     | 115/253 [00:02<00:01, 72.14it/s]| 

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |████▉     | 126/253 [00:03<00:02, 51.06it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█████▊    | 148/253 [00:03<00:02, 41.97it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |██████▍   | 164/253 [00:04<00:02, 39.78it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |██████▉   | 174/253 [00:04<00:02, 38.87it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |███████▍  | 187/253 [00:05<00:02, 30.17it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |███████▋  | 196/253 [00:05<00:01, 33.08it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |████████▏ | 208/253 [00:05<00:01, 30.02it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |████████▍ | 215/253 [00:06<00:01, 28.23it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |████████▊ | 224/253 [00:06<00:01, 25.71it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█████████▏| 231/253 [00:06<00:00, 26.26it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█████████▎| 237/253 [00:06<00:00, 27.83it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |█████████▊| 247/253 [00:07<00:00, 37.48it/s]|

  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()
  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


(2/2) Evaluating Column Pair Trends: |██████████| 253/253 [00:07<00:00, 35.09it/s]|
Column Pair Trends Score: 81.62%

Overall Score (Average): 87.39%



  contingency_real_counts = real.groupby(list(columns), dropna=False).size()


<bound method BaseReport.get_score of <sdmetrics.reports.single_table.quality_report.QualityReport object at 0x7f253908c850>>

In [17]:
fig = get_column_plot(
    real_data=real_train,
    synthetic_data=synthetic_train,
    metadata=metadata,
    column_name='race'
)
fig.show()

In [21]:
os.makedirs('../artifacts', exist_ok=True)

metadata.save_to_json('../artifacts/diabetes_metadata.json')
synthesizer.save('../artifacts/gaussian_copula_diabetes.pkl')