# Conditional Variational Autoencoders and Exploring Autoencoder Latent Spaces

This is a small notebook and experiment to show how to use conditional generation and optimise the latent space w.r.t. some sort of feature.

This is partly based on the blogpost [Conditional generation via Bayesian optimization in latent space by Martin Krasser](https://krasserm.github.io/2018/04/07/latent-space-optimization/).

## Configuration

### Environment

We use [miniconda](https://docs.conda.io/en/latest/miniconda.html) (and we recommend also having [mamba](https://mamba.readthedocs.io/en/latest/installation.html#installation)) to setup the environment and we have only tested this way. If you want to install the packages through pip you are on your own.

We provide a simple setup script that checks your system for GPU and CUDA versions using `nvidia-smi`, (re-)creates the environment using `conda`, and installs the packages according to your system using either `mamba` or `conda`. To use this script to setup your environment, you only need to run:
```sh
chmod a+x cond_vae_setup.sh
./cond_vae_setup.sh
```

Note that we only tested our environment on a NVIDIA RTX A4000 on Ubuntu 22.04.02, with driver version 515.105.01, NVIDIA-SMI 515.105.01, and CUDA Version 11.7. If you find any problems with other setups feel free to raise an issue.

All our `.yml` files contain major and minor versions for libraries, with debug versions of some of the libraries. If you can't find a combination for your system, try [relaxing the versions](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#create-env-file-manually). However, it might be possible that this notebook won't work with some of the relaxed versions.

### Folders and settings

The strings below define where we will search for and store data, results, images, and models. Please change it if you want to use other folders than the default ones. Some of the libraries we use might have library-specific defaults which can be shared among environments, and we keep those separate.

Also, we have some variables that define behavious along the script, for example defining which images extensions to save, whether to show images in the notebook, etc...

In [None]:
# Folders to use
data_folder = "~/data"
results_folder = "./cond_vae/results"
models_folder = "./cond_vae/models"
images_folder = "./cond_vae/images"

# Which dataset to use MNIST/FASHION MNIST
DATASET_TO_USE = "MNIST"

# Autoencoder training hyperparameters
N_EPOCHS = 16
BATCH_SIZE = 64
AE_LATENT_DIM = 2
assert AE_LATENT_DIM==2, NotImplementedError("The notebook has not been adapted to dimensions other than 2 yet")

# External classifier (For bayesian exploration) training hyperparameters
CLF_N_EPOCHS = 16
CLF_BATCH_SIZE = 64
BAYESIAN_OPTIMISATION_STEPS = 64

# Point-grid and KDE generative exploration hyperparameters
NUM_POINTS_GRID = 128
NUMBER_OF_SAMPLES_PER_CLASS = 4
TOP_PCT_TO_SAMPLE_FROM = 0.01
SOFTMAX_REPARAM_TEMPERATURE = 1

# Image variables
YLABEL_FONTSIZE = 6
COORDS_FONTSIZE = 8
SHOW_IMAGES = True
SAVE_IMAGES = True
# I do not recommend saving vectorial images, as they become quite large with the amount of points being plotted.
IMAGE_FORMATS = ["jpg", "png"]
SHOW_ACQUISITIONS = False

We create the folders we might be using for this notebook

In [None]:
import os
import os.path as osp

data_folder, results_folder, models_folder, images_folder = map(
    osp.expanduser,
    map(
        osp.expandvars,
        (data_folder, results_folder, models_folder, images_folder)
    )
)

for f in [data_folder, results_folder, models_folder]:
    os.makedirs(f, exist_ok=True)

for fmt in IMAGE_FORMATS:
    os.makedirs(osp.join(images_folder,fmt), exist_ok=True)

### Imports

In [None]:
from more_itertools import interleave, take
from itertools import chain

In [None]:
from tqdm.autonotebook import tqdm

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import scipy.stats as sps

In [None]:
# Patch numpy bool back for GPyOpt
np.bool = bool
import GPyOpt

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, FashionMNIST, KMNIST
from torchvision.transforms import ToTensor
from medmnist import OrganAMNIST

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

### Downloading the datasets

In [None]:
match DATASET_TO_USE.lower():
    case "mnist":
        mnist_train = MNIST(root=data_folder, download=True, train=True, transform=ToTensor(),)
        mnist_test = MNIST(root=data_folder, download=True, train=False, transform=ToTensor(),)
        mnist_n_channels = 1
    case "kmnist":
        mnist_train = KMNIST(root=data_folder, download=True, train=True, transform=ToTensor(),)
        mnist_test = KMNIST(root=data_folder, download=True, train=False, transform=ToTensor(),)
        mnist_n_channels = 1
    case "fashionmnist":
        mnist_train = FashionMNIST(root=data_folder, download=True, train=True, transform=ToTensor(),)
        mnist_test = FashionMNIST(root=data_folder, download=True, train=True, transform=ToTensor(),)
        mnist_n_channels = 1
    case "organamnist":
        mnist_train = OrganAMNIST(root=data_folder, download=True, split="train", transform=ToTensor(),)
        mnist_test = OrganAMNIST(root=data_folder, download=True, split="val", transform=ToTensor(),)
        for dset in [mnist_train,mnist_test]:
            dset.data = dset.imgs
            dset.labels = dset.labels.squeeze()
            dset.targets = torch.tensor(dset.labels)
            dset.classes = [dset.info["label"][str(i)] for i in range(len(dset.info["label"]))]
        mnist_n_channels = 1
    case _:
        raise ValueError(f"Specified dataset {DATASET_TO_USE} is not available")
    
tuple(map(lambda x: (x.data.shape, x.targets.shape), (mnist_train, mnist_test)))

In [None]:
data_shape = next(iter(mnist_train))[0].shape
assert data_shape[1] == data_shape[2], NotImplementedError(f"You tried plugging in a dataset that has non-square images (shape {data_shape}), which will break the code.")
dset_img_dim = data_shape[1]
assert (dset_img_dim%2==0 and dset_img_dim%4==0), NotImplementedError(f"Dataset dimensionality {dset_img_dim} must be divisible by 4")
dset_img_dim_half = dset_img_dim//2
dset_img_dim_fourth = dset_img_dim//4

In [None]:
class_names = np.asarray(mnist_train.classes)
mnist_n_classes = len(class_names)
class_names

### Bayesian Optimisation Helper

In [None]:
def nll_optimizer_for(
                target:int,
                decoder:nn.Module,
                classifier:nn.Module,
                bounds:list[tuple[float,float]],
                normal_means:list[float]|np.ndarray = [0],
                normal_covs:str|list[list[float]]|np.ndarray = "eye",
                ):
    """
    Returns a GPyOpt bayesian optimiser that searches inside our latent space for values which,
    after being reconstructed by our decoder, are correctly classified by our optimiser. 
    """
    
    bounds = [
        {"name": f"t{i+1}", "type": "continuous", "domain": (i_min, i_max)}
        for i, (i_min, i_max) in enumerate(bounds)
    ]

    if isinstance(normal_means, list) and len(normal_means)==1:
        normal_means = normal_means * len(bounds)
    else:
        try:
            normal_means = np.asarray(normal_means)
        except (ValueError, NotImplementedError) as e:
            raise ValueError(f"Could not convert normal_means of type {type(normal_means)} to a numpy array. Original Exception: {e}")
        if len(normal_means.shape)==0 or (np.prod(normal_means.shape)==1):
            normal_means = [normal_means.item()]
        elif np.prod(normal_means.shape)==len(bounds):
            normal_means = normal_means.squeeze()
        else:
            raise ValueError(f"normal_means array has shape {normal_means.shape} which is incompatible with the provided number of bounds {len(bounds)}")

    if isinstance(normal_covs, str):
        match normal_covs:
            case "eye":
                normal_covs = np.eye(len(bounds))
            case _:
                raise ValueError(f"normal_covs string method {normal_covs} unrecognised. See the function definition for available types")
    elif isinstance(normal_covs, list):
        if len(normal_covs)==len(bounds):
            for i, row in enumerate(normal_covs):
                assert isinstance(row,list), ValueError(f"normal_covs must be 2 dimensional, but its {i}th value was not a list.")
                assert len(row)==len(bounds), ValueError(f"normal_covs must be 2 dimensional, but its {i}th value had length {len(row)} instead of {len(bounds)}.")
        else:
            raise ValueError(f"normal_covs has incompatible length {len(normal_covs)} with the bounds {len(bounds)}")
    else:
        try:
            normal_covs = np.asarray(normal_covs)
        except (ValueError, NotImplementedError) as e:
            raise ValueError(f"Could not convert normal_means of type {type(normal_covs)} to a numpy array. Original Exception: {e}")
        assert (len(normal_covs.shape) == 2 and all((s == len(bounds) for s in normal_covs.shape))), ValueError(f"normal_covs must a {len(bounds)}×{len(bounds)} matrix, but its shape was {normal_covs.shape}.")
        
    mvn = sps.multivariate_normal(mean=normal_means, cov=normal_covs)
    found_device = next(decoder.parameters()).device
    found_dtype = next(decoder.parameters()).dtype

    # Create a negative_log_likelihood function for our target as a closure
    def negative_log_likelihood(t):
        """
        Our Bayesian optimisation objective which is a negative likelihood function that
        uses our decoder to decode the latent space being explored and then uses our
        classifier to classify it.
        """
        with torch.no_grad():
            # Decode latent vector into image
            decoded = decoder(torch.tensor(t, device=found_device, dtype=found_dtype))
            # Predict probabilities with separate classifier
            c_probs = torch.softmax(classifier(decoded), 1).detach().cpu().numpy()

        nll_prior = -mvn.logpdf(t).reshape(-1, 1)
        # Get the negative log likelihood for our predicted probability (plus an epsilon for stability)
        nll_pred = -np.log(c_probs[:,target] + 1e-8).reshape(-1, 1)
        
        return nll_prior + nll_pred

    return GPyOpt.methods.BayesianOptimization(f=negative_log_likelihood, 
                                domain=bounds,
                                model_type='GP',
                                acquisition_type ='EI',
                                acquisition_jitter = 0.01,
                                initial_design_numdata = 2,
                                exact_feval=False)

### Helper Classes

In [None]:
import math

class Conv2dKeras(torch.nn.Conv2d):
    """
    Conv2d with keras behaviour for padding="same"
    From: https://stackoverflow.com/a/73332370
    """
    def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
        return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ih, iw = x.size()[-2:]

        pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
        pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])

        print(x.shape)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(
                x,
                [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
                mode=("constant" if self.padding_mode == "zeros" else self.padding_mode),
            )
        print(x.shape)
        return F.conv2d(
            x,
            self.weight,
            self.bias,
            self.stride,
            dilation=self.dilation,
            groups=self.groups,
        )

import math

class ConvTranspose2dKeras(torch.nn.Conv2d):
    """
    ConvTranspose2d with keras behaviour for padding="same"
    From: https://stackoverflow.com/a/73332370
    """
    def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
        return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ih, iw = x.size()[-2:]

        pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
        pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])

        print(x.shape)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(
                x,
                [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
                mode=("constant" if self.padding_mode == "zeros" else self.padding_mode),
            )
        print(x.shape)
        return F.conv_transpose2d(
            x,
            self.weight,
            self.bias,
            self.stride,
            dilation=self.dilation,
            groups=self.groups,
        )

In [None]:
class NopLayer(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(NopLayer,self).__init__()
    
    def forward(self, x):
        return x

class MLP(nn.Module):
    def __init__(self,
                 in_features:int,
                 features_defs:list[int],
                 nonlinearities:list[nn.Module] = [nn.ReLU],
                 drop_last_nonlinearity = False,
                 ):
        """
        TODO: Describe
        """
        super(MLP,self).__init__()
        n_layers = len(features_defs)
        in_features_defs = [in_features] + features_defs[:-1]
        out_features_defs = features_defs
        if len(nonlinearities) == 1: nonlinearities = nonlinearities*n_layers

        assert all(map(lambda l: len(l)==n_layers, [in_features_defs, out_features_defs, nonlinearities]))

        if len(nonlinearities)==n_layers and drop_last_nonlinearity:
            nonlinearities[-1] = NopLayer

        self.mlp = nn.Sequential(
            *interleave(
                (
                    nn.Linear(
                        in_features=d_i,
                        out_features=d_o,
                    ) for d_i, d_o in zip(
                        in_features_defs,
                        out_features_defs,
                    )
                ),
                [Nl() for Nl in nonlinearities],
            )
        )

    def forward(self, x):
        return self.mlp(x)


In [None]:
class CNN(nn.Module):
    def __init__(self,
                 in_channels:int,
                 channel_defs:list[int],
                 kernel_defs:list[int|tuple[int]],
                 stride_defs:list[int|tuple[int]] = [1],
                 padding_defs:list[str|int|tuple[int]] = ["same"],
                 padding_mode_defs:list[str] = ["zeros"],
                 nonlinearities:list[nn.Module] = [nn.ReLU],
                 post_nonlinearity_defs:list[nn.Module] = [NopLayer],
                 post_nonlinearity_defs_args:list[list] = [[]],
                 post_nonlinearity_defs_kwargs:list[dict] = [{}],
                 ):
        """
        TODO: Describe
        """
        super(CNN,self).__init__()
        n_convs = len(channel_defs)
        in_channel_defs = [in_channels] + channel_defs[:-1]
        out_channel_defs = channel_defs
        if len(kernel_defs) == 1: kernel_defs = kernel_defs*n_convs
        if len(stride_defs) == 1: stride_defs = stride_defs*n_convs
        if len(padding_defs) == 1: padding_defs = padding_defs*n_convs
        if len(padding_mode_defs) == 1: padding_mode_defs = padding_mode_defs*n_convs
        if len(nonlinearities) == 1: nonlinearities = nonlinearities*n_convs
        if len(post_nonlinearity_defs) == 1: post_nonlinearity_defs = post_nonlinearity_defs*n_convs
        if len(post_nonlinearity_defs_args) == 1: post_nonlinearity_defs_args = post_nonlinearity_defs_args*n_convs
        if len(post_nonlinearity_defs_kwargs) == 1: post_nonlinearity_defs_kwargs = post_nonlinearity_defs_kwargs*n_convs

        self.cnn = nn.Sequential(
            *interleave(
                (
                    nn.Conv2d(
                        in_channels=d_i,
                        out_channels=d_o,
                        kernel_size=k,
                        stride=s,
                        padding=p,
                        padding_mode=pm
                    ) for d_i, d_o, k, s, p, pm in zip(
                        in_channel_defs,
                        out_channel_defs,
                        kernel_defs,
                        stride_defs,
                        padding_defs,
                        padding_mode_defs,
                    )
                ),
                [Nl() for Nl in nonlinearities],
                [
                    Layer(
                        *layer_args,
                        **layer_kwargs
                    ) for Layer, layer_args, layer_kwargs in zip(
                        post_nonlinearity_defs,
                        post_nonlinearity_defs_args,
                        post_nonlinearity_defs_kwargs,
                    )   
                ],
            )
        )

    def forward(self, img):
        return self.cnn(img)

In [None]:
class DeCNN(nn.Module):
    def __init__(self,
                 in_channels:int,
                 channel_defs:list[int],
                 kernel_defs:list[int|tuple[int]],
                 stride_defs:list[int|tuple[int]] = [1],
                 padding_defs:list[str|int|tuple[int]] = ["same"],
                 output_padding_defs:list[str|int|tuple[int]] = [0],
                 padding_mode_defs:list[str] = ["zeros"],
                 nonlinearities:list[nn.Module] = [nn.ReLU],
                 ):
        """
        TODO: Describe
        """
        super(DeCNN,self).__init__()
        n_convs = len(channel_defs)
        in_channel_defs = [in_channels] + channel_defs[:-1]
        out_channel_defs = channel_defs
        if len(kernel_defs) == 1: kernel_defs = kernel_defs*n_convs
        if len(nonlinearities) == 1: nonlinearities = nonlinearities*n_convs
        if len(stride_defs) == 1: stride_defs = stride_defs*n_convs
        if len(padding_defs) == 1: padding_defs = padding_defs*n_convs
        if len(padding_mode_defs) == 1: padding_mode_defs = padding_mode_defs*n_convs

        self.decnn = nn.Sequential(
            *interleave(
                (
                    nn.ConvTranspose2d(
                        in_channels=d_i,
                        out_channels=d_o,
                        kernel_size=k,
                        stride=s,
                        padding=p,
                        output_padding=op,
                        padding_mode=pm
                    ) for d_i, d_o, k, s, p, op, pm in zip(
                        in_channel_defs,
                        out_channel_defs,
                        kernel_defs,
                        stride_defs,
                        padding_defs,
                        output_padding_defs,
                        padding_mode_defs,
                    )
                ),
                [Nl() for Nl in nonlinearities],
            )
        )

    def forward(self, img):
        return self.decnn(img)

## CNN Classifier

In [None]:
class Classifier(nn.Module):
    def __init__(self, in_channels = 1, n_classes=10):
        super(Classifier, self).__init__()

        self.clf = nn.Sequential(
            CNN(
                in_channels,
                [32,64,64,],
                [3,3,3,],
                [1,1,1,],
                [1,1,1,],
                post_nonlinearity_defs=[
                    nn.MaxPool2d,
                    nn.MaxPool2d,
                    nn.Flatten
                ],
                post_nonlinearity_defs_args=[
                    [(2,2)],
                    [(2,2)],
                    [],
                ]
            ),
            MLP(
                64*dset_img_dim_fourth*dset_img_dim_fourth,
                [64,mnist_n_classes],
                drop_last_nonlinearity=True
            )
        )
        
    def forward(self, x):
        return self.clf(x)

In [None]:
model_name = "ext_clf"

ext_clf = Classifier(1)
if torch.cuda.is_available():
    ext_clf.cuda()

ext_clf_opt = optim.Adam(ext_clf.parameters())

xe_loss = nn.CrossEntropyLoss()
total_loss = lambda xe_loss: xe_loss

train_history = {
    "epoch": [],
    "batch": [],
    "clf_xe_loss": [],
    "clf_acc": [],
    "total_loss": [],
}
test_history = {
    "epoch": [],
    "batch": [],
    "clf_xe_loss": [],
    "clf_acc": [],
    "total_loss": [],
}

for e in tqdm(range(CLF_N_EPOCHS)):
    with torch.no_grad():
        for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=CLF_BATCH_SIZE, shuffle=False)):
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
            y_hat = ext_clf(x)
            xe_l = xe_loss(y_hat,y)
            total_l = total_loss(xe_l)
            acc = (y_hat.argmax(dim=1)==y).to(torch.float).mean()
            test_history["epoch"].append(e)
            test_history["batch"].append(b)
            test_history["clf_xe_loss"].append(xe_l.detach().cpu().numpy().mean().item())
            test_history["clf_acc"].append(acc.detach().cpu().numpy().mean().item())
            test_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

    for b, (x, y) in enumerate(DataLoader(mnist_train, batch_size=CLF_BATCH_SIZE, shuffle=True)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        ext_clf_opt.zero_grad()
        y_hat = ext_clf(x)
        xe_l = xe_loss(y_hat,y)
        total_l = total_loss(xe_l)
        total_l.backward()
        ext_clf_opt.step()
        with torch.no_grad():
            acc = (y_hat.argmax(dim=1)==y).to(torch.float).mean()
        train_history["epoch"].append(e)
        train_history["batch"].append(b)
        train_history["clf_xe_loss"].append(xe_l.detach().cpu().numpy().mean().item())
        train_history["clf_acc"].append(acc.detach().cpu().numpy().mean().item())
        train_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=CLF_BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        y_hat = ext_clf(x)
        xe_l = xe_loss(y_hat,y)
        total_l = total_loss(xe_l)
        acc = (y_hat.argmax(dim=1)==y).to(torch.float).mean()
        test_history["epoch"].append(N_EPOCHS)
        test_history["batch"].append(b)
        test_history["clf_xe_loss"].append(xe_l.detach().cpu().numpy().mean().item())
        test_history["clf_acc"].append(acc.detach().cpu().numpy().mean().item())
        test_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

In [None]:
y

In [None]:
train_df = pd.DataFrame(train_history)
train_df["batch_in_epoch"] = train_df["epoch"] + train_df["batch"]/train_df["batch"].max()
test_df = pd.DataFrame(test_history)

In [None]:
loss_type="total_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} Final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
loss_type="clf_xe_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
loss_type="clf_acc"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} Final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

