## Imports

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
import torch
import torch.nn as nn
import torch.distributions as d
import torch.optim as optim

from core_vectorized import ConditionalNF

import numpy as np
from keras.datasets.mnist import load_data
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile
import torchvision

import pickle

  Referenced from: <8E4D9E61-A0A3-30A7-B778-23E3E702B690> /Users/zhihanyang/opt/miniconda3/envs/mlp/lib/python3.8/site-packages/torchvision/image.so
  Reason: tried: '/Users/zhihanyang/opt/miniconda3/envs/mlp/lib/libpng16.16.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/zhihanyang/opt/miniconda3/envs/mlp/lib/libpng16.16.dylib' (no such file), '/Users/zhihanyang/opt/miniconda3/envs/mlp/lib/libpng16.16.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/zhihanyang/opt/miniconda3/envs/mlp/lib/libpng16.16.dylib' (no such file), '/Users/zhihanyang/opt/miniconda3/envs/mlp/lib/python3.8/lib-dynload/../../libpng16.16.dylib' (no such file), '/Users/zhihanyang/opt/miniconda3/envs/mlp/lib/libpng16.16.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/zhihanyang/opt/miniconda3/envs/mlp/lib/libpng16.16.dylib' (no such file), '/Users/zhihanyang/opt/miniconda3/envs/mlp/bin/../lib/libpng16.16.dylib' (no such file), '/usr/local/lib/libpng16.16.dylib' 

## Decoder

In [3]:
class Decoder(nn.Module):
    
    def __init__(self, z_dim, x_dim):
        super().__init__()
        self.z_dim = z_dim
        self.x_dim = x_dim
        self.mus = nn.Sequential(
            nn.Linear(z_dim, 512),
            nn.ELU(),
            nn.Linear(512, 512),
            nn.ELU(),
            nn.Linear(512, x_dim)
        )

    def forward(self, zs):
        return d.Independent(d.Bernoulli(logits=self.mus(zs)), 1)

In [4]:
dec = Decoder(z_dim=40, x_dim=784)
dec(torch.randn(64, 10, 40)).log_prob(torch.zeros(64, 1, 784)).shape

torch.Size([64, 10])

In [5]:
p_z = d.Independent(d.Normal(torch.zeros(2), torch.ones(2)), 1)
p_z.log_prob(torch.randn(32, 10, 2)).shape

torch.Size([32, 10])

## NFVAE

In [6]:
class NFVAE:

    def __init__(self, z_dim, x_dim, lr, num_flows, num_samples_per_x=1):

        # hyper-parameters
        self.z_dim = z_dim
        self.x_dim = x_dim
        self.lr = lr
        self.num_flows = num_flows
        self.num_samples_per_x = num_samples_per_x

        # describes the generative process
        self.p_z = d.Independent(d.Normal(torch.zeros(z_dim), torch.ones(z_dim)), 1)
        self.p_x_given_z = Decoder(z_dim=z_dim, x_dim=x_dim)

        # required for approximate posterior inference
        self.q_z_given_x = ConditionalNF(x_dim=x_dim, D=z_dim, L=self.num_flows)

        # gradient-based optimizers
        self.p_x_given_z_opt = optim.Adam(self.p_x_given_z.parameters(), lr=lr)
        self.q_z_given_x_opt = optim.Adam(self.q_z_given_x.parameters(), lr=lr)

    def fit(self, xs):
        self.p_x_given_z.train()
        self.q_z_given_x.train()
        zs, log_probs = self.q_z_given_x.rsample(xs, num_samples_per_x=self.num_samples_per_x)  # (bs, 10, z_dim), (bs, 10)
        # kl-divergence
        kl = (log_probs - self.p_z.log_prob(zs)).mean()  # ok this works
        # reconstruction
        rec = self.p_x_given_z(zs).log_prob(xs.reshape(xs.shape[0], 1, xs.shape[1])).mean()  # ok this works
        # elbo (estimated using one sample per posterior)
        elbo = - kl + rec
        # backpropagation and gradient step
        loss = - elbo
        self.p_x_given_z_opt.zero_grad()
        self.q_z_given_x_opt.zero_grad()
        loss.backward()
        self.p_x_given_z_opt.step()
        self.q_z_given_x_opt.step()
        return {
            "kl": float(kl),  # this is at least 0; the smaller the better
            "rec": float(rec),  # the larger the better
            "elbo": float(elbo)  # the larger the better
        }

    def encode(self, xs):
        with torch.no_grad():
            return self.q_z_given_x(xs).mean

    def sample(self, n):
        self.p_x_given_z.eval()
        self.q_z_given_x.eval()
        with torch.no_grad():
            return self.p_x_given_z(self.p_z.sample((n, ))).mean  # or .sample()

    def save(self, save_dir):
        torch.save(self.p_x_given_z.state_dict(), os.path.join(save_dir, "p_x_given_z.pth"))
        torch.save(self.q_z_given_x.state_dict(), os.path.join(save_dir, "q_z_given_x.pth"))

    def load(self, save_dir):
        self.p_x_given_z.load_state_dict(
            torch.load(os.path.join(save_dir, "p_x_given_z.pth"), map_location=torch.device("cpu"))
        )
        self.q_z_given_x.load_state_dict(
            torch.load(os.path.join(save_dir, "q_z_given_x.pth"), map_location=torch.device("cpu"))
        )

