**HMA Takes a long time** 

**Try loading the stored trained model instead.**

In [1]:
import os
import time
import math
import random
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics
from scipy.stats import ks_2samp
from sdv.metadata import MultiTableMetadata
from sdv.multi_table import HMASynthesizer
from sdv.evaluation.multi_table import evaluate_quality

# Load Processed Data From Generation Stage

In [2]:
with open('pkl/gtfs/real_data_collection.pkl', 'rb') as f:
    real_data_collection = pickle.load(f)

In [3]:
with open('pkl/gtfs/synthetic_data_full_epoch.pkl', 'rb') as f:
    synthetic_data_collection = pickle.load(f)

In [4]:
with open('pkl/gtfs/sdv_metadata.pkl', 'rb') as f:
    sdv_metadata = pickle.load(f)

# Benchmark

In [5]:
sdv_metadata = MultiTableMetadata()

In [6]:
sdv_metadata.detect_table_from_dataframe(table_name='agency', data=real_data_collection['agency'])
sdv_metadata.update_column(table_name='agency',column_name='agency_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='agency',column_name='agency_name',sdtype='categorical')
sdv_metadata.update_column(table_name='agency',column_name='agency_url',sdtype='categorical')
sdv_metadata.update_column(table_name='agency',column_name='agency_timezone',sdtype='categorical')
sdv_metadata.update_column(table_name='agency',column_name='agency_lang',sdtype='categorical')
sdv_metadata.update_column(table_name='agency',column_name='agency_phone',sdtype='categorical')
sdv_metadata.set_primary_key(table_name='agency',column_name='agency_id')

In [7]:
sdv_metadata.detect_table_from_dataframe(table_name='calendar',data=real_data_collection['calendar'])
sdv_metadata.update_column(table_name='calendar',column_name='service_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='calendar',column_name='monday',sdtype='categorical')
sdv_metadata.update_column(table_name='calendar',column_name='tuesday',sdtype='categorical')
sdv_metadata.update_column(table_name='calendar',column_name='wednesday',sdtype='categorical')
sdv_metadata.update_column(table_name='calendar',column_name='thursday',sdtype='categorical')
sdv_metadata.update_column(table_name='calendar',column_name='friday',sdtype='categorical')
sdv_metadata.update_column(table_name='calendar',column_name='saturday',sdtype='categorical')
sdv_metadata.update_column(table_name='calendar',column_name='sunday',sdtype='categorical')
sdv_metadata.update_column(table_name='calendar',column_name='start_date',sdtype='datetime',datetime_format='%Y%m%d')
sdv_metadata.update_column(table_name='calendar',column_name='end_date',sdtype='datetime',datetime_format='%Y%m%d')
sdv_metadata.set_primary_key(table_name='calendar',column_name='service_id')

In [8]:
sdv_metadata.detect_table_from_dataframe(table_name='calendar_dates',data=real_data_collection['calendar_dates'])
sdv_metadata.update_column(table_name='calendar_dates',column_name='service_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='calendar_dates',column_name='date',sdtype='datetime',datetime_format='%Y%m%d')
sdv_metadata.update_column(table_name='calendar_dates',column_name='exception_type',sdtype='categorical')

In [9]:
sdv_metadata.detect_table_from_dataframe(table_name='routes',data=real_data_collection['routes'])
sdv_metadata.update_column(table_name='routes',column_name='route_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='routes',column_name='agency_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='routes',column_name='route_short_name',sdtype='categorical')
sdv_metadata.update_column(table_name='routes',column_name='route_long_name',sdtype='categorical')
sdv_metadata.update_column(table_name='routes',column_name='route_type',sdtype='categorical')
sdv_metadata.update_column(table_name='routes',column_name='route_color',sdtype='categorical')
sdv_metadata.update_column(table_name='routes',column_name='route_text_color',sdtype='categorical')
sdv_metadata.update_column(table_name='routes',column_name='contract_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.set_primary_key(table_name='routes',column_name='route_id')

In [10]:
sdv_metadata.detect_table_from_dataframe(table_name='stops',data=real_data_collection['stops'])
sdv_metadata.update_column(table_name='stops',column_name='stop_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='stops',column_name='stop_code',sdtype='numerical',computer_representation='Int32')
sdv_metadata.update_column(table_name='stops',column_name='stop_name',sdtype='categorical')
sdv_metadata.update_column(table_name='stops',column_name='stop_lat',sdtype='numerical',computer_representation='Float')
sdv_metadata.update_column(table_name='stops',column_name='stop_lon',sdtype='numerical',computer_representation='Float')
sdv_metadata.update_column(table_name='stops',column_name='location_type',sdtype='categorical')
sdv_metadata.update_column(table_name='stops',column_name='parent_station',sdtype='categorical')
sdv_metadata.update_column(table_name='stops',column_name='platform_code',sdtype='categorical')
sdv_metadata.update_column(table_name='stops',column_name='wheelchair_boarding',sdtype='categorical')
sdv_metadata.update_column(table_name='stops',column_name='start_date',sdtype='datetime',datetime_format='%Y%m%d')
sdv_metadata.update_column(table_name='stops',column_name='end_date',sdtype='datetime',datetime_format='%Y%m%d')
sdv_metadata.set_primary_key(table_name='stops',column_name='stop_id')

In [11]:
sdv_metadata.detect_table_from_dataframe(table_name='stop_times',data=real_data_collection['stop_times'])
sdv_metadata.update_column(table_name='stop_times',column_name='trip_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='stop_times',column_name='arrival_time',sdtype='datetime',datetime_format='%H:%M:%S')
sdv_metadata.update_column(table_name='stop_times',column_name='departure_time',sdtype='datetime',datetime_format='%H:%M:%S')
sdv_metadata.update_column(table_name='stop_times',column_name='stop_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='stop_times',column_name='stop_sequence',sdtype='numerical',computer_representation='Int32')
sdv_metadata.update_column(table_name='stop_times',column_name='stop_headsign',sdtype='categorical')
sdv_metadata.update_column(table_name='stop_times',column_name='pickup_type',sdtype='categorical')
sdv_metadata.update_column(table_name='stop_times',column_name='drop_off_type',sdtype='categorical')
sdv_metadata.update_column(table_name='stop_times',column_name='shape_dist_traveled',sdtype='numerical',computer_representation='Float')
sdv_metadata.update_column(table_name='stop_times',column_name='timepoint',sdtype='numerical',computer_representation='Int32')

In [12]:
sdv_metadata.detect_table_from_dataframe(table_name='trips',data=real_data_collection['trips'])
sdv_metadata.update_column(table_name='trips',column_name='route_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='trips',column_name='service_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='trips',column_name='trip_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='trips',column_name='trip_headsign',sdtype='categorical')
sdv_metadata.update_column(table_name='trips',column_name='direction_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='trips',column_name='shape_id',sdtype='id',regex_format = '[a-z\d]{8}')
sdv_metadata.update_column(table_name='trips',column_name='wheelchair_accessible',sdtype='categorical')
sdv_metadata.update_column(table_name='trips',column_name='bikes_allowed',sdtype='categorical')
sdv_metadata.set_primary_key(table_name='trips',column_name='trip_id')

In [13]:
sdv_metadata.add_relationship(
    parent_table_name='agency',
    child_table_name='routes',
    parent_primary_key='agency_id',
    child_foreign_key='agency_id'
)

sdv_metadata.add_relationship(
    parent_table_name='calendar',
    child_table_name='calendar_dates',
    parent_primary_key='service_id',
    child_foreign_key='service_id'
)

sdv_metadata.add_relationship(
    parent_table_name='calendar',
    child_table_name='trips',
    parent_primary_key='service_id',
    child_foreign_key='service_id'
)

sdv_metadata.add_relationship(
    parent_table_name='routes',
    child_table_name='trips',
    parent_primary_key='route_id',
    child_foreign_key='route_id'
)

sdv_metadata.add_relationship(
    parent_table_name='stops',
    child_table_name='stop_times',
    parent_primary_key='stop_id',
    child_foreign_key='stop_id'
)

sdv_metadata.add_relationship(
    parent_table_name='trips',
    child_table_name='stop_times',
    parent_primary_key='trip_id',
    child_foreign_key='trip_id'
)

## HMA

In [14]:
synthesizer = HMASynthesizer(sdv_metadata)

In [None]:
%%time
%%capture

synthesizer.fit(real_data_collection)
synthesizer.save(
    filepath='models/gtfs.pkl'
)

In [None]:
%%time
%%capture

synthesizer = HMASynthesizer.load(
    filepath='models/gtfs.pkl'
)

hma_synthetic_data_collection = synthesizer.sample(
    scale=1
)

In [None]:
hma_quality_report = evaluate_quality(
    real_data=real_data_collection,
    synthetic_data=hma_synthetic_data_collection,
    metadata=sdv_metadata)

In [None]:
prototype_quality_report = evaluate_quality(
    real_data=real_data_collection,
    synthetic_data=synthetic_data_collection,
    metadata=sdv_metadata)