## Vanilla Autoencoder

In [None]:
class MNIST_AE(nn.Module):
    def __init__(self, in_channels = 1, hidden_channels = 2):
        super(MNIST_AE, self).__init__()

        self.encoder = nn.Sequential(
            CNN(in_channels,[32,64,64,64],[3,3,3,3],[1,2,1,1],[1,1,1,1]),
            nn.Flatten(),
            nn.Linear(64*dset_img_dim_half*dset_img_dim_half,32),
            nn.ReLU(),
            nn.Linear(32,hidden_channels),
        )

        self.decoder = nn.Sequential(
            nn.Linear(hidden_channels,64*dset_img_dim_half*dset_img_dim_half),
            nn.ReLU(),
            nn.Unflatten(1,(64,14,14)),
            DeCNN(64,[32],[3],[2],[1],[1]),
            CNN(32,[in_channels],[3],[1],[1],nonlinearities=[nn.Sigmoid])
        )
        
    def forward(self, x):
        return self.decoder(self.encoder(x))


### Unconstrained Autoencoder

In [None]:
model_name = "U-AE"

model = MNIST_AE(mnist_n_channels,AE_LATENT_DIM)
if torch.cuda.is_available(): model.cuda()

opt = optim.Adam(model.parameters())

bce_loss = lambda inp,tgt: torch.sum(F.binary_cross_entropy(inp,tgt,reduction="none"), dim=list(range(1,len(inp.shape))))
total_loss = lambda bce_l: torch.mean(torch.concat([bce_l,]))