## Data preprocessing

In [7]:
(x_train, y_train), (x_test, y_test) = load_data()

x_train = x_train / 255
x_train[x_train >= 0.5] = 1
x_train[x_train < 0.5] = 0
x_train = x_train.reshape(-1, 28*28)

x_test = x_test / 255
x_test[x_test >= 0.5] = 1
x_test[x_test < 0.5] = 0
x_test = x_test.reshape(-1, 28*28)

train_ds = TensorDataset(torch.from_numpy(x_train).float())
train_dl = DataLoader(train_ds, batch_size=128)

## Training

In [8]:
num_flows = 16
num_samples_per_x = 1

for seed in [1, 2, 3]:
    
    torch.manual_seed(seed)
    
    model = NFVAE(z_dim=32, x_dim=28*28, lr=1e-4, num_flows=num_flows, num_samples_per_x=num_samples_per_x)

    writer = SummaryWriter()

    elbos_per_epoch = []

    for epoch in range(100):
        
        print(f"Epoch {epoch+1} | Seed {seed}")

        elbos, kls, recs = [], [], []

        for i, xb in enumerate(train_dl):

            stats_dict = model.fit(xb[0])
            elbos.append(stats_dict["elbo"])
            kls.append(stats_dict["kl"])
            recs.append(stats_dict["rec"])

        elbos_per_epoch.append(np.mean(elbos))

        writer.add_scalar("Stat/elbo", np.mean(elbos), epoch)
        writer.add_scalar("Stat/kl", np.mean(kls), epoch)
        writer.add_scalar("Stat/rec", np.mean(recs), epoch)

        samples = model.sample(n=100).reshape(-1, 1, 28, 28)
        grid = torchvision.utils.make_grid(samples, nrow=10)
        writer.add_image('Viz/samples', grid, epoch)

    save_path = f"./trained_models/{num_flows}_flows_{num_samples_per_x}_samples_{seed}_seed"

    model.save(save_dir=save_path)

    with open(f"{save_path}/elbos", "wb") as fp:
        pickle.dump(elbos_per_epoch, fp)

Epoch 1 | Seed 1
Epoch 2 | Seed 1
Epoch 3 | Seed 1
Epoch 4 | Seed 1
Epoch 5 | Seed 1
Epoch 6 | Seed 1
Epoch 7 | Seed 1
Epoch 8 | Seed 1
Epoch 9 | Seed 1
Epoch 10 | Seed 1
Epoch 11 | Seed 1
Epoch 12 | Seed 1
Epoch 13 | Seed 1
Epoch 14 | Seed 1
Epoch 15 | Seed 1
Epoch 16 | Seed 1
Epoch 17 | Seed 1
Epoch 18 | Seed 1
Epoch 19 | Seed 1
Epoch 20 | Seed 1
Epoch 21 | Seed 1
Epoch 22 | Seed 1
Epoch 23 | Seed 1
Epoch 24 | Seed 1
Epoch 25 | Seed 1
Epoch 26 | Seed 1
Epoch 27 | Seed 1
Epoch 28 | Seed 1
Epoch 29 | Seed 1
Epoch 30 | Seed 1
Epoch 31 | Seed 1
Epoch 32 | Seed 1
Epoch 33 | Seed 1
Epoch 34 | Seed 1
Epoch 35 | Seed 1
Epoch 36 | Seed 1
Epoch 37 | Seed 1
Epoch 38 | Seed 1
Epoch 39 | Seed 1
Epoch 40 | Seed 1
Epoch 41 | Seed 1
Epoch 42 | Seed 1
Epoch 43 | Seed 1
Epoch 44 | Seed 1
Epoch 45 | Seed 1
Epoch 46 | Seed 1
Epoch 47 | Seed 1
Epoch 48 | Seed 1
Epoch 49 | Seed 1
Epoch 50 | Seed 1
Epoch 51 | Seed 1
Epoch 52 | Seed 1
Epoch 53 | Seed 1
Epoch 54 | Seed 1
Epoch 55 | Seed 1
Epoch 56 | Seed 1
E