References:
* Tazwar, S., Knobbout, M., Quesada, E., & Popa, M. (2024). Tab-VAE: A Novel VAE for Generating Synthetic Tabular Data: Proceedings of the 13th International Conference on Pattern Recognition Applications and Methods, 17–26. https://doi.org/10.5220/0012302400003654
* https://github.com/sdv-dev/SDV

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sklearn.datasets as sk_data
import sklearn.model_selection as sk_selection
import random_neural_net_models.tabular as rnnm_tab
import torch.optim as optim
import random_neural_net_models.learner as rnnm_learner
import random_neural_net_models.data as rnnm_data
import random_neural_net_models.utils as rnnm_utils
from pathlib import Path
import random_neural_net_models.tabular_vae as rnnm_tvae
from torch.utils.data import DataLoader

In [None]:
rnnm_utils.make_deterministic(42)

In [None]:
device = rnnm_utils.get_device()
device

In [None]:
n_samples = 1_000
n_features = 3
n_classes = 2
X, y = sk_data.make_blobs(
    n_samples=n_samples,
    n_features=n_features,
    random_state=42,
    centers=n_classes,
)

## numerical columns only vae

In [None]:
X0, X1 = sk_selection.train_test_split(X, test_size=0.2)

In [None]:
X0.shape, X1.shape

In [None]:
ds_train = rnnm_data.NumpyInferenceDataset(X0)
ds_valid = rnnm_data.NumpyInferenceDataset(X1)

In [None]:
ds_train[2]

In [None]:
from torch.utils.data import RandomSampler
import torch

In [None]:
batch_size = 50

sampler = RandomSampler(
    ds_train,
    replacement=True,
    num_samples=int(1e5),
    generator=torch.manual_seed(42),
)

dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    sampler=sampler,
    collate_fn=rnnm_data.collate_numpy_dataset_to_xblock,
    drop_last=True,
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=batch_size,
    collate_fn=rnnm_data.collate_numpy_dataset_to_xblock,
)

In [None]:
next(iter(dl_train))

In [None]:
means = X0.mean(axis=0)
means

In [None]:
stds = X0.std(axis=0)
stds

In [None]:
n_features = X0.shape[1]
n_hidden = [n_features, 3, 3]
do_impute = False
impute_bias_source = rnnm_tab.BiasSources.zero
n_latent = 2

model = rnnm_tvae.TabularVariationalAutoEncoderNumerical(
    n_hidden=n_hidden,
    n_latent=n_latent,
    means=means,
    stds=stds,
    do_impute=False,
    use_batch_norm=True,
)

In [None]:
learning_rate = 0.1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss = rnnm_tvae.KullbackLeiblerNumericalOnlyLoss()
loss_callback = rnnm_learner.TrainLossCallback()

save_dir = Path("./models")

callbacks = [loss_callback]

In [None]:
learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
    show_epoch_progress=True,
)

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 100, 100)

learner.find_learning_rate(
    dl_train, n_epochs=10, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot(yscale="log", ylim=(4e3, 1e5))

In [None]:
learning_rate = 2e-1
n_epochs = 5

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=learning_rate,
    epochs=n_epochs,
    steps_per_epoch=len(dl_train),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(dl_train, n_epochs=n_epochs, dataloader_valid=dl_valid)

In [None]:
loss_callback.plot(yscale="log")

In [None]:
X_preds = learner.predict(dl_valid, component=0).detach().numpy()
X_preds[:3]

In [None]:
import seaborn as sns

ax = sns.scatterplot(
    x=X1[:, 0], y=X1[:, 1], alpha=0.3, color="black", label="orig"
)
sns.scatterplot(
    x=X_preds[:, 0],
    y=X_preds[:, 1],
    alpha=0.3,
    color="orange",
    label="vae",
    ax=ax,
)
ax.legend()

## numerical and categorical vae

In [None]:
X0_num, X1_num, X0_cat, X1_cat = sk_selection.train_test_split(
    X, y.reshape((-1, 1)), test_size=0.2
)

In [None]:
X0_num[:3], X0_cat[:3]

In [None]:
ds_train = rnnm_data.NumpyNumCatTrainingDatasetXOnly(X0_num, X0_cat)
ds_valid = rnnm_data.NumpyNumCatTrainingDatasetXOnly(X1_num, X1_cat)

In [None]:
ds_train[0]

In [None]:
from torch.utils.data import RandomSampler
import torch

In [None]:
batch_size = 50

sampler = RandomSampler(
    ds_train,
    replacement=True,
    num_samples=int(1e5),
    generator=torch.manual_seed(42),
)

dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    sampler=sampler,
    collate_fn=rnnm_data.collate_numpy_numcat_dataset_to_xblock,
    drop_last=True,
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=batch_size,
    collate_fn=rnnm_data.collate_numpy_numcat_dataset_to_xblock,
)

In [None]:
next(iter(dl_train))

In [None]:
means = X0_num.mean(axis=0)
means

In [None]:
stds = X0_num.std(axis=0)
stds

In [None]:
n_categories_per_column = rnnm_data.calc_n_categories_per_column(X0_cat)
n_categories_per_column

In [None]:
n_features = X0_num.shape[1] + X0_cat.shape[1]
n_hidden = [n_features, 3, 3]
do_impute = False
impute_bias_source = rnnm_tab.BiasSources.zero
n_latent = 2

model = rnnm_tvae.TabularVariationalAutoEncoderNumericalAndCategorical(
    n_hidden=n_hidden,
    n_categories_per_column=n_categories_per_column,
    n_latent=n_latent,
    means=means,
    stds=stds,
    do_impute=False,
    use_batch_norm=True,
)

In [None]:
model.decoder

In [None]:
learning_rate = 0.1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss = rnnm_tvae.KullbackLeiblerNumericalAndCategoricalLoss(
    n_categories_per_column
)
loss_callback = rnnm_learner.TrainLossCallback()

save_dir = Path("./models")

callbacks = [loss_callback]

In [None]:
learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
    show_epoch_progress=True,
)

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 100, 100)

learner.find_learning_rate(
    dl_train, n_epochs=10, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot(yscale="log", ylim=(4e3, 1e5))

In [None]:
learning_rate = 2e-1
n_epochs = 5

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=learning_rate,
    epochs=n_epochs,
    steps_per_epoch=len(dl_train),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(dl_train, n_epochs=n_epochs, dataloader_valid=dl_valid)

In [None]:
loss_callback.plot(yscale="log")

In [None]:
X_preds_num, X_preds_cat_probs = learner.predict(dl_valid, component=[0, 1])
X_preds_num = X_preds_num.detach().numpy()
X_preds_cat_probs = X_preds_cat_probs.detach()
X_preds_num[:3], X_preds_cat_probs[:3]

In [None]:
X_preds_cat = rnnm_tvae.transform_X_cat_probs_to_classes(
    X_preds_cat_probs, n_categories_per_column
)
X_preds_cat = X_preds_cat.numpy()
X_preds_cat[:5]

In [None]:
import seaborn as sns

ax = sns.scatterplot(
    x=X1_num[:, 0], y=X1_num[:, 1], alpha=0.3, color="black", label="orig"
)
sns.scatterplot(
    x=X_preds_num[:, 0],
    y=X_preds_num[:, 1],
    alpha=0.3,
    color="orange",
    label="vae",
    ax=ax,
)
ax.legend()

In [None]:
sns.scatterplot(
    x=X_preds_num[:, 0], y=X_preds_num[:, 1], alpha=0.3, hue=X_preds_cat[:, 0]
)