train_history = {
    "epoch": [],
    "batch": [],
    "rec_bce_loss": [],
    "total_loss": [],
}
test_history = {
    "epoch": [],
    "batch": [],
    "rec_bce_loss": [],
    "total_loss": [],
}

for e in tqdm(range(N_EPOCHS)):
    with torch.no_grad():
        for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
            x_hat = model(x)
            bce_l = bce_loss(x_hat,x)
            total_l = total_loss(bce_l)
            test_history["epoch"].append(e)
            test_history["batch"].append(b)
            test_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
            test_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

    for b, (x, y) in enumerate(DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        opt.zero_grad()
        x_hat = model(x)
        bce_l = bce_loss(x_hat,x)
        total_l = total_loss(bce_l)
        total_l.backward()
        opt.step()
        train_history["epoch"].append(e)
        train_history["batch"].append(b)
        train_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
        train_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        x_hat = model(x)
        bce_l = bce_loss(x_hat,x)
        total_l = total_loss(bce_l)
        test_history["epoch"].append(N_EPOCHS)
        test_history["batch"].append(b)
        test_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
        test_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

In [None]:
train_df = pd.DataFrame(train_history)
train_df["batch_in_epoch"] = train_df["epoch"] + train_df["batch"]/train_df["batch"].max()
test_df = pd.DataFrame(test_history)

In [None]:
loss_type="total_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} Final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
loss_type="rec_bce_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
zs = []
with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        zs.append(model.encoder(x).detach().cpu().numpy())
z_train = np.concatenate(tuple(zs),axis=0)
z_train.shape, mnist_train.targets.numpy().shape

In [None]:
zs = []
with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        zs.append(model.encoder(x).detach().cpu().numpy())
z_test = np.concatenate(tuple(zs),axis=0)
z_test.shape, mnist_test.targets.numpy().shape

In [None]:
fig, axes = plt.subplots(1,2, sharex=True, sharey=True)

x_train, y_train = z_train[:,0], z_train[:,1]
hue_train = mnist_train.targets.numpy()

x_test, y_test = z_test[:,0], z_test[:,1]
hue_test = mnist_test.targets.numpy()

sns.scatterplot(x=x_train, y=y_train, hue=class_names[hue_train], hue_order=class_names, marker="x", s=4, ax=axes[0], legend=False,)
sns.histplot(x=x_train, y=y_train, bins=64, pthresh=0.01, hue=class_names[hue_train], hue_order=class_names, ax=axes[0], legend=False,)
sns.kdeplot(x=x_train, y=y_train, levels=5, hue=class_names[hue_train], hue_order=class_names, linewidths=1, ax=axes[0], legend=False,)

sns.scatterplot(x=x_test, y=y_test, hue=class_names[hue_test], hue_order=class_names, marker="x", s=4, ax=axes[1],)
sns.histplot(x=x_test, y=y_test, bins=64, pthresh=0.01, hue=class_names[hue_test], hue_order=class_names, ax=axes[1],)
sns.kdeplot(x=x_test, y=y_test, levels=5, hue=class_names[hue_test], hue_order=class_names, linewidths=1, ax=axes[1],)

sns.move_legend(axes[1], "upper left", bbox_to_anchor=(1, 0.75),)

plt.suptitle(f"{model_name} latent space distribution")
axes[0].set_title(f"Train")
axes[1].set_title(f"Test")

if SAVE_IMAGES:
    xlim = axes[0].get_xlim()
    ylim = axes[0].get_ylim()
    axes[0].set_xlim(max(xlim[0],-20),min(xlim[1],20))
    axes[0].set_ylim(max(ylim[0],-20),min(ylim[1],20))
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-z-zoom.{fmt}"), bbox_inches="tight",)
    axes[0].set_xlim(xlim)
    axes[0].set_ylim(ylim)
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-z.{fmt}"), bbox_inches="tight",)

if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
kdes = {}
for ci, c in enumerate(class_names):
    kdes[c] = sps.gaussian_kde(z_train[hue_train==ci,].T)

In [None]:
mins, maxs = z_train.min(axis=0,keepdims=True), z_train.max(axis=0,keepdims=True)
xx, yy = np.meshgrid(*[np.linspace(mins[:,i],maxs[:,i],NUM_POINTS_GRID) for i in range(2)])
points = np.vstack([xx.ravel(), yy.ravel()])

In [None]:
probs = []
for ci, c in enumerate(class_names):
    probs.append(kdes[c](points))
probs = np.vstack(probs).T

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Samples from point grid 1-class KDEs")
axes[0,0].set_title("Best")
for ci, c in enumerate(class_names):
    c_prob = probs[:,ci]
    top_threshold = np.quantile(c_prob[c_prob>0],TOP_PCT_TO_SAMPLE_FROM)
    supersample_size = max((c_prob>top_threshold).sum(),NUMBER_OF_SAMPLES_PER_CLASS)
    c_prob_argsort = c_prob.argsort()
    best_and_rest_from_top_pct = [
        c_prob_argsort[-1],
        *take(
            NUMBER_OF_SAMPLES_PER_CLASS-1,
            [
                x for x in np.random.choice(
                    c_prob_argsort[-supersample_size:],
                    NUMBER_OF_SAMPLES_PER_CLASS,
                    replace=supersample_size==NUMBER_OF_SAMPLES_PER_CLASS,
                ) if x != c_prob_argsort[-1] or supersample_size==NUMBER_OF_SAMPLES_PER_CLASS
            ]
        )
    ]
    z_np = points.T[best_and_rest_from_top_pct,:]
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-pointgrid-1c.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
softmax_probs = torch.softmax(torch.tensor(np.log(probs))/SOFTMAX_REPARAM_TEMPERATURE,1).numpy()
softmax_classes = class_names[np.argmax(softmax_probs*probs, axis=1)]
pd.value_counts(softmax_classes)

In [None]:
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from softmax of point grid 1-class KDE")
for ci, c in enumerate(class_names):
    indexes_of_this_class = np.logical_and(
        softmax_classes==c,
        probs[:,ci]>0
    )
    num_points_in_this_class = indexes_of_this_class.sum()
    num_show = min(NUMBER_OF_SAMPLES_PER_CLASS,num_points_in_this_class)
    if num_points_in_this_class>0:
        points_in_this_class = points.T[indexes_of_this_class,:]
        random_points_from_this_class = points_in_this_class[
            np.random.choice(
                num_points_in_this_class,
                num_show,
                replace=False,
            )
        ]
        z_np = random_points_from_this_class
        z = torch.tensor(z_np)
        if torch.cuda.is_available(): z = z.cuda()
        x_hat = model.decoder(z).detach().cpu().numpy()
        x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(
            (
                x_hat[pi]
                if pi<num_show else
                np.zeros((dset_img_dim,dset_img_dim,1))
            ),
            vmin=0, vmax=1, cmap="gray"
        )
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        if pi<num_show:
            ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
                size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)

