In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

Path().absolute()

In [None]:
import pandas as pd
from artifact_core.libs.resource_spec.tabular.spec import TabularDataSpec
from artifact_experiment.libs.tracking.filesystem.client import FilesystemTrackingClient
from artifact_torch.base.data.data_loader import DataLoader

from demo.config.constants import (
    BATCH_SIZE,
    BN_EPSILON,
    BN_MOMENTUM,
    DROP_LAST,
    DROPOUT_RATE,
    LEAKY_RELU_SLOPE,
    LOSS_BETA,
    LS_ENCODER_LAYER_SIZES,
    SHUFFLE,
)
from demo.data.dataset import TabularVAEDataset
from demo.data.feature_flattener import FeatureFlattener
from demo.model.synthesizer import TabularVAESynthesizer, TabularVAESynthesizerConfig
from demo.trainer.trainer import TabularVAETrainer
from demo.trainer.validation_routine import TabularVAEValidationRoutine


In [None]:
artifact_experiment_root = Path().absolute().parent

df_real = pd.read_csv(artifact_experiment_root / "assets/real.csv")
df_real

In [None]:
ls_cts_features = ["Age", "RestingBP", "Cholesterol", "MaxHR", "Oldpeak"]

resource_spec = TabularDataSpec.from_df(
    df=df_real,
    ls_cts_features=ls_cts_features,
    ls_cat_features=[feature for feature in df_real.columns if feature not in ls_cts_features],
)

In [None]:
flattener = FeatureFlattener(data_spec=resource_spec)

flattener.fit(df=df_real)

In [None]:
architecture_config = TabularVAESynthesizerConfig(
    ls_encoder_layer_sizes=[len(flattener.ls_flattened_column_names)] + LS_ENCODER_LAYER_SIZES,
    loss_beta=LOSS_BETA,
    leaky_relu_slope=LEAKY_RELU_SLOPE,
    bn_momentum=BN_MOMENTUM,
    bn_epsilon=BN_EPSILON,
    dropout_rate=DROPOUT_RATE,
)

model = TabularVAESynthesizer.build(config=architecture_config, flattener=flattener)

In [None]:
filesystem_tracker = FilesystemTrackingClient.build(experiment_id="demo")
dataset = TabularVAEDataset(df=df_real, flattener=flattener)
loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, drop_last=DROP_LAST, shuffle=SHUFFLE)
validation_routine = TabularVAEValidationRoutine.build(
    df_real=df_real,
    tabular_data_spec=resource_spec,
    train_loader=loader,
    tracking_client=filesystem_tracker,
)
trainer = TabularVAETrainer.build(
    model=model,
    train_loader=loader,
    validation_routine=validation_routine,
    tracking_client=filesystem_tracker,
)

In [None]:
trainer.train()

In [None]:
trainer.epoch_scores