# Batch notebook for training models

In [None]:
import os
from datetime import datetime
from keras.utils import to_categorical
import tensorflow as tf

import src.CLVAE as CLVAE

# Define Model and Hyperparameters

In [2]:
config_options = {
    "mnist": {
        'dataset': 'mnist',
        "network": {
            "latent_dim": 2,
            "intermediate_dims": [[32,3,2],[64,3,2]]
        },
    },
    "fashion_mnist": {
        'dataset': 'fashion_mnist',
        "network": {
            "latent_dim": 3,
            "intermediate_dims": [[16,3,1],[32,3,2],[64,3,2]]
        },
    },
    "cifar10": {
        'dataset': 'cifar10',
        "network": {
            "latent_dim": 3,
            "intermediate_dims": [[64,3,2],[128,3,2],[256,3,2]]
        },
    }
}

In [3]:
config_edit = config_options['mnist']
config_edit['loss'] = { "alpha": 1/6 }

print("Loading Data")
x_train, y_train, x_test, y_test, x_anom, y_anom, config = CLVAE.load_data(config_edit)
print("Trainin set", x_train.shape, y_train.shape)
print("Inlier test set", x_test.shape, y_test.shape)
print("Outlier test set", x_anom.shape, y_anom.shape)
print(config)

anomalous_digit = config['anomalous_digit']


Loading Data
Trainin set (54077, 28, 28, 1) (54077,)
Inlier test set (9020, 28, 28, 1) (9020,)
Outlier test set (6903, 28, 28, 1) (6903,)
{'dataset': 'mnist', 'anomalous_digit': 0, 'num_classes': 9, 'loss': {'alpha': 0.16666666666666666, 'type': 'normal'}, 'network': {'optimizer': 'adam', 'img_dim': 28, 'color_dim': 1, 'original_dim': 784, 'intermediate_dims': [[32, 3, 2], [64, 3, 2]], 'latent_dim': 2}, 'gaussian_config': {'mode': 'fixed', 'num_classes': 9, 'random_order': True, 'r': 3}, 'run_info': {'batch_size': 100, 'epochs': 50}}


# Batch Train Models

In [None]:
callbacks = [tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=3,
    restore_best_weights=True
)]

save_models = True
model_name = "clvae"

runs = 1
for r in range(runs):
    epochs = 50
    batch_size = 128

    if model_name == "clvae":
        model, encoder, decoder = CLVAE.build_clvae(config)
        model.fit([x_train, to_categorical(y_train)], epochs=epochs, batch_size=batch_size, validation_data=([x_test, to_categorical(y_test)], None), callbacks=callbacks)
    elif model_name == "vae":
        model, encoder, decoder = CLVAE.build_vae(config)
        history = model.fit(x_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, None), callbacks=callbacks)
    elif model_name == "cnn":
        model = CLVAE.build_cnn(config)
        history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, y_test), callbacks=callbacks)
    
    if config['network']['latent_dim'] == 3: data_type = config['dataset'] + "_dim3"
    else: data_type = config['dataset']
    
    
    if save_models:
        date = datetime.now().strftime("%Y-%m-%d %H-%M-%S")
        model_folder = "models/" + data_type + "/" + model_name  + "/" + date
        os.makedirs(model_folder, exist_ok=True)
        print("saving to " + model_folder)

        model.save_weights(model_folder + "/model")
        if model_name != "cnn":
            encoder.save_weights(model_folder + "/encoder")
            decoder.save_weights(model_folder + "/decoder")