In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import pandas as pd
import seaborn as sns
from artifact_core.libs.resource_spec.tabular.spec import TabularDataSpec
from artifact_experiment.libs.tracking.filesystem.client import FilesystemTrackingClient
from demos.table_comparison.config.constants import (
    LS_CAT_FEATURES,
    LS_CTS_FEATURES,
    TRAINING_DATASET_PATH,
)
from demos.table_comparison.tabular_vae import TabularVAE
from matplotlib import pyplot as plt

In [None]:
sns.set_theme(style="whitegrid", palette="colorblind")

In [None]:
artifact_experiment_root = Path().absolute().parent.parent

df_real = pd.read_csv(artifact_experiment_root / TRAINING_DATASET_PATH)

df_real

In [None]:
data_spec = TabularDataSpec.from_df(
    df=df_real,
    ls_cts_features=LS_CTS_FEATURES,
    ls_cat_features=LS_CAT_FEATURES,
)

In [None]:
filesystem_tracker = FilesystemTrackingClient.build(experiment_id="demo")

In [None]:
model = TabularVAE.build(data_spec=data_spec)

In [None]:
epoch_scores = model.fit(df=df_real, data_spec=data_spec, tracking_client=filesystem_tracker)

In [None]:
plt.figure(figsize=(10, 6))
epoch_scores["TRAIN_LOSS"].plot(color=sns.color_palette("colorblind")[2], linewidth=2)
plt.title("Train Loss", fontsize=16, fontweight="bold")
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Training Loss", fontsize=14)
plt.grid(True, linestyle="--", alpha=0.6)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.tight_layout()