**CAUTION: TAKES A LONG TIME.**  
Try using the stored trained model instead.

In [1]:
import os
import time
import math
import random
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics
from scipy.stats import ks_2samp
from sdv.metadata import MultiTableMetadata
from sdv.evaluation.single_table import evaluate_quality as st_evaluate_quality
from sdv.evaluation.multi_table import evaluate_quality as mt_evaluate_quality
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table import CTGANSynthesizer
from sdv.single_table import TVAESynthesizer
from sdv.multi_table import HMASynthesizer

In [2]:
def cos_test(df1, df2):
    cos_sim = metrics.pairwise.cosine_similarity(df1.values.T, df2.values.T)
    mean_cos_sim = np.mean(cos_sim)
    return mean_cos_sim

In [3]:
def batch_cos_test(collection1, collection2):
    test_dict = {}
    for df_name in collection1.keys():
        mean_cos_sim = cos_test(collection2[df_name], collection1[df_name])
        test_dict[df_name] = 1-mean_cos_sim
    return test_dict

In [4]:
def ks_test(df1, df2):
    ks_stats = []
    p_values = []
    for column in df1.columns:
        ks_stat, ks_p_value = ks_2samp(df1[column], df2[column])
        ks_stats.append(ks_stat)
        p_values.append(ks_p_value)
    mean_ks_stat = np.mean(ks_stats)
    mean_p_value = np.mean(p_values)
    return mean_ks_stat, mean_p_value

In [5]:
def batch_ks_test(collection1, collection2):
    stats_dict = {}
    for df_name in collection1:
        ks_results, p_value = ks_test(collection1[df_name], collection2[df_name])
        stats_dict[df_name] = 1-ks_results
    return stats_dict

# Load Processed Data From Generation Stage

In [6]:
with open('pkl/real_data_collection.pkl', 'rb') as f:
    real_data_collection = pickle.load(f)

In [7]:
with open('pkl/synthetic_data_collection_e100.pkl', 'rb') as f:
    synthetic_data_collection = pickle.load(f)

In [8]:
with open('pkl/sdv_metadata.pkl', 'rb') as f:
    sdv_metadata = pickle.load(f)

# Benchmark

In [9]:
generation_dict = {k:{'nrows':len(v)} for k,v in real_data_collection.items()}

In [10]:
generation_dict

{'agency': {'nrows': 15},
 'calendar': {'nrows': 121},
 'calendar_dates': {'nrows': 674},
 'routes': {'nrows': 215},
 'stops': {'nrows': 6714},
 'stop_times': {'nrows': 966790},
 'trips': {'nrows': 32403}}

## GaussianCopula

In [11]:
# def collection_gaussiancopula_training(data_collection, multi_metadata):
#     generator_dict = {}
#     for df_name, df in data_collection.items():
#         synthesizer = GaussianCopulaSynthesizer(
#             multi_metadata.tables[df_name],
#             enforce_min_max_values=True,
#             enforce_rounding=True,
#             default_distribution='norm',
#         )
#         synthesizer.fit(df)
#         synthesizer.save(
#             filepath='models_single/'+df_name+'_gc.pkl'
#         )

In [12]:
# %%time

# collection_gaussiancopula_training(real_data_collection, sdv_metadata)

In [13]:
%%time

gc_synthetic_data_collection = {}
for df_name, content in generation_dict.items():
    synthesizer = GaussianCopulaSynthesizer.load(
        filepath='models_single/'+df_name+'_gc.pkl'
    )
    gc_synthetic_data_collection[df_name] = synthesizer.sample(
        num_rows=content['nrows']
    )

CPU times: total: 14.4 s
Wall time: 9.63 s


In [14]:
batch_cos_test(gc_synthetic_data_collection, real_data_collection)

{'agency': 0.96,
 'calendar': 0.47202756917211053,
 'calendar_dates': 0.09751486150956068,
 'routes': 0.7069871996285897,
 'stops': 0.826732863027359,
 'stop_times': 0.5464731949556783,
 'trips': 0.5474426557179761}

In [15]:
batch_ks_test(gc_synthetic_data_collection, real_data_collection)

  ks_stat, ks_p_value = ks_2samp(df1[column], df2[column])


{'agency': 1.0,
 'calendar': 0.9049586776859504,
 'calendar_dates': 0.6666666666666667,
 'routes': 0.6093023255813954,
 'stops': 0.7977029755403303,
 'stop_times': 0.7459531025351938,
 'trips': 0.4921226429651575}

In [16]:
individual_report_collection = {}
for df_name in real_data_collection.keys():
    print(f'[{df_name}]:')
    individual_report_collection[df_name] = st_evaluate_quality(
        real_data=real_data_collection[df_name],
        synthetic_data=gc_synthetic_data_collection[df_name],
        metadata=sdv_metadata.tables[df_name])
    print()
    print('--------------------')
    print()

[agency]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 62.49it/s]



Overall Quality Score: 100.0%

Properties:
Column Shapes: 100.0%
Column Pair Trends: 100.0%

--------------------

[calendar]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 19.38it/s]



Overall Quality Score: 82.03%

