In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path

artifact_torch_root = Path().absolute().parent.parent

os.chdir(artifact_torch_root)

Path().absolute()

In [None]:
import pandas as pd
import seaborn as sns
from artifact_core.table_comparison import TabularDataSpec
from artifact_experiment.tracking import DataSplit, FilesystemTrackingClient
from matplotlib import pyplot as plt

from demos.table_comparison.config.constants import (
    EXPERIMENT_ID,
    LS_CAT_FEATURES,
    LS_CTS_FEATURES,
    N_BINS_CTS,
    TRAINING_DATASET_PATH,
)
from demos.table_comparison.data.utils import DemoDataUtils
from demos.table_comparison.experiment.experiment import DemoTabularSynthesisExperiment
from demos.table_comparison.libs.transformers.discretizer import Discretizer
from demos.table_comparison.libs.transformers.encoder import Encoder
from demos.table_comparison.model.synthesizer import TabularVAESynthesizer

In [None]:
sns.set_theme(style="whitegrid", palette="colorblind")

In [None]:
df_real = pd.read_csv(artifact_torch_root / TRAINING_DATASET_PATH)

raw_data_spec = TabularDataSpec.from_df(
    df=df_real, cts_features=LS_CTS_FEATURES, cat_features=LS_CAT_FEATURES
)

df_real

In [None]:
discretizer = Discretizer(n_bins=N_BINS_CTS, ls_cts_features=raw_data_spec.cts_features)

discretizer.fit(df=df_real)

df_discretized = discretizer.transform(df=df_real)

df_discretized

In [None]:
encoder = Encoder()

encoder.fit(df=df_discretized, ls_cat_features=list(df_discretized.columns))

df_encoded = encoder.transform(df=df_discretized)

encoded_data_spec = TabularDataSpec.from_df(df=df_encoded, cat_features=list(df_encoded.columns))

df_encoded

In [None]:
data_loaders = {
    DataSplit.TRAIN: DemoDataUtils.build_data_loader(
        df=df_real, discretizer=discretizer, encoder=encoder
    )
}

In [None]:
artifact_routine_data = {
    DataSplit.TRAIN: DemoDataUtils.build_artifact_routine_data(df_real=df_real)
}

In [None]:
model = TabularVAESynthesizer.build(
    data_spec=encoded_data_spec, discretizer=discretizer, encoder=encoder
)

In [None]:
tracking_client = FilesystemTrackingClient.build(experiment_id=EXPERIMENT_ID)

In [None]:
experiment = DemoTabularSynthesisExperiment.build(
    model=model,
    data_loaders=data_loaders,
    artifact_routine_data=artifact_routine_data,
    artifact_routine_data_spec=raw_data_spec,
    tracking_client=tracking_client,
)

In [None]:
experiment.run()

In [None]:
experiment.epoch_scores

In [None]:
plt.figure(figsize=(10, 6))
experiment.epoch_scores["LOSS_TRAIN"].dropna().plot(
    color=sns.color_palette("colorblind")[2], linewidth=2
)
plt.title("Train Loss", fontsize=16, fontweight="bold")
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Training Loss", fontsize=14)
plt.grid(True, linestyle="--", alpha=0.6)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.tight_layout()