In [1]:
import os
import time
import math
import random
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics
from scipy.stats import ks_2samp
from sdv.metadata import SingleTableMetadata
from sdv.metadata import MultiTableMetadata
from sdv.evaluation.single_table import evaluate_quality as st_evaluate_quality
from sdv.evaluation.single_table import run_diagnostic as st_run_diagnostic
from sdv.evaluation.multi_table import evaluate_quality as mt_evaluate_quality
from sdv.evaluation.multi_table import run_diagnostic as mt_run_diagnostic

# Load Processed Data From Generation Stage

In [2]:
with open('pkl/gtfs/real_data_collection.pkl', 'rb') as f:
    real_data_collection = pickle.load(f)

In [3]:
with open('pkl/gtfs/synthetic_data_full_epoch.pkl', 'rb') as f:
    synthetic_data_collection = pickle.load(f)

In [4]:
# with open('pkl/gtfs/synthetic_data_10epoch.pkl', 'rb') as f:
#     synthetic_data_collection = pickle.load(f)

In [5]:
with open('pkl/gtfs/sdv_metadata.pkl', 'rb') as f:
    sdv_metadata = pickle.load(f)

# Metrics

In [6]:
mt_quality_report = mt_evaluate_quality(
    real_data=real_data_collection,
    synthetic_data=synthetic_data_collection,
    metadata=sdv_metadata)

Creating report: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:24<00:00,  4.91s/it]



Overall Quality Score: 76.16%

Properties:
Column Shapes: 76.22%
Column Pair Trends: 60.11%
Parent Child Relationships: 92.16%


## Overall Score Per Table

In [7]:
mt_quality_report.get_details(property_name='Column Shapes').groupby('Table')['Quality Score'].mean().round(3)

Table
agency            0.920
calendar          0.877
calendar_dates    0.458
routes            0.589
stop_times        0.643
stops             0.776
trips             0.920
Name: Quality Score, dtype: float64

## Shape-wise Score Per Table

In [8]:
mt_quality_report.get_details(property_name='Column Shapes').groupby(['Table','Metric'])['Quality Score'].mean().round(3)

Table           Metric      
agency          TVComplement    0.920
calendar        KSComplement    0.562
                TVComplement    0.967
calendar_dates  KSComplement    0.386
                TVComplement    0.530
routes          TVComplement    0.589
stop_times      KSComplement    0.662
                TVComplement    0.612
stops           KSComplement    0.680
                TVComplement    0.871
trips           TVComplement    0.920
Name: Quality Score, dtype: float64

## Trend/Correlation-wise Score Per Table

In [9]:
mt_quality_report.get_details(property_name='Column Pair Trends').groupby('Table')['Quality Score'].mean().round(3)

Table
agency            0.840
calendar          0.673
calendar_dates    0.000
routes            0.376
stop_times        0.421
stops             0.673
trips             0.840
Name: Quality Score, dtype: float64

In [10]:
mt_quality_report.get_details(property_name='Column Pair Trends').groupby(['Table','Metric'])['Quality Score'].mean().round(3)

Table           Metric               
agency          ContingencySimilarity    0.840
calendar        ContingencySimilarity    0.665
                CorrelationSimilarity    0.944
calendar_dates  ContingencySimilarity    0.000
routes          ContingencySimilarity    0.376
stop_times      ContingencySimilarity    0.270
                CorrelationSimilarity    0.843
stops           ContingencySimilarity    0.608
                CorrelationSimilarity    0.964
trips           ContingencySimilarity    0.840
Name: Quality Score, dtype: float64

In [11]:
# mt_run_diagnostic(
#     real_data=real_data_collection,
#     synthetic_data=synthetic_data_collection,
#     metadata=sdv_metadata)

# Join Test

## Join Data

In [12]:
def auto_join(df_list, how='inner'):
    result = df_list[0]
    for df in df_list[1:]:
        shared_columns = list(set(result.columns) & set(df.columns))
        if not shared_columns:
            continue
        result = result.merge(df, on=shared_columns, how=how)
    return result

In [13]:
df_list = [real_data_collection['routes'], real_data_collection['trips']]
real_routes_trips = auto_join(df_list)

In [14]:
df_list = [synthetic_data_collection['routes'], synthetic_data_collection['trips']]
fake_routes_trips = auto_join(df_list)

In [15]:
routes_trips_meta = SingleTableMetadata()
routes_trips_meta.detect_from_dataframe(data=real_routes_trips)

In [16]:
routes_trips_meta.update_column(column_name='route_id',sdtype='id')
routes_trips_meta.update_column(column_name='agency_id',sdtype='id')
routes_trips_meta.update_column(column_name='route_type',sdtype='categorical')
routes_trips_meta.update_column(column_name='contract_id',sdtype='id')
routes_trips_meta.update_column(column_name='service_id',sdtype='id')
routes_trips_meta.update_column(column_name='trip_id',sdtype='id')
routes_trips_meta.update_column(column_name='direction_id',sdtype='categorical')
routes_trips_meta.update_column(column_name='shape_id',sdtype='id')
routes_trips_meta.update_column(column_name='wheelchair_accessible',sdtype='boolean')
routes_trips_meta.update_column(column_name='bikes_allowed',sdtype='boolean')

## Check Join Size Difference Between Real and Fake Data

In [17]:
def size_difference_in_percentage(df1, df2):
    df1_len = len(df1)
    df2_len = len(df2)
    min_ = min(df1_len,df2_len)
    max_ = max(df1_len,df2_len)
    leftout = (((max_ - min_) / max_) + ((max_ - min_) / min_)) / 2
    return leftout

In [18]:
size_difference_in_percentage(fake_routes_trips, real_routes_trips)

0.0

## Perform Metrics

In [19]:
st_quality_report = st_evaluate_quality(
    real_data=real_routes_trips,
    synthetic_data=fake_routes_trips,
    metadata=routes_trips_meta
)

Creating report: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.59it/s]


Overall Quality Score: 55.82%

Properties:
Column Shapes: 64.8%
Column Pair Trends: 46.84%





In [20]:
st_quality_report.get_details(property_name='Column Shapes')

Unnamed: 0,Column,Metric,Quality Score
0,route_short_name,TVComplement,0.383761
1,route_long_name,TVComplement,0.379903
2,route_type,TVComplement,0.057742
3,route_color,TVComplement,0.506735
4,route_text_color,TVComplement,0.745303
5,trip_headsign,TVComplement,0.759714
6,direction_id,TVComplement,0.998519
7,wheelchair_accessible,TVComplement,1.0
8,bikes_allowed,TVComplement,1.0


In [21]:
st_quality_report.get_details(property_name='Column Pair Trends')

Unnamed: 0,Column 1,Column 2,Metric,Quality Score,Real Correlation,Synthetic Correlation
0,route_long_name,route_short_name,ContingencySimilarity,0.375336,,
1,route_short_name,route_type,ContingencySimilarity,0.0,,
2,route_color,route_short_name,ContingencySimilarity,0.339876,,
3,route_short_name,route_text_color,ContingencySimilarity,0.352714,,
4,route_short_name,trip_headsign,ContingencySimilarity,0.0,,
5,direction_id,route_short_name,ContingencySimilarity,0.36867,,
6,route_short_name,wheelchair_accessible,ContingencySimilarity,0.383761,,
7,bikes_allowed,route_short_name,ContingencySimilarity,0.383761,,
8,route_long_name,route_type,ContingencySimilarity,0.0,,
9,route_color,route_long_name,ContingencySimilarity,0.344351,,