Properties:
Column Shapes: 89.44%
Column Pair Trends: 74.62%

--------------------

[calendar_dates]:


Creating report: 100%|█████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 1999.43it/s]



Overall Quality Score: 78.75%

Properties:
Column Shapes: 91.02%
Column Pair Trends: 66.47%

--------------------

[routes]:


Creating report: 100%|██████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 105.24it/s]



Overall Quality Score: 65.74%

Properties:
Column Shapes: 74.57%
Column Pair Trends: 56.9%

--------------------

[stops]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 15.47it/s]



Overall Quality Score: 72.68%

Properties:
Column Shapes: 76.93%
Column Pair Trends: 68.42%

--------------------

[stop_times]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:13<00:00,  3.45s/it]



Overall Quality Score: 82.07%

Properties:
Column Shapes: 87.16%
Column Pair Trends: 76.98%

--------------------

[trips]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 19.13it/s]


Overall Quality Score: 76.13%

Properties:
Column Shapes: 84.09%
Column Pair Trends: 68.17%

--------------------






## CTGAN

In [17]:
# training_parameter_dict = {'agency': {'epochs': 10},
#                            'calendar': {'epochs': 10},
#                            'calendar_dates': {'epochs': 10},
#                            'routes': {'epochs': 10},
#                            'stops': {'epochs': 10},
#                            'stop_times': {'epochs': 10},
#                            'trips': {'epochs': 10}}

In [18]:
# def collection_gtgan_training(data_collection, training_parameter_dict, multi_metadata):
#     generator_dict = {}
#     for df_name, df in data_collection.items():
#         print(f'[{df_name}]:')
#         synthesizer = CTGANSynthesizer(
#             multi_metadata.tables[df_name],
#             enforce_rounding=True,
#             epochs=training_parameter_dict[df_name]['epochs'],
#             verbose=True,
#             batch_size=400
#         )
#         synthesizer.fit(df)
#         synthesizer.save(
#             filepath='models_single/'+df_name+'_ctgan.pkl'
#         )
#         print()

In [19]:
# %%time

# collection_gtgan_training(real_data_collection, training_parameter_dict, sdv_metadata)

In [20]:
%%time

ctgan_synthetic_data_collection = {}
for df_name, content in generation_dict.items():
    synthesizer = GaussianCopulaSynthesizer.load(
        filepath='models_single/'+df_name+'_ctgan.pkl'
    )
    ctgan_synthetic_data_collection[df_name] = synthesizer.sample(
        num_rows=content['nrows']
    )

CPU times: total: 3min 6s
Wall time: 50.7 s


In [21]:
batch_cos_test(ctgan_synthetic_data_collection, real_data_collection)

{'agency': 0.96,
 'calendar': 0.45433335433075717,
 'calendar_dates': 0.09894015760429198,
 'routes': 0.6368036019915946,
 'stops': 0.8347906714470074,
 'stop_times': 0.5269788145343832,
 'trips': 0.54789669766831}

In [22]:
batch_ks_test(ctgan_synthetic_data_collection, real_data_collection)

  ks_stat, ks_p_value = ks_2samp(df1[column], df2[column])


{'agency': 1.0,
 'calendar': 0.8801652892561983,
 'calendar_dates': 0.6582591493570722,
 'routes': 0.5031007751937984,
 'stops': 0.8891040280673882,
 'stop_times': 0.7199040122467134,
 'trips': 0.5018362497299633}

In [23]:
individual_report_collection = {}
for df_name in real_data_collection.keys():
    print(f'[{df_name}]:')
    individual_report_collection[df_name] = st_evaluate_quality(
        real_data=real_data_collection[df_name],
        synthetic_data=ctgan_synthetic_data_collection[df_name],
        metadata=sdv_metadata.tables[df_name])
    print('--------------------')
    print()

[agency]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 67.78it/s]



Overall Quality Score: 100.0%

Properties:
Column Shapes: 100.0%
Column Pair Trends: 100.0%
--------------------

[calendar]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 18.51it/s]



Overall Quality Score: 78.05%

Properties:
Column Shapes: 86.69%
Column Pair Trends: 69.42%
--------------------

[calendar_dates]:


Creating report: 100%|█████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 1999.91it/s]



Overall Quality Score: 82.94%

Properties:
Column Shapes: 89.76%
Column Pair Trends: 76.11%
--------------------

[routes]:


Creating report: 100%|██████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 114.25it/s]



Overall Quality Score: 36.98%

Properties:
Column Shapes: 50.23%
Column Pair Trends: 23.72%
--------------------

[stops]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 20.00it/s]



Overall Quality Score: 83.15%

Properties:
Column Shapes: 87.52%
Column Pair Trends: 78.77%
--------------------

[stop_times]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:14<00:00,  3.53s/it]



Overall Quality Score: 81.32%

Properties:
Column Shapes: 86.01%
Column Pair Trends: 76.62%
--------------------

[trips]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 19.04it/s]


Overall Quality Score: 88.35%

Properties:
Column Shapes: 92.23%
Column Pair Trends: 84.46%
--------------------






## TVAE

