In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

Path().absolute()

WindowsPath('c:/Users/hecto/Codebase/artifact-ml/artifact-torch/demo')

In [3]:
import pandas as pd
from artifact_core.libs.resource_spec.tabular.spec import TabularDataSpec
from artifact_experiment.libs.tracking.filesystem.client import FilesystemTrackingClient
from artifact_torch.base.data.data_loader import DataLoader

from demo.config.constants import (
    BATCH_SIZE,
    BN_EPSILON,
    BN_MOMENTUM,
    DROP_LAST,
    DROPOUT_RATE,
    LEAKY_RELU_SLOPE,
    LOSS_BETA,
    LS_ENCODER_LAYER_SIZES,
    SHUFFLE,
)
from demo.data.dataset import TabularVAEDataset
from demo.data.feature_flattener import FeatureFlattener
from demo.model.synthesizer import TabularVAESynthesizer, TabularVAESynthesizerConfig
from demo.trainer.trainer import TabularVAETrainer
from demo.trainer.validation_routine import TabularVAEValidationRoutine


In [4]:
artifact_experiment_root = Path().absolute().parent

df_real = pd.read_csv(artifact_experiment_root / "assets/real.csv")
df_real

Unnamed: 0,Age,Sex,ChestPainType,RestingBP,Cholesterol,FastingBS,RestingECG,MaxHR,ExerciseAngina,Oldpeak,ST_Slope,HeartDisease
0,40,M,ATA,140,289,0,Normal,172,N,0.0,Up,0
1,49,F,NAP,160,180,0,Normal,156,N,1.0,Flat,1
2,37,M,ATA,130,283,0,ST,98,N,0.0,Up,0
3,48,F,ASY,138,214,0,Normal,108,Y,1.5,Flat,1
4,54,M,NAP,150,195,0,Normal,122,N,0.0,Up,0
...,...,...,...,...,...,...,...,...,...,...,...,...
913,45,M,TA,110,264,0,Normal,132,N,1.2,Flat,1
914,68,M,ASY,144,193,1,Normal,141,N,3.4,Flat,1
915,57,M,ASY,130,131,0,Normal,115,Y,1.2,Flat,1
916,57,F,ATA,130,236,0,LVH,174,N,0.0,Flat,1


In [5]:
ls_cts_features = ["Age", "RestingBP", "Cholesterol", "MaxHR", "Oldpeak"]

resource_spec = TabularDataSpec.from_df(
    df=df_real,
    ls_cts_features=ls_cts_features,
    ls_cat_features=[feature for feature in df_real.columns if feature not in ls_cts_features],
)

In [6]:
flattener = FeatureFlattener(data_spec=resource_spec)

flattener.fit(df=df_real)

array([[0.24489796, 0.7       , 0.47927032, ..., 0.        , 1.        ,
        1.        ],
       [0.42857143, 0.8       , 0.29850746, ..., 1.        , 0.        ,
        1.        ],
       [0.18367347, 0.65      , 0.46932007, ..., 0.        , 1.        ,
        1.        ],
       ...,
       [0.59183673, 0.65      , 0.2172471 , ..., 1.        , 0.        ,
        1.        ],
       [0.59183673, 0.65      , 0.39137645, ..., 1.        , 0.        ,
        1.        ],
       [0.20408163, 0.69      , 0.29021559, ..., 0.        , 1.        ,
        1.        ]], shape=(918, 21))

In [7]:
architecture_config = TabularVAESynthesizerConfig(
    ls_encoder_layer_sizes=[len(flattener.ls_flattened_column_names)] + LS_ENCODER_LAYER_SIZES,
    loss_beta=LOSS_BETA,
    leaky_relu_slope=LEAKY_RELU_SLOPE,
    bn_momentum=BN_MOMENTUM,
    bn_epsilon=BN_EPSILON,
    dropout_rate=DROPOUT_RATE,
)

model = TabularVAESynthesizer.build(config=architecture_config, flattener=flattener)

In [8]:
filesystem_tracker = FilesystemTrackingClient.build(experiment_id="demo")
dataset = TabularVAEDataset(df=df_real, flattener=flattener)
loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, drop_last=DROP_LAST, shuffle=SHUFFLE)
validation_routine = TabularVAEValidationRoutine.build(
    df_real=df_real,
    tabular_data_spec=resource_spec,
    train_loader=loader,
    tracking_client=filesystem_tracker,
)
trainer = TabularVAETrainer.build(
    model=model, train_loader=loader, validation_routine=validation_routine
)

In [9]:
trainer.train()

Training on device: cpu


                                                                         

In [10]:
trainer.batch_cache.get_full_history("batch_loss")

[94.07298278808594,
 94.74506378173828,
 98.06695556640625,
 102.92822265625,
 99.40510559082031,
 95.68535614013672,
 101.68133544921875,
 107.3340835571289,
 96.94939422607422,
 110.16706848144531,
 107.57196044921875,
 105.56519317626953,
 117.22380065917969,
 106.9201431274414,
 111.93673706054688,
 100.39976501464844,
 115.59913635253906,
 92.35005187988281,
 104.1298828125,
 100.13412475585938,
 106.30201721191406,
 95.79492950439453,
 104.81558227539062,
 89.10783386230469,
 85.562744140625,
 86.21676635742188,
 95.08621215820312,
 76.32865142822266,
 90.26446533203125,
 83.37684631347656,
 77.298095703125,
 85.17749786376953,
 83.98896789550781,
 76.52367401123047,
 77.31805419921875,
 82.39067840576172,
 76.75288391113281,
 76.07357788085938,
 77.14584350585938,
 71.98766326904297,
 76.69206237792969,
 67.77980041503906,
 67.58712768554688,
 65.34146881103516,
 72.75856018066406,
 61.74360656738281,
 61.40653991699219,
 66.05590057373047,
 61.099876403808594,
 59.7756919860839

In [11]:
trainer.epoch_scores

Unnamed: 0,train_loss,MEAN_JS_DISTANCE,PAIRWISE_CORRELATION_DISTANCE
0,24.540452,,
1,21.94133,,
2,21.787362,,
3,21.667128,,
4,21.711225,0.541279,0.260978
5,21.742171,,
6,21.783869,,
7,21.768117,,
8,21.945376,,
9,22.00788,0.565137,0.281583
