<a href="https://colab.research.google.com/github/qinliuliuqin/GenerativeModels/blob/main/vanilla_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import os
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

# build datasets
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, 
                                transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, 
                               transform=transforms.ToTensor())

mb_size = 64
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=mb_size, 
                                           shuffle=True, num_workers=2)

test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=mb_size, 
                                          shuffle=False, num_workers=2)


Z_dim = 100
X_dim = 28
Y_dim=28
h_dim = 128
c = 0
lr = 1e-3


def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return torch.randn(*size) * xavier_stddev


# =============================== Q(z|X) ======================================

Wxh = xavier_init(size=[X_dim * Y_dim, h_dim])
bxh = torch.zeros(h_dim)

Whz_mu = xavier_init(size=[h_dim, Z_dim])
bhz_mu = torch.zeros(Z_dim)

Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = torch.zeros(Z_dim)


def Q(X):
    X = X.flatten(start_dim=1)
    h = nn.ReLU()(X @ Wxh + bxh.repeat(X.size(0), 1))
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1)
    z_var = h @ Whz_var + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_var


def sample_z(mu, log_var):
    eps = torch.randn(mb_size, Z_dim)
    return mu + torch.exp(log_var / 2) * eps


# =============================== P(X|z) ======================================

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = torch.zeros(h_dim)

Whx = xavier_init(size=[h_dim, X_dim * Y_dim])
bhx = torch.zeros(X_dim * Y_dim)

def P(z):
    h = nn.ReLU()(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = nn.Sigmoid()(h @ Whx + bhx.repeat(h.size(0), 1))

    X = X.reshape(-1, X_dim, Y_dim)
    return X


# =============================== TRAINING ====================================

params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

for param in params:
  param.requires_grad = True

solver = optim.Adam(params, lr=lr)

for it in range(1000):
    X, _ = next(iter(train_loader))
    X = X.squeeze(dim=1)

    # Forward
    z_mu, z_var = Q(X)
    z = sample_z(z_mu, z_var)
    X_sample = P(z)

    # Loss
    recon_loss = F.binary_cross_entropy(X_sample, X, size_average=False) / mb_size
    kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
    loss = recon_loss + kl_loss

    # Backward
    loss.backward()

    # Update
    solver.step()

    # Housekeeping
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = data.new().resize_as_(data).zero_()

    # Print and plot every now and then
    if it % 10 == 0:
        print('Iter-{}; Loss: {:.4}'.format(it, loss.item()))

        samples = P(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)



Iter-0; Loss: 802.9
Iter-10; Loss: 524.0
Iter-20; Loss: 420.7
Iter-30; Loss: 329.3
Iter-40; Loss: 264.9
Iter-50; Loss: 257.8
Iter-60; Loss: 255.3
Iter-70; Loss: 241.6
Iter-80; Loss: 240.0
Iter-90; Loss: 235.4
Iter-100; Loss: 226.1
Iter-110; Loss: 235.2
Iter-120; Loss: 221.7
Iter-130; Loss: 215.3
Iter-140; Loss: 216.4
Iter-150; Loss: 211.7
Iter-160; Loss: 218.2
Iter-170; Loss: 206.6
Iter-180; Loss: 202.9
Iter-190; Loss: 206.3
Iter-200; Loss: 198.6
Iter-210; Loss: 196.2
Iter-220; Loss: 179.7
Iter-230; Loss: 183.8
Iter-240; Loss: 191.4
Iter-250; Loss: 185.4
Iter-260; Loss: 182.2
Iter-270; Loss: 182.0
Iter-280; Loss: 177.9
Iter-290; Loss: 183.6
Iter-300; Loss: 187.0
Iter-310; Loss: 167.9
Iter-320; Loss: 173.8
Iter-330; Loss: 167.3
Iter-340; Loss: 170.9
Iter-350; Loss: 173.7
Iter-360; Loss: 179.0
Iter-370; Loss: 182.5
Iter-380; Loss: 168.4
Iter-390; Loss: 175.3
Iter-400; Loss: 170.4
Iter-410; Loss: 169.8
Iter-420; Loss: 160.1
Iter-430; Loss: 162.9
Iter-440; Loss: 169.7
Iter-450; Loss: 173.7