In [12]:
# !pip install sdv fairlearn
import pandas as pd
import numpy as np
import os
from pathlib import Path
import pickle

from fairlearn.datasets import fetch_diabetes_hospital
from sklearn.model_selection import train_test_split
from sdv.metadata import Metadata
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table import CTGANSynthesizer
from sdv.single_table import TVAESynthesizer
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality
from sdv.evaluation.single_table import get_column_plot

In [2]:
data = fetch_diabetes_hospital(as_frame=True)

X = data.data.copy()
y = data.target.copy()

X.shape, y.shape

((101766, 24), (101766,))

In [3]:
dropped_columns = ['readmitted', 'readmit_binary']
X = X.drop(columns=dropped_columns)

real_data = X.copy()
real_data['readmit_binary'] = (y == 1)
real_data['readmit_binary'] = real_data['readmit_binary'].astype(bool)
real_data.shape, real_data['readmit_binary'].dtype

((101766, 23), dtype('bool'))

In [4]:
real_train, real_test = train_test_split(
    real_data,
    test_size=0.2,
    random_state=66,
    stratify=real_data['readmit_binary']
)

real_train = real_train.reset_index(drop=True)
real_test = real_test.reset_index(drop=True)

real_train.shape, real_test.shape

((81412, 23), (20354, 23))

In [5]:
cat_cols = real_train.select_dtypes(include=['category']).columns
real_train[cat_cols] = real_train[cat_cols].astype('object')
real_test[cat_cols] = real_test[cat_cols].astype('object')

real_train['readmit_binary'] = real_train['readmit_binary'].astype(bool)
real_test['readmit_binary'] = real_test['readmit_binary'].astype(bool)

In [6]:
metadata = Metadata.detect_from_dataframe(
    data=real_train,
    table_name='diabetes'
)

In [7]:
sensitive_attributes = ['race', 'gender']
for col in sensitive_attributes:
    if col in real_train.columns:
        metadata.update_column(column_name=col, sdtype='categorical')
metadata.update_column(column_name='readmit_binary', sdtype='boolean')
metadata.validate()

In [13]:
def load_model_syn_data(model_path, sample_len):
    if model_path.exists():
        with model_path.open("rb") as f:
            model = pickle.load(f)
        synthetic_dataset = model.sample(num_rows=sample_len)
        return model, synthetic_dataset

In [16]:
gc_path = Path("../artifacts/gaussian_copuula_diabetes.pkl")
ct_path = Path("../artifacts/ctgan_diabetes.pkl")
tv_path = Path("../artifacts/tvae_diabetes.pkl")

In [17]:
sample_len = 81412
gc_model, gc_gendata = load_model_syn_data(gc_path, sample_len)
gc_gendata.head()

Unnamed: 0,race,gender,age,discharge_disposition_id,admission_source_id,time_in_hospital,medical_specialty,num_lab_procedures,num_procedures,num_medications,...,A1Cresult,insulin,change,diabetesMed,medicare,medicaid,had_emergency,had_inpatient_days,had_outpatient_days,readmit_binary
0,Caucasian,Male,'Over 60 years','Discharged to Home',Emergency,4,Missing,54,3,9,...,,Steady,Ch,Yes,True,False,False,True,False,False
1,Caucasian,Male,'30 years or younger','Discharged to Home',Emergency,6,Family/GeneralPractice,29,0,15,...,,No,Ch,Yes,False,False,False,False,True,False
2,Caucasian,Male,'30-60 years','Discharged to Home',Emergency,4,Emergency/Trauma,11,0,5,...,,Steady,No,Yes,True,False,False,False,False,False
3,Caucasian,Male,'30-60 years','Discharged to Home',Emergency,2,Missing,6,6,8,...,,No,Ch,No,False,False,False,True,False,True
4,Unknown,Female,'Over 60 years','Discharged to Home',Other,6,InternalMedicine,53,0,27,...,,No,Ch,Yes,False,False,False,False,True,False


In [18]:
ct_model, ct_gendata = load_model_syn_data(ct_path, sample_len)
ct_gendata.head()

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Unnamed: 0,race,gender,age,discharge_disposition_id,admission_source_id,time_in_hospital,medical_specialty,num_lab_procedures,num_procedures,num_medications,...,A1Cresult,insulin,change,diabetesMed,medicare,medicaid,had_emergency,had_inpatient_days,had_outpatient_days,readmit_binary
0,AfricanAmerican,Female,'Over 60 years','Discharged to Home',Referral,3,Missing,19,5,6,...,,Steady,No,Yes,True,False,False,False,False,False
1,Caucasian,Female,'Over 60 years',Other,Emergency,6,InternalMedicine,62,0,21,...,,No,No,Yes,True,False,False,False,False,False
2,Hispanic,Male,'30-60 years','Discharged to Home',Emergency,4,Missing,42,2,32,...,,Up,Ch,No,False,False,False,False,False,False
3,Caucasian,Female,'Over 60 years','Discharged to Home',Emergency,2,Missing,46,1,9,...,,No,No,Yes,True,False,True,False,False,True
4,Caucasian,Female,'30-60 years','Discharged to Home',Referral,4,Family/GeneralPractice,40,0,23,...,,No,Ch,Yes,True,False,False,False,False,False