if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-pointgrid-1c-softmax.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from the KDE estimators")
for ci, c in enumerate(class_names):
    z_np = kdes[c].resample(NUMBER_OF_SAMPLES_PER_CLASS).T.astype(np.float32)
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-kde-1c.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
bayopts = {}
for ci, c in enumerate(class_names):
    bayopts[ci] = nll_optimizer_for(ci,
                                model.decoder,
                                ext_clf,
                                bounds=[
                                    (min(mins[0,i],-4.0), max(maxs[0,i],4.0))
                                    for i in range(mins.shape[1])
                                ],
                                )
    bayopts[ci].run_optimization(max_iter=BAYESIAN_OPTIMISATION_STEPS)

In [None]:
for ci, c in enumerate(class_names):
    if SAVE_IMAGES:
        # Why do they show the image inside the function???
        for fmt in IMAGE_FORMATS:
            bayopts[ci].plot_acquisition(filename=osp.join(images_folder,fmt,f"{model_name}-bayes-acquisition-{ci}.{fmt}"))
            plt.close()
    if SHOW_ACQUISITIONS:
        print(c)
        bayopts[ci].plot_acquisition()
        plt.close()

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from the Bayesian Optimisation")
for ci, c in enumerate(class_names):
    top_idx = np.argsort(bayopts[ci].Y, axis=0).ravel()
    z_np = bayopts[ci].X[top_idx].astype(np.float32)
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-bayesian-extclf.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

### Constrained Autoencoder

In [None]:
model_name = "C-AE"

model = MNIST_AE(mnist_n_channels,AE_LATENT_DIM)
clf = MLP(AE_LATENT_DIM,[128,128,mnist_n_classes], drop_last_nonlinearity=True)
if torch.cuda.is_available():
    model.cuda()
    clf.cuda()

opt = optim.Adam(chain(model.parameters(),clf.parameters()))

bce_loss = lambda inp,tgt: torch.sum(F.binary_cross_entropy(inp,tgt,reduction="none"), dim=list(range(1,len(inp.shape))))
xe_loss = nn.CrossEntropyLoss()
total_loss = lambda bce_l, xe_loss: 20*xe_loss + torch.mean(torch.concat([bce_l,]))

train_history = {
    "epoch": [],
    "batch": [],
    "rec_bce_loss": [],
    "clf_xe_loss": [],
    "total_loss": [],
}
test_history = {
    "epoch": [],
    "batch": [],
    "rec_bce_loss": [],
    "clf_xe_loss": [],
    "total_loss": [],
}