In [24]:
# training_parameter_dict = {'agency': {'epochs': 10},
#                            'calendar': {'epochs': 10},
#                            'calendar_dates': {'epochs': 10},
#                            'routes': {'epochs': 10},
#                            'stops': {'epochs': 10},
#                            'stop_times': {'epochs': 10},
#                            'trips': {'epochs': 10}}

In [25]:
# def collection_tvae_training(data_collection, training_parameter_dict, multi_metadata):
#     generator_dict = {}
#     for df_name, df in data_collection.items():
#         synthesizer = TVAESynthesizer(
#             multi_metadata.tables[df_name],
#             enforce_min_max_values=True,
#             enforce_rounding=True,
#             epochs=training_parameter_dict[df_name]['epochs'],
#             batch_size=400
#         )
#         synthesizer.fit(df)
#         synthesizer.save(
#             filepath='models_single/'+df_name+'_tvae.pkl'
#         )

In [26]:
# %%time

# collection_tvae_training(real_data_collection, training_parameter_dict, sdv_metadata)

In [27]:
%%time

tvae_synthetic_data_collection = {}
for df_name, content in generation_dict.items():
    synthesizer = TVAESynthesizer.load(
        filepath='models_single/'+df_name+'_tvae.pkl'
    )
    tvae_synthetic_data_collection[df_name] = synthesizer.sample(
        num_rows=content['nrows']
    )

CPU times: total: 1min 5s
Wall time: 33 s


In [28]:
batch_cos_test(tvae_synthetic_data_collection, real_data_collection)

{'agency': 0.96,
 'calendar': 0.6958804743808202,
 'calendar_dates': 0.08489484015640736,
 'routes': 0.834312248944935,
 'stops': 0.963467587426582,
 'stop_times': 0.5970759661458913,
 'trips': 0.5492417285637832}

In [29]:
batch_ks_test(tvae_synthetic_data_collection, real_data_collection)

{'agency': 1.0,
 'calendar': 0.7380165289256199,
 'calendar_dates': 0.4920870425321464,
 'routes': 0.6751937984496124,
 'stops': 0.9025585013073842,
 'stop_times': 0.7580405258639415,
 'trips': 0.4854257321852915}

In [30]:
individual_report_collection = {}
for df_name in real_data_collection.keys():
    print(f'[{df_name}]:')
    individual_report_collection[df_name] = st_evaluate_quality(
        real_data=real_data_collection[df_name],
        synthetic_data=tvae_synthetic_data_collection[df_name],
        metadata=sdv_metadata.tables[df_name])
    print()
    print('--------------------')
    print()

[agency]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 64.50it/s]



Overall Quality Score: 100.0%

Properties:
Column Shapes: 100.0%
Column Pair Trends: 100.0%

--------------------

[calendar]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 19.13it/s]



Overall Quality Score: 63.59%

Properties:
Column Shapes: 70.89%
Column Pair Trends: 56.29%

--------------------

[calendar_dates]:


Creating report: 100%|█████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 1998.72it/s]



Overall Quality Score: 55.86%

Properties:
Column Shapes: 64.84%
Column Pair Trends: 46.88%

--------------------

[routes]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 86.94it/s]



Overall Quality Score: 87.52%

Properties:
Column Shapes: 88.68%
Column Pair Trends: 86.36%

--------------------

[stops]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 13.33it/s]



Overall Quality Score: 86.75%

Properties:
Column Shapes: 89.04%
Column Pair Trends: 84.47%

--------------------

[stop_times]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:13<00:00,  3.47s/it]



Overall Quality Score: 86.25%

Properties:
Column Shapes: 89.96%
Column Pair Trends: 82.54%

--------------------

[trips]:


Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  8.49it/s]


Overall Quality Score: 60.13%

Properties:
Column Shapes: 73.42%
Column Pair Trends: 46.84%

--------------------






## HMA

In [31]:
# synthesizer = HMASynthesizer(sdv_metadata)

In [32]:
# %%time

# synthesizer.fit(real_data_collection)

# KS Comparison Table

In [33]:
def create_comparison_table(real_collection, fake_collections):
    comparison_dict = {}
    for model_name, fake_collection in fake_collections.items():
        ks_test_scores = batch_ks_test(real_collection, fake_collection)
        comparison_dict[model_name] = ks_test_scores
    comparison_table = pd.DataFrame(comparison_dict)
    return comparison_table

In [34]:
benchmark_collection = {
    'GaussianCopula':gc_synthetic_data_collection,
    'CTGAN':ctgan_synthetic_data_collection,
    'TVAE':tvae_synthetic_data_collection,
    'New Approach':synthetic_data_collection
}

create_comparison_table(real_data_collection, benchmark_collection).round(3)

  ks_stat, ks_p_value = ks_2samp(df1[column], df2[column])


Unnamed: 0,GaussianCopula,CTGAN,TVAE,New Approach
agency,1.0,1.0,1.0,1.0
calendar,0.905,0.88,0.738,0.713
calendar_dates,0.667,0.658,0.492,0.731
routes,0.609,0.503,0.675,0.681
stops,0.798,0.889,0.903,0.596
stop_times,0.746,0.72,0.758,0.797
trips,0.492,0.502,0.485,0.673
