In [46]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [47]:
pip install sdv



In [58]:
import pandas as pd
import numpy as np

import sklearn.datasets as datasets
from sklearn.model_selection import train_test_split

from sdv.metadata import SingleTableMetadata
from sdv.single_table import CTGANSynthesizer
from sdv.evaluation.single_table import run_diagnostic
from sdv.evaluation.single_table import evaluate_quality
from sdv.evaluation.single_table import get_column_plot
from sdv.sampling import Condition

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

import plotly.graph_objs as go
from plotly.subplots import make_subplots

seed = 42

In [21]:
# create metadata
metadata = SingleTableMetadata()

# import data
data = pd.read_csv('/content/drive/MyDrive/SC263/data/sig_feats/train_sigfeats.csv')
data.drop(columns=data.columns[0], inplace=True)

metadata.detect_from_dataframe(data)

In [22]:
data.head()

Unnamed: 0,SEQN_new,SDDSRVYR,diabetes,age,female,race_ethnicity,education,us_born,pir,smoke,...,LBDHDDSI,LBDSTRSI,VNTOTHDRATIO,LBXMCVSI,LBXSGTSI,LBXGH,LBXSOSSI,LBXSCLSI,LBXBPB,LBXSNASI
0,C-21017,3,0,37,1,3,2,0,1,0,...,1.99,0.948,2.38961,95.9,17,5.1,271,105,7.0,137
1,C-21091,3,0,25,1,1,5,1,2,0,...,1.94,0.644,2.36,89.9,7,4.8,268,103,1.0,136
2,C-21142,3,0,31,1,1,2,1,3,0,...,1.86,0.734,2.263889,84.8,17,5.2,276,105,0.6,139
3,C-21205,3,0,40,1,2,4,1,1,1,...,1.58,0.948,2.967213,95.7,32,5.3,272,103,3.1,138
4,C-21223,3,0,34,1,1,3,1,1,1,...,1.6,0.497,3.096774,98.6,8,5.5,278,106,0.8,141


In [23]:
# reliable negative id
id_path = '/content/drive/MyDrive/SC263/data/reliable_negatives_id.txt'
with open(id_path, 'r') as file:
    negative_ids = file.read().splitlines()

In [24]:
# reliable negatives
negatives = data[data['SEQN_new'].apply(lambda x: x in negative_ids)].copy()

# positives
positives = data[data['diabetes'] == 1].copy()

In [25]:
# sampling conditions
race_1_neg = Condition(
    num_rows=1000,
    column_values={'race_ethnicity': 1, 'diabetes': 0}
)

race_1_pos = Condition(
    num_rows=1000,
    column_values={'race_ethnicity': 1, 'diabetes': 1}
)

race_2_neg = Condition(
    num_rows=1000,
    column_values={'race_ethnicity': 2, 'diabetes': 0}
)

race_2_pos = Condition(
    num_rows=1000,
    column_values={'race_ethnicity': 2, 'diabetes': 1}
)


race_3_neg = Condition(
    num_rows=1000,
    column_values={'race_ethnicity': 3, 'diabetes': 0}
)


race_3_pos = Condition(
    num_rows=1000,
    column_values={'race_ethnicity': 3, 'diabetes': 1}
)


In [20]:
synthesizer = CTGANSynthesizer(metadata)
synthesizer.fit(data)



PerformanceAlert: Using the CTGANSynthesizer on this data is not recommended. To model this data, CTGAN will generate a large number of columns.

Original Column Name   Est # of Columns (CTGAN)
SDDSRVYR               8
diabetes               2
age                    11
female                 2
race_ethnicity         3
education              5
us_born                2
pir                    3
smoke                  2
alcohol_consumption    2
cholesterol_total      11
high_bp                2
DRXTKCAL               11
DRXTPROT               11
DRXTCARB               11
DRXTTFAT               11
DRXTSFAT               11
DRXTMFAT               11
DRXTPFAT               11
DRXTCHOL               11
DRXTFIBE               11
DRXTVARA               11
DRXTVB1                11
DRXTVB2                11
DRXTNIAC               11
DRXTVB6                11
DRXTFOLA               11
DRXTVB12               11
DRXTVC                 11
DRXTATOC               11
DRXTCALC               11
DRXTPHOS  

KeyboardInterrupt: 

In [14]:
synthesizer.sample_from_conditions(
    conditions=[race_1_neg, race_1_pos, race_2_neg, race_2_pos, race_3_neg, race_3_pos],
    output_file_path='/content/drive/MyDrive/SC263/data/sig_feats/CTGAN_synthetic_train_sigfeats_corrected.csv'
)

