# Rotating MNIST Notebook

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import sys
sys.path.append('../')

import torch
import torch.nn as nn
import numpy as np
import json
from pprint import pprint

from argparse import Namespace

from core.models import (
    RotatingMNISTRecogNetwork, 
    RotatingMNISTReconNetwork, 
    SOnPathDistributionEncoder, 
    PathToBernoulliDecoder, 
    default_SOnPathDistributionEncoder)

Load experiment **log file**:

In [None]:
dataset = "rotmnist"
experiment_id = 80345
log_file = f"../logs/{dataset}_{experiment_id}.json"

with open(log_file,'r') as f:
    logs = json.load(f)

Print **experiment configuration** (args):

In [None]:
pprint(logs['args'])

Print final **training/evaluation/testing statistics**:

In [None]:
pprint(logs['final'])

In [None]:
def plot_stat(logs: dict, stat:str, modes:list = ['trn','tst', 'val']):
    fig, ax = plt.subplots(figsize=(8,3))
    for mode in modes:
        key = f"{mode}_{stat}"
        val = logs['all'][key]
        ax.plot(val, label = mode)
    ax.set_xlabel('training epochs')
    ax.set_ylabel(stat)
    ax.grid()
    return fig, ax

E.g., one could look at the **loss evolution** over all training epochs:

In [None]:
plot_stat(logs, 'loss')

... or compare (for this experiment), the **MSE** on the left-out target:

In [None]:
fig, ax = plot_stat(logs, 'mse_trgt', ['tst','val'])
ax.set_ylim((0.01,0.02))

## Loading a checkpoint & Sampling from the posterior

First, **instantiate** the model (using the model configuration provided via the command line arguments):

In [None]:
args = Namespace(**logs['args'])

recog_net = RotatingMNISTRecogNetwork(n_filters=args.n_filters)
recon_net = RotatingMNISTReconNetwork(
    z_dim=args.z_dim, n_filters=args.n_filters * 2
)
qzx_net = default_SOnPathDistributionEncoder(
    h_dim=256, 
    z_dim=args.z_dim, 
    n_deg=args.n_deg, 
    time_min=0.0, 
    time_max=20.0
)
pxz_net = PathToBernoulliDecoder(logit_map=recon_net)

modules = nn.ModuleDict(
    {
        "recog_net": recog_net,
        "recon_net": recon_net,
        "pxz_net": pxz_net,
        "qzx_net": qzx_net,
    }
)
modules = modules.to(args.device)

Next, we look the checkpoint (at the epoch specified):

In [None]:
epoch = 990
checkpoint = f"checkpoints/checkpoint_{experiment_id}_{epoch}.h5"
checkpoint = torch.load(checkpoint)
modules.load_state_dict(checkpoint['modules'])

Get the data loaders for training/validation/testing:

In [None]:
from data.mnist_provider import RotatingMNISTProvider
provider = RotatingMNISTProvider(args.data_dir, random_state=133)
dl_trn = provider.get_train_loader(batch_size=args.batch_size, shuffle=True)
dl_val = provider.get_val_loader(batch_size=args.batch_size, shuffle=False)
dl_tst = provider.get_test_loader(batch_size=args.batch_size, shuffle=False)

batch = next(iter(dl_tst))

In the example below, we run one batch of testing data through the model. We can, e.g., look at samples from the approximate posterior (i.e., **latent paths**), or at the reconstructions (i.e., paths in the input space, reconstructed from the latent paths).

In [None]:
dl = dl_tst
device = 'cpu'
modules = modules.to(device)
desired_t = torch.linspace(0, 0.99, provider.num_timepoints, device=device)
for _, batch in enumerate(dl):
    parts = {key: val.to(device) for key, val in batch.items()}
    inp = (parts["inp_obs"], parts["inp_msk"], parts["inp_tps"])
    h = modules["recog_net"](inp)
    qzx, pz = modules["qzx_net"](h, desired_t)
    zis = qzx.rsample((args.mc_eval_samples,))
    pxz = modules["pxz_net"](zis)
    break

In [None]:
from mpl_toolkits.axes_grid1 import ImageGrid

rec = pxz.mean[0,0].detach().cpu()
rec = np.array(rec).transpose(0,2,3,1)

fig = plt.figure(figsize=(4., 4.))
grid = ImageGrid(fig, 111,
                 nrows_ncols=(4, 4),
                 axes_pad=0.1,
                 )

for ax, im in zip(grid, rec):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)
    ax.axis('off')

plt.show()

## Extrapolation

In [None]:
k = 4
desired_t = torch.linspace(0, 0.99*k, k*provider.num_timepoints, device=device)

for _, batch in enumerate(dl):
    parts = {key: val.to(device) for key, val in batch.items()}
    inp = (parts["inp_obs"], parts["inp_msk"], parts["inp_tps"])
    h = modules["recog_net"](inp)
    qzx, pz = modules["qzx_net"](h, desired_t)
    zis = qzx.rsample((args.mc_eval_samples,))
    pxz = modules["pxz_net"](zis)
    break

from mpl_toolkits.axes_grid1 import ImageGrid

rec = pxz.mean[0,0].detach().cpu()
rec = np.array(rec).transpose(0,2,3,1)

fig = plt.figure(figsize=(16., 1.*k))
grid = ImageGrid(fig, 111,
                 nrows_ncols=(k, 16),
                 axes_pad=0.1,
                 )

for ax, im in zip(grid, rec):
    ax.imshow(im)
    ax.axis('off')
plt.show()