for e in tqdm(range(N_EPOCHS)):
    with torch.no_grad():
        for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
            z = model.encoder(x)
            y_hat = clf(z)
            x_hat = model.decoder(z)
            bce_l = bce_loss(x_hat,x)
            xe_l = xe_loss(y_hat,y)
            total_l = total_loss(bce_l,xe_l)
            test_history["epoch"].append(e)
            test_history["batch"].append(b)
            test_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
            test_history["clf_xe_loss"].append(xe_l.detach().cpu().numpy().mean().item())
            test_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

    for b, (x, y) in enumerate(DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        opt.zero_grad()
        z = model.encoder(x)
        y_hat = clf(z)
        x_hat = model.decoder(z)
        bce_l = bce_loss(x_hat,x)
        xe_l = xe_loss(y_hat,y)
        total_l = total_loss(bce_l,xe_l)
        total_l.backward()
        opt.step()
        train_history["epoch"].append(e)
        train_history["batch"].append(b)
        train_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
        train_history["clf_xe_loss"].append(xe_l.detach().cpu().numpy().mean().item())
        train_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        z = model.encoder(x)
        y_hat = clf(z)
        x_hat = model.decoder(z)
        bce_l = bce_loss(x_hat,x)
        xe_l = xe_loss(y_hat,y)
        total_l = total_loss(bce_l,xe_l)
        test_history["epoch"].append(N_EPOCHS)
        test_history["batch"].append(b)
        test_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
        test_history["clf_xe_loss"].append(xe_l.detach().cpu().numpy().mean().item())
        test_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

In [None]:
train_df = pd.DataFrame(train_history)
train_df["batch_in_epoch"] = train_df["epoch"] + train_df["batch"]/train_df["batch"].max()
test_df = pd.DataFrame(test_history)

In [None]:
loss_type="total_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
loss_type="rec_bce_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
loss_type="clf_xe_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
zs = []
with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        zs.append(model.encoder(x).detach().cpu().numpy())
z_train = np.concatenate(tuple(zs),axis=0)
z_train.shape, mnist_train.targets.numpy().shape

In [None]:
zs = []
with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        zs.append(model.encoder(x).detach().cpu().numpy())
z_test = np.concatenate(tuple(zs),axis=0)
z_test.shape, mnist_test.targets.numpy().shape

In [None]:
fig, axes = plt.subplots(1,2, sharex=True, sharey=True)

x_train, y_train = z_train[:,0], z_train[:,1]
hue_train = mnist_train.targets.numpy()

x_test, y_test = z_test[:,0], z_test[:,1]
hue_test = mnist_test.targets.numpy()

sns.scatterplot(x=x_train, y=y_train, hue=class_names[hue_train], hue_order=class_names, marker="x", s=4, ax=axes[0], legend=False,)
sns.histplot(x=x_train, y=y_train, bins=64, pthresh=0.01, hue=class_names[hue_train], hue_order=class_names, ax=axes[0], legend=False,)
kde_train = sns.kdeplot(x=x_train, y=y_train, levels=5, hue=class_names[hue_train], hue_order=class_names, linewidths=1, ax=axes[0], legend=False,)

sns.scatterplot(x=x_test, y=y_test, hue=class_names[hue_test], hue_order=class_names, marker="x", s=4, ax=axes[1],)
sns.histplot(x=x_test, y=y_test, bins=64, pthresh=0.01, hue=class_names[hue_test], hue_order=class_names, ax=axes[1],)
kde_test = sns.kdeplot(x=x_test, y=y_test, levels=5, hue=class_names[hue_test], hue_order=class_names, linewidths=1, ax=axes[1],)

sns.move_legend(axes[1], "upper left", bbox_to_anchor=(1, 0.75),)

plt.suptitle(f"{model_name} latent space distribution")
axes[0].set_title(f"Train")
axes[1].set_title(f"Test")

if SAVE_IMAGES:
    xlim = axes[0].get_xlim()
    ylim = axes[0].get_ylim()
    axes[0].set_xlim(max(xlim[0],-20),min(xlim[1],20))
    axes[0].set_ylim(max(ylim[0],-20),min(ylim[1],20))
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-z-zoom.{fmt}"), bbox_inches="tight",)
    axes[0].set_xlim(xlim)
    axes[0].set_ylim(ylim)
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-z.{fmt}"), bbox_inches="tight",)

if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
kdes = {}
for ci, c in enumerate(class_names):
    kdes[c] = sps.gaussian_kde(z_train[hue_train==ci,].T)

In [None]:
mins, maxs = z_train.min(axis=0,keepdims=True), z_train.max(axis=0,keepdims=True)
xx, yy = np.meshgrid(*[np.linspace(mins[:,i],maxs[:,i],NUM_POINTS_GRID) for i in range(2)])
points = np.vstack([xx.ravel(), yy.ravel()])

In [None]:
probs = []
for ci, c in enumerate(class_names):
    probs.append(kdes[c](points))
probs = np.vstack(probs).T

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Samples from point grid 1-class KDEs")
axes[0,0].set_title("Best")
for ci, c in enumerate(class_names):
    c_prob = probs[:,ci]
    top_threshold = np.quantile(c_prob[c_prob>0],TOP_PCT_TO_SAMPLE_FROM)
    supersample_size = max((c_prob>top_threshold).sum(),NUMBER_OF_SAMPLES_PER_CLASS)
    c_prob_argsort = c_prob.argsort()
    best_and_rest_from_top_pct = [
        c_prob_argsort[-1],
        *take(
            NUMBER_OF_SAMPLES_PER_CLASS-1,
            [
                x for x in np.random.choice(
                    c_prob_argsort[-supersample_size:],
                    NUMBER_OF_SAMPLES_PER_CLASS,
                    replace=supersample_size==NUMBER_OF_SAMPLES_PER_CLASS,
                ) if x != c_prob_argsort[-1] or supersample_size==NUMBER_OF_SAMPLES_PER_CLASS
            ]
        )
    ]
    z_np = points.T[best_and_rest_from_top_pct,:]
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-pointgrid-1c.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
softmax_probs = torch.softmax(torch.tensor(np.log(probs))/SOFTMAX_REPARAM_TEMPERATURE,1).numpy()
softmax_classes = class_names[np.argmax(softmax_probs*probs, axis=1)]
pd.value_counts(softmax_classes)

In [None]:
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from softmax of point grid 1-class KDE")
for ci, c in enumerate(class_names):
    indexes_of_this_class = np.logical_and(
        softmax_classes==c,
        probs[:,ci]>0
    )
    num_points_in_this_class = indexes_of_this_class.sum()
    num_show = min(NUMBER_OF_SAMPLES_PER_CLASS,num_points_in_this_class)
    if num_points_in_this_class>0:
        points_in_this_class = points.T[indexes_of_this_class,:]
        random_points_from_this_class = points_in_this_class[
            np.random.choice(
                num_points_in_this_class,
                num_show,
                replace=False,
            )
        ]
        z_np = random_points_from_this_class
        z = torch.tensor(z_np)
        if torch.cuda.is_available(): z = z.cuda()
        x_hat = model.decoder(z).detach().cpu().numpy()
        x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(
            (
                x_hat[pi]
                if pi<num_show else
                np.zeros((dset_img_dim,dset_img_dim,1))
            ),
            vmin=0, vmax=1, cmap="gray"
        )
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        if pi<num_show:
            ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
                size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-pointgrid-1c-softmax.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from the KDE estimators")
for ci, c in enumerate(class_names):
    z_np = kdes[c].resample(NUMBER_OF_SAMPLES_PER_CLASS).T.astype(np.float32)
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-kde-1c.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
bayopts = {}
for ci, c in enumerate(class_names):
    bayopts[ci] = nll_optimizer_for(ci,
                                model.decoder,
                                ext_clf,
                                bounds=[
                                    (min(mins[0,i],-4.0), max(maxs[0,i],4.0))
                                    for i in range(mins.shape[1])
                                ],
                                )
    bayopts[ci].run_optimization(max_iter=BAYESIAN_OPTIMISATION_STEPS)

In [None]:
for ci, c in enumerate(class_names):
    if SAVE_IMAGES:
        # Why do they show the image inside the function???
        for fmt in IMAGE_FORMATS:
            bayopts[ci].plot_acquisition(filename=osp.join(images_folder,fmt,f"{model_name}-bayes-acquisition-{ci}.{fmt}"))
            plt.close()
    if SHOW_ACQUISITIONS:
        print(c)
        bayopts[ci].plot_acquisition()
        plt.close()

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from the Bayesian Optimisation")
for ci, c in enumerate(class_names):
    top_idx = np.argsort(bayopts[ci].Y, axis=0).ravel()
    z_np = bayopts[ci].X[top_idx].astype(np.float32)
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-bayesian-extclf.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

## Variational Autoencoder

In [None]:
def sample(mean:torch.Tensor, log_var:torch.Tensor) -> torch.Tensor:
    sigma = torch.sqrt(torch.exp(log_var))
    epsilon = torch.normal(mean=0., std=1., size=mean.shape, device=mean.device)
    return mean + sigma * epsilon


def kl_div(mean:torch.Tensor, log_var:torch.Tensor) -> torch.Tensor:
    # Regularization term (KL divergence)
    kl_l = -0.5 * torch.sum(1 + log_var \
                             - torch.square(mean) \
                             - torch.exp(log_var), axis=-1)
    
    return kl_l

class MNIST_VAE(nn.Module):
    def __init__(self, in_channels = 1, hidden_channels = 2):
        super(MNIST_VAE, self).__init__()

        self.hidden_channels = hidden_channels

        self.encoder = nn.Sequential(
            CNN(in_channels,[32,64,64,64],[3,3,3,3],[1,2,1,1],[1,1,1,1]),
            nn.Flatten(),
            nn.Linear(64*dset_img_dim_half*dset_img_dim_half,32),
            nn.ReLU(),
            nn.Linear(32,2*hidden_channels),
        )

        self.decoder = nn.Sequential(
            nn.Linear(hidden_channels,64*dset_img_dim_half*dset_img_dim_half),
            nn.ReLU(),
            nn.Unflatten(1,(64,14,14)),
            DeCNN(64,[32],[3],[2],[1],[1]),
            CNN(32,[in_channels],[3],[1],[1],nonlinearities=[nn.Sigmoid])
        )
    
    def get_mu_and_sigma(self, x):
        mu_and_sigma = self.encoder(x)
        mu, sigma = mu_and_sigma[:,:self.hidden_channels], mu_and_sigma[:,self.hidden_channels:]
        return mu, sigma
    
    def sample(self, mu, sigma):
        z = sample(mu, sigma)
        return z

    def decode(self, z):
        img = self.decoder(z)
        return img

    def forward(self, x):
        return self.decode(
            self.sample(
                *self.get_mu_and_sigma(x)
            )
        )

### Unconstrained

In [None]:
model_name = "U-VAE"

model = MNIST_VAE(mnist_n_channels,AE_LATENT_DIM)
if torch.cuda.is_available(): model.cuda()

opt = optim.Adam(model.parameters())

bce_loss = lambda inp,tgt: torch.sum(F.binary_cross_entropy(inp,tgt,reduction="none"), dim=list(range(1,len(inp.shape)))) #nn.BCELoss(reduction="sum")
kl_loss = lambda mu, sigma: kl_div(mu, sigma)
total_loss = lambda bce_l, kl_l: torch.mean(torch.concat([bce_l, kl_l]))

train_history = {
    "epoch": [],
    "batch": [],
    "rec_bce_loss": [],
    "kl_loss": [],
    "total_loss": [],
}
test_history = {
    "epoch": [],
    "batch": [],
    "rec_bce_loss": [],
    "kl_loss": [],
    "total_loss": [],
}

for e in tqdm(range(N_EPOCHS)):
    with torch.no_grad():
        for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
            mu, sigma = model.get_mu_and_sigma(x)
            z = model.sample(mu, sigma)
            x_hat = model.decode(z)
            bce_l = bce_loss(x_hat,x)
            kl_l = kl_loss(mu, sigma)
            total_l = total_loss(bce_l,kl_l)
            test_history["epoch"].append(e)
            test_history["batch"].append(b)
            test_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
            test_history["kl_loss"].append(kl_l.detach().cpu().numpy().mean().item())
            test_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

    for b, (x, y) in enumerate(DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        opt.zero_grad()
        mu, sigma = model.get_mu_and_sigma(x)
        z = model.sample(mu, sigma)
        x_hat = model.decode(z)
        bce_l = bce_loss(x_hat,x)
        kl_l = kl_loss(mu, sigma)
        total_l = total_loss(bce_l,kl_l)
        total_l.backward()
        opt.step()
        train_history["epoch"].append(e)
        train_history["batch"].append(b)
        train_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
        train_history["kl_loss"].append(kl_l.detach().cpu().numpy().mean().item())
        train_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        mu, sigma = model.get_mu_and_sigma(x)
        z = model.sample(mu, sigma)
        x_hat = model.decode(z)
        bce_l = bce_loss(x_hat,x)
        kl_l = kl_loss(mu, sigma)
        total_l = total_loss(bce_l,kl_l)
        test_history["epoch"].append(N_EPOCHS)
        test_history["batch"].append(b)
        test_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
        test_history["kl_loss"].append(kl_l.detach().cpu().numpy().mean().item())
        test_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

In [None]:
train_df = pd.DataFrame(train_history)
train_df["batch_in_epoch"] = train_df["epoch"] + train_df["batch"]/train_df["batch"].max()
test_df = pd.DataFrame(test_history)

In [None]:
loss_type="total_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} Final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
loss_type="rec_bce_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
loss_type="kl_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()


In [None]:
zs = []
with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        zs.append(model.encoder(x).detach().cpu().numpy()[:,:2])
z_train = np.concatenate(tuple(zs),axis=0)
z_train.shape, mnist_train.targets.numpy().shape

In [None]:
zs = []
with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        zs.append(model.encoder(x).detach().cpu().numpy()[:,:2])
z_test = np.concatenate(tuple(zs),axis=0)
z_test.shape, mnist_test.targets.numpy().shape

In [None]:
fig, axes = plt.subplots(1,2, sharex=True, sharey=True)

x_train, y_train = z_train[:,0], z_train[:,1]
hue_train = mnist_train.targets.numpy()

x_test, y_test = z_test[:,0], z_test[:,1]
hue_test = mnist_test.targets.numpy()

sns.scatterplot(x=x_train, y=y_train, hue=class_names[hue_train], hue_order=class_names, marker="x", s=4, ax=axes[0], legend=False,)
sns.histplot(x=x_train, y=y_train, bins=64, pthresh=0.01, hue=class_names[hue_train], hue_order=class_names, ax=axes[0], legend=False,)
sns.kdeplot(x=x_train, y=y_train, levels=5, hue=class_names[hue_train], hue_order=class_names, linewidths=1, ax=axes[0], legend=False,)

sns.scatterplot(x=x_test, y=y_test, hue=class_names[hue_test], hue_order=class_names, marker="x", s=4, ax=axes[1],)
sns.histplot(x=x_test, y=y_test, bins=64, pthresh=0.01, hue=class_names[hue_test], hue_order=class_names, ax=axes[1],)
sns.kdeplot(x=x_test, y=y_test, levels=5, hue=class_names[hue_test], hue_order=class_names, linewidths=1, ax=axes[1],)

sns.move_legend(axes[1], "upper left", bbox_to_anchor=(1, 0.75),)

plt.suptitle(f"{model_name} latent space distribution")
axes[0].set_title(f"Train")
axes[1].set_title(f"Test")

if SAVE_IMAGES:
    xlim = axes[0].get_xlim()
    ylim = axes[0].get_ylim()
    axes[0].set_xlim(max(xlim[0],-1),min(xlim[1],1))
    axes[0].set_ylim(max(ylim[0],-1),min(ylim[1],1))
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-z-zoom.{fmt}"), bbox_inches="tight",)
    axes[0].set_xlim(xlim)
    axes[0].set_ylim(ylim)
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-z.{fmt}"), bbox_inches="tight",)

if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
kdes = {}
for ci, c in enumerate(class_names):
    kdes[c] = sps.gaussian_kde(z_train[hue_train==ci,].T)

In [None]:
mins, maxs = z_train.min(axis=0,keepdims=True), z_train.max(axis=0,keepdims=True)
xx, yy = np.meshgrid(*[np.linspace(mins[:,i],maxs[:,i],NUM_POINTS_GRID) for i in range(2)])
points = np.vstack([xx.ravel(), yy.ravel()])

In [None]:
probs = []
for ci, c in enumerate(class_names):
    probs.append(kdes[c](points))
probs = np.vstack(probs).T

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Samples from point grid 1-class KDEs")
axes[0,0].set_title("Best")
for ci, c in enumerate(class_names):
    c_prob = probs[:,ci]
    top_threshold = np.quantile(c_prob[c_prob>0],TOP_PCT_TO_SAMPLE_FROM)
    supersample_size = max((c_prob>top_threshold).sum(),NUMBER_OF_SAMPLES_PER_CLASS)
    c_prob_argsort = c_prob.argsort()
    best_and_rest_from_top_pct = [
        c_prob_argsort[-1],
        *take(
            NUMBER_OF_SAMPLES_PER_CLASS-1,
            [
                x for x in np.random.choice(
                    c_prob_argsort[-supersample_size:],
                    NUMBER_OF_SAMPLES_PER_CLASS,
                    replace=supersample_size==NUMBER_OF_SAMPLES_PER_CLASS,
                ) if x != c_prob_argsort[-1] or supersample_size==NUMBER_OF_SAMPLES_PER_CLASS
            ]
        )
    ]
    z_np = points.T[best_and_rest_from_top_pct,:]
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-pointgrid-1c.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
softmax_probs = torch.softmax(torch.tensor(np.log(probs))/SOFTMAX_REPARAM_TEMPERATURE,1).numpy()
softmax_classes = np.array(class_names)[np.argmax(softmax_probs*probs, axis=1)]
pd.value_counts(softmax_classes)

In [None]:
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from softmax of point grid 1-class KDE")
for ci, c in enumerate(class_names):
    indexes_of_this_class = np.logical_and(
        softmax_classes==c,
        probs[:,ci]>0
    )
    num_points_in_this_class = indexes_of_this_class.sum()
    num_show = min(NUMBER_OF_SAMPLES_PER_CLASS,num_points_in_this_class)
    if num_points_in_this_class>0:
        points_in_this_class = points.T[indexes_of_this_class,:]
        random_points_from_this_class = points_in_this_class[
            np.random.choice(
                num_points_in_this_class,
                num_show,
                replace=False,
            )
        ]
        z_np = random_points_from_this_class
        z = torch.tensor(z_np)
        if torch.cuda.is_available(): z = z.cuda()
        x_hat = model.decoder(z).detach().cpu().numpy()
        x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(
            (
                x_hat[pi]
                if pi<num_show else
                np.zeros((dset_img_dim,dset_img_dim,1))
            ),
            vmin=0, vmax=1, cmap="gray"
        )
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        if pi<num_show:
            ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
                size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-pointgrid-1c-softmax.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from the KDE estimators")
for ci, c in enumerate(class_names):
    z = torch.tensor(kdes[c].resample(NUMBER_OF_SAMPLES_PER_CLASS).T.astype(np.float32))
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-kde-1c.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
bayopts = {}
for ci, c in enumerate(class_names):
    bayopts[ci] = nll_optimizer_for(ci,
                                model.decoder,
                                ext_clf,
                                bounds=[
                                    (min(mins[0,i],-4.0), max(maxs[0,i],4.0))
                                    for i in range(mins.shape[1])
                                ],
                                )
    bayopts[ci].run_optimization(max_iter=BAYESIAN_OPTIMISATION_STEPS)

In [None]:
for ci, c in enumerate(class_names):
    if SAVE_IMAGES:
        # Why do they show the image inside the function???
        for fmt in IMAGE_FORMATS:
            bayopts[ci].plot_acquisition(filename=osp.join(images_folder,fmt,f"{model_name}-bayes-acquisition-{ci}.{fmt}"))
            plt.close()
    if SHOW_ACQUISITIONS:
        print(c)
        bayopts[ci].plot_acquisition()
        plt.close()

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from the Bayesian Optimisation")
for ci, c in enumerate(class_names):
    top_idx = np.argsort(bayopts[ci].Y, axis=0).ravel()
    z_np = bayopts[ci].X[top_idx].astype(np.float32)
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-bayesian-extclf.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

### Constrained

In [None]:
model_name = "C-VAE"
N_EPOCHS = 16
BATCH_SIZE = 64
AE_LATENT_DIM = 2

model = MNIST_VAE(mnist_n_channels,AE_LATENT_DIM)
clf = MLP(AE_LATENT_DIM,[128,128,mnist_n_classes], drop_last_nonlinearity=True)
if torch.cuda.is_available():
    model.cuda()
    clf.cuda()

opt = optim.Adam(chain(model.parameters(),clf.parameters()))

bce_loss = lambda inp,tgt: torch.sum(F.binary_cross_entropy(inp,tgt,reduction="none"), dim=list(range(1,len(inp.shape)))) #nn.BCELoss(reduction="sum")
kl_loss = lambda mu, sigma: kl_div(mu, sigma)
xe_loss = nn.CrossEntropyLoss()
total_loss = lambda bce_l, kl_l, xe_loss: 20*xe_loss + torch.mean(torch.concat([bce_l, kl_l]))
#if torch.cuda.is_available():
#    bce_loss = bce_loss.cuda()
#    xe_loss = xe_loss.cuda()

train_history = {
    "epoch": [],
    "batch": [],
    "rec_bce_loss": [],
    "kl_loss": [],
    "clf_xe_loss": [],
    "total_loss": [],
}
test_history = {
    "epoch": [],
    "batch": [],
    "rec_bce_loss": [],
    "kl_loss": [],
    "clf_xe_loss": [],
    "total_loss": [],
}

for e in tqdm(range(N_EPOCHS)):
    with torch.no_grad():
        for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
            mu, sigma = model.get_mu_and_sigma(x)
            y_hat = clf(mu)
            z = model.sample(mu, sigma)
            x_hat = model.decode(z)
            bce_l = bce_loss(x_hat,x)
            kl_l = kl_loss(mu, sigma)
            xe_l = xe_loss(y_hat,y)
            total_l = total_loss(bce_l,kl_l,xe_l)
            test_history["epoch"].append(e)
            test_history["batch"].append(b)
            test_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
            test_history["kl_loss"].append(kl_l.detach().cpu().numpy().mean().item())
            test_history["clf_xe_loss"].append(xe_l.detach().cpu().numpy().mean().item())
            test_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

    for b, (x, y) in enumerate(DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        opt.zero_grad()
        mu, sigma = model.get_mu_and_sigma(x)
        y_hat = clf(mu)
        z = model.sample(mu, sigma)
        x_hat = model.decode(z)
        bce_l = bce_loss(x_hat,x)
        kl_l = kl_loss(mu, sigma)
        xe_l = xe_loss(y_hat,y)
        total_l = total_loss(bce_l,kl_l,xe_l)
        total_l.backward()
        opt.step()
        train_history["epoch"].append(e)
        train_history["batch"].append(b)
        train_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
        train_history["kl_loss"].append(kl_l.detach().cpu().numpy().mean().item())
        train_history["clf_xe_loss"].append(xe_l.detach().cpu().numpy().mean().item())
        train_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        mu, sigma = model.get_mu_and_sigma(x)
        y_hat = clf(mu)
        z = model.sample(mu, sigma)
        x_hat = model.decode(z)
        bce_l = bce_loss(x_hat,x)
        kl_l = kl_loss(mu, sigma)
        xe_l = xe_loss(y_hat,y)
        total_l = total_loss(bce_l,kl_l,xe_l)
        test_history["epoch"].append(N_EPOCHS)
        test_history["batch"].append(b)
        test_history["rec_bce_loss"].append(bce_l.detach().cpu().numpy().mean().item())
        test_history["kl_loss"].append(kl_l.detach().cpu().numpy().mean().item())
        test_history["clf_xe_loss"].append(xe_l.detach().cpu().numpy().mean().item())
        test_history["total_loss"].append(total_l.detach().cpu().numpy().mean().item())

In [None]:
train_df = pd.DataFrame(train_history)
train_df["batch_in_epoch"] = train_df["epoch"] + train_df["batch"]/train_df["batch"].max()
test_df = pd.DataFrame(test_history)

In [None]:
loss_type="total_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
loss_type="rec_bce_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
loss_type="kl_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
loss_type="clf_xe_loss"
train_df.set_index("batch_in_epoch")[loss_type].plot(c="b")
ax:plt.Axes = plt.gca()
nax = ax.twinx()
ax.sharey(nax)
test_df.groupby("epoch")[loss_type].mean().plot(c="r")
plt.title(f"{model_name} {loss_type}")
print(f"{model_name} final {loss_type} tr/val = {train_df.set_index('batch_in_epoch').sort_index(inplace=False,ascending=True)[loss_type].iloc[-1]:.4f}/{test_df.groupby('epoch')[loss_type].mean().sort_index(inplace=False,ascending=True).iloc[-1]:.4f}")
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-{loss_type}.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
zs = []
with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        zs.append(model.encoder(x).detach().cpu().numpy()[:,:2])
z_train = np.concatenate(tuple(zs),axis=0)
z_train.shape, mnist_train.targets.numpy().shape

In [None]:
zs = []
with torch.no_grad():
    for b, (x, y) in enumerate(DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)):
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
        zs.append(model.encoder(x).detach().cpu().numpy()[:,:2])
z_test = np.concatenate(tuple(zs),axis=0)
z_test.shape, mnist_test.targets.numpy().shape

In [None]:
fig, axes = plt.subplots(1,2, sharex=True, sharey=True)

x_train, y_train = z_train[:,0], z_train[:,1]
hue_train = mnist_train.targets.numpy()

x_test, y_test = z_test[:,0], z_test[:,1]
hue_test = mnist_test.targets.numpy()

sns.scatterplot(x=x_train, y=y_train, hue=class_names[hue_train], hue_order=class_names, marker="x", s=4, ax=axes[0], legend=False,)
sns.histplot(x=x_train, y=y_train, bins=64, pthresh=0.01, hue=class_names[hue_train], hue_order=class_names, ax=axes[0], legend=False,)
kde_train = sns.kdeplot(x=x_train, y=y_train, levels=5, hue=class_names[hue_train], hue_order=class_names, linewidths=1, ax=axes[0], legend=False,)

sns.scatterplot(x=x_test, y=y_test, hue=class_names[hue_test], hue_order=class_names, marker="x", s=4, ax=axes[1],)
sns.histplot(x=x_test, y=y_test, bins=64, pthresh=0.01, hue=class_names[hue_test], hue_order=class_names, ax=axes[1],)
kde_test = sns.kdeplot(x=x_test, y=y_test, levels=5, hue=class_names[hue_test], hue_order=class_names, linewidths=1, ax=axes[1],)

sns.move_legend(axes[1], "upper left", bbox_to_anchor=(1, 0.75),)

plt.suptitle(f"{model_name} latent space distribution")
axes[0].set_title(f"Train")
axes[1].set_title(f"Test")

if SAVE_IMAGES:
    xlim = axes[0].get_xlim()
    ylim = axes[0].get_ylim()
    axes[0].set_xlim(max(xlim[0],-1),min(xlim[1],1))
    axes[0].set_ylim(max(ylim[0],-1),min(ylim[1],1))
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-z-zoom.{fmt}"), bbox_inches="tight",)
    axes[0].set_xlim(xlim)
    axes[0].set_ylim(ylim)
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-z.{fmt}"), bbox_inches="tight",)

if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
kdes = {}
for ci, c in enumerate(class_names):
    kdes[c] = sps.gaussian_kde(z_train[hue_train==ci,].T)

In [None]:
mins, maxs = z_train.min(axis=0,keepdims=True), z_train.max(axis=0,keepdims=True)
xx, yy = np.meshgrid(*[np.linspace(mins[:,i],maxs[:,i],NUM_POINTS_GRID) for i in range(2)])
points = np.vstack([xx.ravel(), yy.ravel()])

In [None]:
probs = []
for ci, c in enumerate(class_names):
    probs.append(kdes[c](points))
probs = np.vstack(probs).T

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Samples from point grid 1-class KDEs")
axes[0,0].set_title("Best")
for ci, c in enumerate(class_names):
    c_prob = probs[:,ci]
    top_threshold = np.quantile(c_prob[c_prob>0],TOP_PCT_TO_SAMPLE_FROM)
    supersample_size = max((c_prob>top_threshold).sum(),NUMBER_OF_SAMPLES_PER_CLASS)
    c_prob_argsort = c_prob.argsort()
    best_and_rest_from_top_pct = [
        c_prob_argsort[-1],
        *take(
            NUMBER_OF_SAMPLES_PER_CLASS-1,
            [
                x for x in np.random.choice(
                    c_prob_argsort[-supersample_size:],
                    NUMBER_OF_SAMPLES_PER_CLASS,
                    replace=supersample_size==NUMBER_OF_SAMPLES_PER_CLASS,
                ) if x != c_prob_argsort[-1] or supersample_size==NUMBER_OF_SAMPLES_PER_CLASS
            ]
        )
    ]
    z_np = points.T[best_and_rest_from_top_pct,:]
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-pointgrid-1c.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
softmax_probs = torch.softmax(torch.tensor(np.log(probs))/SOFTMAX_REPARAM_TEMPERATURE,1).numpy()
softmax_classes = np.array(class_names)[np.argmax(softmax_probs*probs, axis=1)]
pd.value_counts(softmax_classes)

In [None]:
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from softmax of point grid 1-class KDE")
for ci, c in enumerate(class_names):
    indexes_of_this_class = np.logical_and(
        softmax_classes==c,
        probs[:,ci]>0
    )
    num_points_in_this_class = indexes_of_this_class.sum()
    num_show = min(NUMBER_OF_SAMPLES_PER_CLASS,num_points_in_this_class)
    if num_points_in_this_class>0:
        points_in_this_class = points.T[indexes_of_this_class,:]
        random_points_from_this_class = points_in_this_class[
            np.random.choice(
                num_points_in_this_class,
                num_show,
                replace=False,
            )
        ]
        z_np = random_points_from_this_class
        z = torch.tensor(z_np)
        if torch.cuda.is_available(): z = z.cuda()
        x_hat = model.decoder(z).detach().cpu().numpy()
        x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(
            (
                x_hat[pi]
                if pi<num_show else
                np.zeros((dset_img_dim,dset_img_dim,1))
            ),
            vmin=0, vmax=1, cmap="gray"
        )
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        if pi<num_show:
            ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
                size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-pointgrid-1c-softmax.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from the KDE estimators")
for ci, c in enumerate(class_names):
    z_np = kdes[c].resample(NUMBER_OF_SAMPLES_PER_CLASS).T.astype(np.float32)
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-kde-1c.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()

In [None]:
bayopts = {}
for ci, c in enumerate(class_names):
    bayopts[ci] = nll_optimizer_for(ci,
                                model.decoder,
                                ext_clf,
                                bounds=[
                                    (min(mins[0,i],-4.0), max(maxs[0,i],4.0))
                                    for i in range(mins.shape[1])
                                ],
                                )
    bayopts[ci].run_optimization(max_iter=BAYESIAN_OPTIMISATION_STEPS)

In [None]:
for ci, c in enumerate(class_names):
    if SAVE_IMAGES:
        # Why do they show the image inside the function???
        for fmt in IMAGE_FORMATS:
            bayopts[ci].plot_acquisition(filename=osp.join(images_folder,fmt,f"{model_name}-bayes-acquisition-{ci}.{fmt}"))
            plt.close()
    if SHOW_ACQUISITIONS:
        print(c)
        bayopts[ci].plot_acquisition()
        plt.close()

In [None]:
number_of_classes = len(class_names)
fig, axes = plt.subplots(number_of_classes, NUMBER_OF_SAMPLES_PER_CLASS)
fig.suptitle(f"{model_name} Random samples from the Bayesian Optimisation")
for ci, c in enumerate(class_names):
    top_idx = np.argsort(bayopts[ci].Y, axis=0).ravel()
    z_np = bayopts[ci].X[top_idx].astype(np.float32)
    z = torch.tensor(z_np)
    if torch.cuda.is_available(): z = z.cuda()
    x_hat = model.decoder(z).detach().cpu().numpy()
    x_hat = x_hat.reshape([x_hat.shape[0],*x_hat.shape[2:],x_hat.shape[1]])
    axes[ci,0].set_ylabel(c, size=YLABEL_FONTSIZE)
    for pi in range(NUMBER_OF_SAMPLES_PER_CLASS):
        ax:plt.Axes = axes[ci,pi]
        ax.imshow(x_hat[pi], vmin=0, vmax=1, cmap="gray")
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.text(2.1, 0.5, ", ".join(map("{:.2f}".format,z_np[pi])),
            size=COORDS_FONTSIZE, ha='center', va='center', transform=ax.transAxes)
if SAVE_IMAGES:
    for fmt in IMAGE_FORMATS:
        plt.savefig(osp.join(images_folder,fmt,f"{model_name}-samples-bayesian-extclf.{fmt}"))
if SHOW_IMAGES:
    plt.show()
plt.close()