In [19]:
tv_model, tv_gendata = load_model_syn_data(tv_path, sample_len)
tv_gendata.head()

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Unnamed: 0,race,gender,age,discharge_disposition_id,admission_source_id,time_in_hospital,medical_specialty,num_lab_procedures,num_procedures,num_medications,...,A1Cresult,insulin,change,diabetesMed,medicare,medicaid,had_emergency,had_inpatient_days,had_outpatient_days,readmit_binary
0,Caucasian,Female,'30-60 years',Other,Emergency,6,Other,50,0,16,...,,No,No,No,True,False,False,False,False,False
1,Caucasian,Female,'Over 60 years',Other,Emergency,2,Missing,3,0,9,...,,No,No,No,False,False,False,False,False,False
2,AfricanAmerican,Male,'30-60 years','Discharged to Home',Emergency,5,Missing,50,3,10,...,,Steady,No,Yes,False,False,False,False,False,False
3,Caucasian,Female,'30-60 years','Discharged to Home',Emergency,8,Other,41,0,18,...,,Steady,No,Yes,False,False,False,False,False,False
4,Caucasian,Female,'30-60 years',Other,Emergency,2,Family/GeneralPractice,43,0,12,...,,No,No,No,False,False,False,True,False,False


In [22]:
print(type(metadata))
metadata_dict = metadata.to_dict()
print(type(metadata_dict))

<class 'sdv.metadata.metadata.Metadata'>
<class 'dict'>


In [25]:
print(metadata_dict.keys())

dict_keys(['tables', 'relationships', 'METADATA_SPEC_VERSION'])


In [23]:
print(os.getcwd())

/content


In [20]:
syn_datasets = {
    "GaussianCopula" : gc_gendata,
    "CTGAN" : ct_gendata,
    "TVAE" : tv_gendata,
}

In [21]:
from sdmetrics.reports.single_table import QualityReport, DiagnosticReport
from sdmetrics.single_column import (
    KSComplement,
    TVComplement,
    CategoryCoverage,
    RangeCoverage,
    MissingValueSimilarity,
    StatisticSimilarity,
    BoundaryAdherence,
    CategoryAdherence,
)
from sdmetrics.column_pairs import CorrelationSimilarity, ContingencySimilarity
from sdmetrics.single_table import NewRowSynthesis, TableStructure
from sdmetrics.visualization import get_column_plot, get_column_pair_plot

In [22]:
import warnings
from pathlib import Path


In [26]:
import kaleido

In [27]:
def save_plot_image(fig, path, width=1200, height=800, scale=2):
    path.parent.mkdir(parents=True, exist_ok=True)
    fig.write_image(str(path), width=width, height=height, scale=scale)

In [23]:
def save_plot(fig, path):
  try:
    if fig is None:
      return
    path.parent.mkdir(parents=True, exist_ok=True)
    fig.write_html(str(path), include_plotlyjs='cdn')
  except Exception as e:
    warnings.warn(f"Could not save plot to {path} : {e}")


In [None]:
"""
def quality_reports(real_data : pd.DataFrame, syn_data : pd.DataFrame, \
                metadata_dict, out_dir : Path,):
  out_dir.mkdir(parents=True, exist_ok=True)

  q = QualityReport()
  q.generate(real_data=real_data, synthetic_data=syn_data, \
             metadata=metadata_dict, verbose=False)

  q_score = float(q.get_score())
  q_props = q.get_properties()
  q_col_shapes = q.get_details(property_name="Column Shapes")
  q_col_pairs = q.get_details(property_name="Column Pair Trends")

  q.save(str(out_dir / "quality_report.pkl"))
  q_props.to_csv(out_dir / "quality_properties.csv", index=False)
  q_col_shapes.to_csv(out_dir / "quality_column_shapes.csv", index=False)
  q_col_pairs.to_csv(out_dir / "quality_column_pair_trends.csv", index=False)

  save_plot(q.get_visualization("Column Shapes"), out_dir / "plots" / "quality_column_shapes_scores.html")
  save_plot(q.get_visualization("Column Pair Trends"), out_dir / "plots" / "quality_column_pair_trends_scores.html")
  """


SyntaxError: incomplete input (1766906667.py, line 1)

In [33]:
def quality_reports(real_data : pd.DataFrame, syn_data : pd.DataFrame, metadata_dict, out_dir : Path,):
    out_dir.mkdir(parents=True, exist_ok=True)

    q = QualityReport()
    q.generate(real_data=real_data, synthetic_data=syn_data, \
                metadata=metadata_dict, verbose=False)

    q_score = float(q.get_score())
    q_props = q.get_properties()
    q_col_shapes = q.get_details(property_name="Column Shapes")
    q_col_pairs = q.get_details(property_name="Column Pair Trends")

    q.save(str(out_dir / "quality_report.pkl"))
    q_props.to_csv(out_dir / "quality_properties.csv", index=False)
    q_col_shapes.to_csv(out_dir / "quality_column_shapes.csv", index=False)
    q_col_pairs.to_csv(out_dir / "quality_column_pair_trends.csv", index=False)
    fig_shapes = q.get_visualization("Column Shapes")
    save_plot_image(fig_shapes, out_dir / "plots" / "quality_column_shapes_scores.png")

    fig_pair_trends = q.get_visualization("Column Pair Trends")
    save_plot_image(fig_pair_trends, out_dir / "plots" / "quality_column_pair_trends_scores.png")

In [34]:
save_dir = Path("../reports")
quality_reports(real_data=real_data, syn_data=syn_datasets["GaussianCopula"], metadata_dict=metadata.to_dict()["tables"]["diabetes"], out_dir=save_dir)























































































































































































ValueError: 
Image export using the "kaleido" engine requires the Kaleido package,
which can be installed using pip:

    $ pip install --upgrade kaleido