Sampling conditions: 100%|██████████| 6000/6000 [00:28<00:00, 206.93it/s]


Unnamed: 0,SEQN_new,SDDSRVYR,diabetes,age,female,race_ethnicity,education,us_born,pir,smoke,...,LBDHDDSI,LBDSTRSI,VNTOTHDRATIO,LBXMCVSI,LBXSGTSI,LBXGH,LBXSOSSI,LBXSCLSI,LBXBPB,LBXSNASI
0,sdv-id-kvVrQJ,4,0,36,1,1,5,1,3,0,...,0.88,1.681,5.917737,83.8,13,5.2,278,113,1.65,137
1,sdv-id-lTrYpd,6,0,36,1,1,2,1,2,0,...,0.82,2.188,9.074801,89.4,33,5.7,281,106,4.19,137
2,sdv-id-HwwjPg,7,0,25,0,1,4,1,3,0,...,1.93,0.479,3.698585,94.8,30,5.0,274,109,0.61,140
3,sdv-id-sASnMZ,6,0,79,0,1,1,1,3,0,...,1.97,1.746,3.381838,91.7,27,5.1,275,103,0.80,140
4,sdv-id-UkSwwU,7,0,36,0,1,4,1,2,0,...,1.59,1.720,3.106183,84.3,15,5.3,284,105,0.72,139
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5995,sdv-id-bZgbaH,7,1,65,0,3,1,1,2,0,...,0.64,2.762,4.223402,91.7,31,5.5,273,102,0.59,138
5996,sdv-id-xWsmpe,6,1,25,0,3,4,1,2,0,...,1.31,3.100,5.120936,84.8,18,6.4,279,106,0.58,138
5997,sdv-id-msiKGS,4,1,60,0,3,3,1,2,1,...,0.88,1.781,5.567385,83.4,46,6.3,271,103,2.21,138
5998,sdv-id-yCskIh,7,1,27,1,3,2,1,2,0,...,1.59,0.679,2.535740,84.4,19,5.8,270,104,2.88,140


## Quality Accessment

In [26]:
# create metadata
metadata = SingleTableMetadata()

# import data
data = pd.read_csv('/content/drive/MyDrive/SC263/data/sig_feats/train_sigfeats.csv')
data.drop(columns=data.columns[0], inplace=True)

metadata.detect_from_dataframe(data)

In [28]:
# load synthetic data
synthetic_data = pd.read_csv('/content/drive/MyDrive/SC263/data/sig_feats/CTGAN_synthetic_train_sigfeats_corrected.csv')

In [29]:
# diagnosis
diagnostic = run_diagnostic(
    real_data=data,
    synthetic_data=synthetic_data,
    metadata=metadata
)

Generating report ...

(1/2) Evaluating Data Validity: |██████████| 30/30 [00:00<00:00, 186.70it/s]|
Data Validity Score: 100.0%

