In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt

from survae import SurVAE, DEVICE
from survae.data import *
from survae.layer import *
from survae.calibrate import *

In [None]:
SAVE_PATH = "./saves/sv_smnist"

In [None]:
# number of points per image
N_POINTS = 50

In [None]:
smnist_dataset = SpatialMNIST(k=N_POINTS, flatten=False)

In [None]:
class BijectiveLayer(Layer):
    def __init__(self, shape: tuple[int] | int, hidden_sizes: list[int]) -> None:
        '''
        Standard bijective block from normalizing flow architecture.

        ### Inputs:
        * shape: Shape of input entries, which is the same for the output.
        * hidden_sizes: Sizes of hidden layers of the nested FFNN.
        '''
        super().__init__()

        # transform shape variable into a more usable form
        if isinstance(shape, int):
            shape = (shape,)

        self.shape = shape

        self.size = torch.prod(torch.tensor(shape)).item()

        assert self.size > 1, "Bijective layer size must be at least 2!"

        # The size of the skip connection is half the input size, rounded down
        self.skip_size = self.size // 2
        self.non_skip_size = self.size - self.skip_size

        self.hidden_sizes = hidden_sizes

        self.ffnn = FFNN(self.skip_size, self.hidden_sizes, 2 * self.non_skip_size)

    def forward(self, X: torch.Tensor, condition: torch.Tensor | None = None, return_log_likelihood: bool = False):
        # flatten input
        X = X.flatten(start_dim=1)

        # split input into skip and non-skip
        skip_connection = X[:, :self.skip_size]
        non_skip_connection = X[:, self.skip_size:]

        # add conditional input
        if condition is not None:
            ffnn_input = torch.cat((skip_connection, condition), dim=1)
        else:
            ffnn_input = skip_connection

        # compute coefficients for linear transformation
        coeffs = self.ffnn(ffnn_input)
        # split output into t and pre_s
        t = coeffs[:, :self.non_skip_size]
        pre_s = coeffs[:, self.non_skip_size:]
        # compute s_log for log-likelihood contribution
        s_log = tanh(pre_s)
        # compute s
        s = exp(s_log)

        # apply transformation
        new_connection = s * non_skip_connection + t
        # stack skip connection and transformed non-skip connection
        Z = torch.cat((skip_connection, new_connection), dim=1)

        # reshape output
        Z = Z.reshape(-1, *self.shape)

        if return_log_likelihood:
            return Z, torch.sum(s_log)
        else:
            return Z

    def backward(self, Z: torch.Tensor, condition: torch.Tensor | None = None):
        # flatten input
        Z = Z.flatten(start_dim=1)

        # split input into skip and non-skip
        skip_connection = Z[:, :self.skip_size]
        non_skip_connections = Z[:, self.skip_size:]

        # add conditional input
        if condition is not None:
            ffnn_input = torch.cat((skip_connection, condition), dim=1)
        else:
            ffnn_input = skip_connection

        # compute coefficients for linear transformation
        coeffs = self.ffnn(ffnn_input)
        # split output into t and pre_s
        t = coeffs[:, :self.non_skip_size]
        pre_s = coeffs[:, self.non_skip_size:]
        # compute s
        s = exp(tanh(pre_s))

        # apply inverse transformation
        new_connection = (non_skip_connections - t) / s
        # stack skip connection and transformed non-skip connection
        X = torch.cat((skip_connection, new_connection), dim=1)

        # reshape output
        X = X.reshape(-1, *self.shape)

        return X

    def make_conditional(self, size: int):
        self.ffnn = FFNN(self.skip_size + size, self.hidden_sizes, 2 * self.non_skip_size)

    def in_size(self) -> int | None:
        return self.size

    def out_size(self) -> int | None:
        return self.size

In [None]:
# This is a very hacky way of creating a layer that swaps the x- and y-components
# of the spatial data, which is needed for each step in the SMNIST architecture.
# I am sorry.
transposition_layer = OrthonormalLayer(2)
transposition_layer.o = nn.Parameter(torch.tensor([[0, 1], [1, 0]], dtype=torch.double), requires_grad=False)

In [None]:
sv_smnist = SurVAE(
    [
        [
            PermuteAxesLayer((1, 0)),
            BijectiveLayer((2, N_POINTS), [200, 200]),
            PermuteAxesLayer((1, 0)),
            transposition_layer,
            BijectiveLayer((N_POINTS, 2), [200, 200]),
            PermutationLayer()
        ]
        for _ in range(32)
    ] +
    [ReshapeLayer((2, N_POINTS), (2 * N_POINTS,))],
    name = "SV_SMNIST",
    condition_size = 10
)

In [None]:
train_log = sv_smnist.train(
    dataset    = smnist_dataset,
    batch_size = 1000,
    test_size  = 100,
    epochs     = 8_000,
    lr         = 1e-3,
    log_period = 400,
    show_tqdm  = True,
    lr_decay_params = {'gamma': 0.95, 'step_size': 500},
    # save_path = SAVE_PATH,
)

In [None]:
# save the model
torch.save(sv_smnist.state_dict(), "./saves/sv_smnist/model.pt")

### Plot loss

In [None]:
times = list(train_log.keys())
loss_train = [m.training_loss for m in train_log.values()]
loss_test = [m.testing_loss for m in train_log.values()]

In [None]:
plt.figure(figsize=(5, 4))

plt.plot(times, loss_train, label='Training loss')
plt.plot(times, loss_test, label='Validation loss')

plt.title('Loss of MNIST network during training')
plt.xlabel('Number of epochs')
plt.ylabel('NLL Loss')

plt.grid()
# plt.ylim(0, 1e17)
plt.legend()
plt.tight_layout()
plt.show()

## Calibration

In [None]:
X, y = smnist_dataset.sample(1_000, labels=True)
y = smnist_dataset.label_to_one_hot(y.long(), 10)

with torch.no_grad():
    Z = sv_smnist(X, y).cpu()

In [None]:
# calculate standard deviation
sigma = Z.std().item()
print(f"Standard deviation of code distribution is measured to be {sigma:.4f}")

In [None]:
# randomly choose 6 dimensions to display
rp = torch.randperm(Z.shape[-1])[:6].cpu()
plot_learned_distribution(Z[:, rp], "", axis_scale=1, sigma=sigma)

## Sampling

In [None]:
ncols = 10
nrows = 4
plotsize = 1.5

In [None]:
# sample 1's
# _y = torch.tensor([1]).expand(ncols * nrows)

# sample every kind of number several times
_y = torch.arange(10).repeat(nrows)
print(_y)

y = smnist_dataset.label_to_one_hot(_y, 10)

In [None]:
# manually specify the standard deviation for the samples
Z_hat = torch.normal(0, sigma, size=(ncols * nrows, 2 * N_POINTS), device=DEVICE)
with torch.no_grad():
    X_hat = sv_smnist.backward(Z_hat, y).reshape(nrows, ncols, N_POINTS, 2).cpu()

In [None]:
fig, ax = plt.subplots(nrows, ncols, figsize=(plotsize * ncols, plotsize * nrows))

for j in range(ncols):
    ax[0, j].set_title(f"Sample {j}'s")
    for i in range(nrows):
        _ax = ax[i, j]
        data = X_hat[i, j]
        _ax.scatter(data[:, 0], -data[:, 1])
        # _ax.set_xticks([])
        # _ax.set_yticks([])

plt.tight_layout()
plt.show()