Based on `tabular-variational-auto-encoder.ipynb`

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sklearn.datasets as sk_data
import sklearn.model_selection as sk_selection
import random_neural_net_models.tabular as rnnm_tab
import torch.optim as optim
import random_neural_net_models.learner as rnnm_learner
import random_neural_net_models.data as rnnm_data
import random_neural_net_models.utils as rnnm_utils
from pathlib import Path
import random_neural_net_models.tabular_vae as rnnm_tvae
from torch.utils.data import DataLoader
import random_neural_net_models.losses as rnnm_loss
import sklearn.metrics as sk_metrics

In [None]:
rnnm_utils.make_deterministic(42)

In [None]:
device = rnnm_utils.get_device()
device

In [None]:
n_samples = 1_000
n_features = 3
n_classes = 2
X, y = sk_data.make_blobs(
    n_samples=n_samples,
    n_features=n_features,
    random_state=42,
    centers=n_classes,
)

In [None]:
X0, X1, y0, y1 = sk_selection.train_test_split(X, y, test_size=0.2)

In [None]:
X0.shape, X1.shape

## Training the VAE

In [None]:
ds_train_vae = rnnm_data.NumpyInferenceDataset(X0)
ds_valid_vae = rnnm_data.NumpyInferenceDataset(X1)

In [None]:
ds_train_vae[2]

In [None]:
from torch.utils.data import RandomSampler
import torch

In [None]:
batch_size_vae = 50

sampler_vae = RandomSampler(
    ds_train_vae,
    replacement=True,
    num_samples=int(1e5),
    generator=torch.manual_seed(42),
)

dl_train_vae = DataLoader(
    ds_train_vae,
    batch_size=batch_size_vae,
    sampler=sampler_vae,
    collate_fn=rnnm_data.collate_numpy_dataset_to_xblock,
    drop_last=True,
)
dl_valid_vae = DataLoader(
    ds_valid_vae,
    batch_size=batch_size_vae,
    collate_fn=rnnm_data.collate_numpy_dataset_to_xblock,
)

In [None]:
next(iter(dl_train_vae))

In [None]:
means = X0.mean(axis=0)
means

In [None]:
stds = X0.std(axis=0)
stds

In [None]:
n_features = X0.shape[1]
n_hidden = [n_features, 3, 3]
do_impute = False
impute_bias_source = rnnm_tab.BiasSources.zero
n_latent = 2

model_vae = rnnm_tvae.TabularVariationalAutoEncoderNumerical(
    n_hidden=n_hidden,
    n_latent=n_latent,
    means=means,
    stds=stds,
    do_impute=False,
    use_batch_norm=True,
)

In [None]:
learning_rate_classy = 0.1
optimizer_vae = optim.Adam(model_vae.parameters(), lr=learning_rate_classy)
loss_vae = rnnm_tvae.KullbackLeiblerNumericalOnlyLoss()
loss_callback = rnnm_learner.TrainLossCallback()

save_dir = Path("./models")

callbacks = [loss_callback]

In [None]:
learner_vae = rnnm_learner.Learner(
    model_vae,
    optimizer_vae,
    loss_vae,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
    show_epoch_progress=True,
)

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 100, 100)

learner_vae.find_learning_rate(
    dl_train_vae, n_epochs=10, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot(yscale="log", ylim=(1e3, 5e4))

In [None]:
learning_rate_vae = 2e-1
n_epochs_vae = 5

scheduler_vae = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer_vae,
    max_lr=learning_rate_vae,
    epochs=n_epochs_vae,
    steps_per_epoch=len(dl_train_vae),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler_vae)
learner_vae.update_callback(scheduler_callback)

In [None]:
learner_vae.fit(
    dl_train_vae, n_epochs=n_epochs_vae, dataloader_valid=dl_valid_vae
)

In [None]:
loss_callback.plot(yscale="log")

In [None]:
X_preds = learner_vae.predict(dl_valid_vae, component=0).detach().numpy()
X_preds[:3]

In [None]:
import seaborn as sns

ax = sns.scatterplot(
    x=X1[:, 0], y=X1[:, 1], alpha=0.3, color="black", label="orig"
)
sns.scatterplot(
    x=X_preds[:, 0],
    y=X_preds[:, 1],
    alpha=0.3,
    color="orange",
    label="vae",
    ax=ax,
)
ax.legend()

## Re-using the trained Encoder for classification

In [None]:
ds_train_classy = rnnm_data.NumpyTrainingDataset(X0, y0.astype(int))
ds_valid_classy = rnnm_data.NumpyTrainingDataset(X1, y1.astype(int))

In [None]:
batch_size_classy = 50

dl_train_classy = DataLoader(
    ds_train_classy,
    batch_size=batch_size_classy,
    collate_fn=rnnm_data.collate_numpy_dataset_to_xyblock_keep_orig_y,
)
dl_valid_classy = DataLoader(
    ds_valid_classy,
    batch_size=batch_size_classy,
    collate_fn=rnnm_data.collate_numpy_dataset_to_xyblock_keep_orig_y,
)

In [None]:
next(iter(dl_train_classy))

In [None]:
model_classy = rnnm_tvae.TabularModelReusingTrainedEncoder(
    pretrained_encoder=model_vae.encoder, n_out=n_classes, use_batch_norm=False
)

In [None]:
learning_rate_classy = 0.1
optimizer_classy = optim.Adam(
    model_classy.parameters(), lr=learning_rate_classy
)
loss_classy = rnnm_loss.CrossEntropyXy()
loss_callback = rnnm_learner.TrainLossCallback()

save_dir = Path("./models")

callbacks = [loss_callback]

In [None]:
learner_classy = rnnm_learner.Learner(
    model_classy,
    optimizer_classy,
    loss_classy,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
    show_epoch_progress=True,
)

In [None]:
import copy

orig_vae_weights = copy.deepcopy(
    model_classy.pretrained_encoder.net[0].lin.weight.detach().cpu()
)
orig_vae_weights

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 100, 100)

learner_classy.find_learning_rate(
    dl_train_classy, n_epochs=10, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot(yscale="log")

In [None]:
learning_rate_classy = 0.1
n_epochs_classy = 5

scheduler_classy = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer_classy,
    max_lr=learning_rate_classy,
    epochs=n_epochs_classy,
    steps_per_epoch=len(dl_train_classy),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler_classy)
learner_classy.update_callback(scheduler_callback)

In [None]:
learner_classy.fit(
    dl_train_classy, n_epochs=n_epochs_classy, dataloader_valid=dl_valid_classy
)

In [None]:
loss_callback.plot()

In [None]:
probs = (
    learner_classy.predict(dl_valid_classy)
    .detach()
    .softmax(dim=1)
    .numpy()[:, 1]
)
probs[:3]

In [None]:
print(sk_metrics.roc_auc_score(y1, probs))

In [None]:
orig_vae_weights

In [None]:
final_vae_weights = copy.deepcopy(
    model_classy.pretrained_encoder.net[0].lin.weight.detach().cpu()
)
final_vae_weights

In [None]:
assert torch.allclose(orig_vae_weights, final_vae_weights)