(2/2) Evaluating Data Structure: |██████████| 1/1 [00:00<00:00, 48.21it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 100.0%



In [30]:
quality_report = evaluate_quality(
    data,
    synthetic_data,
    metadata
)

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 30/30 [00:04<00:00,  6.93it/s]|
Column Shapes Score: 84.45%

(2/2) Evaluating Column Pair Trends: |██████████| 435/435 [00:13<00:00, 31.66it/s]|
Column Pair Trends Score: 75.56%

Overall Score (Average): 80.01%



In [31]:
quality_report.get_details(property_name='Column Shapes')

Unnamed: 0,Column,Metric,Score
0,SDDSRVYR,TVComplement,0.793748
1,diabetes,TVComplement,0.525792
2,age,KSComplement,0.887988
3,female,TVComplement,0.907621
4,race_ethnicity,TVComplement,0.803493
5,education,TVComplement,0.922115
6,us_born,TVComplement,0.95187
7,pir,TVComplement,0.843961
8,smoke,TVComplement,0.881201
9,alcohol_consumption,TVComplement,0.902118


In [41]:
column_pair_trends = quality_report.get_details(property_name='Column Pair Trends')
column_pair_trends

Unnamed: 0,Column 1,Column 2,Metric,Score,Real Correlation,Synthetic Correlation
0,SDDSRVYR,diabetes,ContingencySimilarity,0.522276,,
1,SDDSRVYR,age,ContingencySimilarity,0.739204,,
2,SDDSRVYR,female,ContingencySimilarity,0.766481,,
3,SDDSRVYR,race_ethnicity,ContingencySimilarity,0.757808,,
4,SDDSRVYR,education,ContingencySimilarity,0.778384,,
...,...,...,...,...,...,...
401,LBXSOSSI,LBXBPB,CorrelationSimilarity,0.989839,0.072679,0.052357
402,LBXSOSSI,LBXSNASI,CorrelationSimilarity,0.606978,0.802365,0.016321
403,LBXSCLSI,LBXBPB,CorrelationSimilarity,0.971685,-0.063939,-0.007310
404,LBXSCLSI,LBXSNASI,CorrelationSimilarity,0.718513,0.555363,-0.007610


In [42]:
# get the most different pairs
column_pair_trends['abs_correlation_diff'] = abs(column_pair_trends['Real Correlation'] - column_pair_trends['Synthetic Correlation'])
column_pair_trends = column_pair_trends.sort_values(by='abs_correlation_diff', ascending=False)
column_pair_trends

Unnamed: 0,Column 1,Column 2,Metric,Score,Real Correlation,Synthetic Correlation,abs_correlation_diff
402,LBXSOSSI,LBXSNASI,CorrelationSimilarity,0.606978,0.802365,0.016321,0.786044
404,LBXSCLSI,LBXSNASI,CorrelationSimilarity,0.718513,0.555363,-0.007610,0.562974
317,BMXARMC,BMXWT,CorrelationSimilarity,0.767781,0.896693,0.432256,0.464437
340,BMXWAIST,BMXWT,CorrelationSimilarity,0.781000,0.906851,0.468851,0.437999
400,LBXSOSSI,LBXSCLSI,CorrelationSimilarity,0.796564,0.391529,-0.015344,0.406873
...,...,...,...,...,...,...,...
281,MCQ300C,LBXGH,ContingencySimilarity,0.866873,,,
282,MCQ300C,LBXSOSSI,ContingencySimilarity,0.117292,,,
283,MCQ300C,LBXSCLSI,ContingencySimilarity,0.545005,,,
284,MCQ300C,LBXBPB,ContingencySimilarity,0.616386,,,


In [45]:
column_pair_trends.abs_correlation_diff.mean()

0.09245353243938591

In [56]:
fig = get_column_plot(
    real_data=data,
    synthetic_data=synthetic_data,
    column_name='LBXSNASI',
    metadata=metadata
)

fig.show()

In [57]:
print(type(fig))

<class 'plotly.graph_objs._figure.Figure'>


In [66]:
def plot_multiple_columns_plotly(data, synthetic_data, column_names, metadata, nrows, ncols):
    # Create a subplot grid
    subplot_titles = [name for name in column_names]
    fig = make_subplots(rows=nrows, cols=ncols, subplot_titles=subplot_titles)

    for i, column_name in enumerate(column_names):
        row = i // ncols + 1
        col = i % ncols + 1

        # Generate the plot for this column
        single_fig = get_column_plot(
            real_data=data,
            synthetic_data=synthetic_data,
            column_name=column_name,
            metadata=metadata
        )

        # Add traces from this figure to the appropriate subplot
        for trace in single_fig.data:
            fig.add_trace(trace, row=row, col=col)

    # Update layout if necessary (optional)
    fig.update_layout(height=800, width=1200, title_text="Multiple Columns Data Comparison")
    fig.show()

column_names = ['SDDSRVYR', 'diabetes', 'age', 'female', 'race_ethnicity',
       'education', 'us_born', 'LBXGH', 'smoke', 'alcohol_consumption',
       'cholesterol_total', 'high_bp', 'MCQ300C', 'URXUMA', 'BPXPLS',
       'BMXARMC']

plot_multiple_columns_plotly(data, synthetic_data, column_names, metadata, nrows=4, ncols=4)


In [64]:
data.columns

Index(['SEQN_new', 'SDDSRVYR', 'diabetes', 'age', 'female', 'race_ethnicity',
       'education', 'us_born', 'pir', 'smoke', 'alcohol_consumption',
       'cholesterol_total', 'high_bp', 'MCQ300C', 'URXUMA', 'BPXPLS',
       'BMXARMC', 'BMXLEG', 'BMXWAIST', 'BMXWT', 'LBDHDDSI', 'LBDSTRSI',
       'VNTOTHDRATIO', 'LBXMCVSI', 'LBXSGTSI', 'LBXGH', 'LBXSOSSI', 'LBXSCLSI',
       'LBXBPB', 'LBXSNASI'],
      dtype='object')