In [6]:
import numpy as np
import time

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.distributions.normal import Normal

from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence

In [8]:
def to_scalar(arr):
    if type(arr) == list:
        return [x.item() for x in arr]
    else:
        return arr.item()


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        try:
            nn.init.xavier_uniform_(m.weight.data)
            m.bias.data.fill_(0)
        except AttributeError:
            print("Skipping initialization of ", classname)


class VAE(nn.Module):
    def __init__(self, input_dim, dim, z_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 5, 1, 0),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, z_dim * 2, 3, 1, 0),
            nn.BatchNorm2d(z_dim * 2)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(z_dim, dim, 3, 1, 0),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, dim, 5, 1, 0),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
            nn.Tanh()
        )

        self.apply(weights_init)

    def forward(self, x):
        mu, logvar = self.encoder(x).chunk(2, dim=1)

        q_z_x = Normal(mu, logvar.mul(.5).exp())
        p_z = Normal(torch.zeros_like(mu), torch.ones_like(logvar))
        kl_div = kl_divergence(q_z_x, p_z).sum(1).mean()

        x_tilde = self.decoder(q_z_x.rsample())
        return x_tilde, kl_div

In [18]:
BATCH_SIZE = 32
N_EPOCHS = 100
PRINT_INTERVAL = 500
DATASET = 'CIFAR10'  # CIFAR10 | MNIST | FashionMNIST
NUM_WORKERS = 4

INPUT_DIM = 3
DIM = 256
Z_DIM = 128
LR = 1e-3

In [19]:
preproc_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_loader = torch.utils.data.DataLoader(
    eval('datasets.'+DATASET)(
        '../data/{}/'.format(DATASET), train=True, download=True,
        transform=preproc_transform,
    ), batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
    eval('datasets.'+DATASET)(
        '../data/{}/'.format(DATASET), train=False,
        transform=preproc_transform
    ), batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)

Files already downloaded and verified


In [20]:
model = VAE(INPUT_DIM, DIM, Z_DIM).cuda()
print(model)
opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)

VAE(
  (encoder): Sequential(
    (0): Conv2d(3, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1))
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d(256, 256, kernel_size=(5, 5),

In [21]:
def train():
    train_loss = []
    model.train()
    for batch_idx, (x, _) in enumerate(train_loader):
        start_time = time.time()
        x = x.cuda()

        x_tilde, kl_d = model(x)
        loss_recons = F.mse_loss(x_tilde, x, size_average=False) / x.size(0)
        loss = loss_recons + kl_d

        nll = -Normal(x_tilde, torch.ones_like(x_tilde)).log_prob(x)
        log_px = nll.mean().item() - np.log(128) + kl_d.item()
        log_px /= np.log(2)

        opt.zero_grad()
        loss.backward()
        opt.step()

        train_loss.append([log_px, loss.item()])

        if (batch_idx + 1) % PRINT_INTERVAL == 0:
            print('\tIter [{}/{} ({:.0f}%)]\tLoss: {} Time: {:5.3f} ms/batch'.format(
                batch_idx * len(x), len(train_loader.dataset),
                PRINT_INTERVAL * batch_idx / len(train_loader),
                np.asarray(train_loss)[-PRINT_INTERVAL:].mean(0),
                1000 * (time.time() - start_time)
            ))

In [22]:
def test():
    start_time = time.time()
    val_loss = []
    model.eval()
    with torch.no_grad():
        for batch_idx, (x, _) in enumerate(test_loader):
            x = x.cuda()
            x_tilde, kl_d = model(x)
            loss_recons = F.mse_loss(x_tilde, x, size_average=False) / x.size(0)
            loss = loss_recons + kl_d
            val_loss.append(loss.item())

    print('\nValidation Completed!\tLoss: {:5.4f} Time: {:5.3f} s'.format(
        np.asarray(val_loss).mean(0),
        time.time() - start_time
    ))
    return np.asarray(val_loss).mean(0)

In [23]:
def generate_reconstructions():
    model.eval()
    x, _ = test_loader.__iter__().next()
    x = x[:32].cuda()
    x_tilde, kl_div = model(x)

    x_cat = torch.cat([x, x_tilde], 0)
    images = (x_cat.cpu().data + 1) / 2

    save_image(
        images,
        'samples/vae_reconstructions_{}.png'.format(DATASET),
        nrow=8
    )


def generate_samples():
    model.eval()
    z_e_x = torch.randn(64, Z_DIM, 1, 1).cuda()
    x_tilde = model.decoder(z_e_x)

    images = (x_tilde.cpu().data + 1) / 2

    save_image(
        images,
        'samples/vae_samples_{}.png'.format(DATASET),
        nrow=8
    )

In [26]:
BEST_LOSS = 99999
LAST_SAVED = -1
for epoch in range(1, N_EPOCHS):
    print("Epoch {}:".format(epoch))
    train()
    cur_loss = test()

    if cur_loss <= BEST_LOSS:
        BEST_LOSS = cur_loss
        LAST_SAVED = epoch
        print("Saving model!")
        torch.save(model.state_dict(), 'models/{}_vae.pt'.format(DATASET))
    else:
        print("Not saving model! Last saved: {}".format(LAST_SAVED))

    generate_reconstructions()
    generate_samples()


Epoch 1:





Validation Completed!	Loss: 144.3787 Time: 7.425 s
Saving model!
Epoch 2:

Validation Completed!	Loss: 135.0530 Time: 7.384 s
Saving model!
Epoch 3:

Validation Completed!	Loss: 129.9874 Time: 7.598 s
Saving model!
Epoch 4:

Validation Completed!	Loss: 127.5721 Time: 7.197 s
Saving model!
Epoch 5:

Validation Completed!	Loss: 124.8113 Time: 7.565 s
Saving model!
Epoch 6:

Validation Completed!	Loss: 123.4451 Time: 7.396 s
Saving model!
Epoch 7:

Validation Completed!	Loss: 122.2854 Time: 7.809 s
Saving model!
Epoch 8:

Validation Completed!	Loss: 120.7163 Time: 7.236 s
Saving model!
Epoch 9:

Validation Completed!	Loss: 120.8584 Time: 7.502 s
Not saving model! Last saved: 8
Epoch 10:

Validation Completed!	Loss: 119.3024 Time: 7.281 s
Saving model!
Epoch 11:

Validation Completed!	Loss: 119.2188 Time: 7.283 s
Saving model!
Epoch 12:

Validation Completed!	Loss: 118.0684 Time: 7.231 s
Saving model!
Epoch 13:

Validation Completed!	Loss: 117.7397 Time: 7.325 s
Saving model!
Epoch 14:

V

Not saving model! Last saved: 24
Epoch 26:

Validation Completed!	Loss: 114.8647 Time: 7.182 s
Saving model!
Epoch 27:

Validation Completed!	Loss: 114.6178 Time: 7.061 s
Saving model!
Epoch 28:

Validation Completed!	Loss: 115.2309 Time: 7.311 s
Not saving model! Last saved: 27
Epoch 29:

Validation Completed!	Loss: 114.4311 Time: 8.273 s
Saving model!
Epoch 30:

Validation Completed!	Loss: 114.4705 Time: 9.851 s
Not saving model! Last saved: 29
Epoch 31:

Validation Completed!	Loss: 114.1936 Time: 7.124 s
Saving model!
Epoch 32:

Validation Completed!	Loss: 113.8678 Time: 9.378 s
Saving model!
Epoch 33:

Validation Completed!	Loss: 114.1099 Time: 6.971 s
Not saving model! Last saved: 32
Epoch 34:

Validation Completed!	Loss: 113.7880 Time: 7.043 s
Saving model!
Epoch 35:

Validation Completed!	Loss: 113.3419 Time: 7.241 s
Saving model!
Epoch 36:

Validation Completed!	Loss: 113.9036 Time: 7.194 s
Not saving model! Last saved: 35
Epoch 37:

Validation Completed!	Loss: 113.7269 Time: 7


Validation Completed!	Loss: 112.6075 Time: 10.564 s
Saving model!
Epoch 51:

Validation Completed!	Loss: 112.8351 Time: 6.944 s
Not saving model! Last saved: 50
Epoch 52:

Validation Completed!	Loss: 113.0245 Time: 9.071 s
Not saving model! Last saved: 50
Epoch 53:

Validation Completed!	Loss: 112.9773 Time: 7.024 s
Not saving model! Last saved: 50
Epoch 54:

Validation Completed!	Loss: 113.0638 Time: 7.764 s
Not saving model! Last saved: 50
Epoch 55:

Validation Completed!	Loss: 112.7679 Time: 7.019 s
Not saving model! Last saved: 50
Epoch 56:

Validation Completed!	Loss: 113.0661 Time: 11.363 s
Not saving model! Last saved: 50
Epoch 57:

Validation Completed!	Loss: 112.9423 Time: 7.028 s
Not saving model! Last saved: 50
Epoch 58:

Validation Completed!	Loss: 112.8368 Time: 7.012 s
Not saving model! Last saved: 50
Epoch 59:

Validation Completed!	Loss: 113.0769 Time: 7.066 s
Not saving model! Last saved: 50
Epoch 60:

Validation Completed!	Loss: 113.1030 Time: 8.111 s
Not saving mode


Validation Completed!	Loss: 112.1091 Time: 7.025 s
Not saving model! Last saved: 73
Epoch 75:

Validation Completed!	Loss: 112.2387 Time: 7.477 s
Not saving model! Last saved: 73
Epoch 76:

Validation Completed!	Loss: 111.8316 Time: 8.203 s
Saving model!
Epoch 77:

Validation Completed!	Loss: 111.7218 Time: 7.097 s
Saving model!
Epoch 78:

Validation Completed!	Loss: 111.7063 Time: 10.570 s
Saving model!
Epoch 79:

Validation Completed!	Loss: 111.7937 Time: 6.997 s
Not saving model! Last saved: 78
Epoch 80:

Validation Completed!	Loss: 111.8056 Time: 7.278 s
Not saving model! Last saved: 78
Epoch 81:

Validation Completed!	Loss: 111.7322 Time: 9.974 s
Not saving model! Last saved: 78
Epoch 82:

Validation Completed!	Loss: 111.7064 Time: 7.035 s
Not saving model! Last saved: 78
Epoch 83:

Validation Completed!	Loss: 111.9459 Time: 7.073 s
Not saving model! Last saved: 78
Epoch 84:

Validation Completed!	Loss: 111.7600 Time: 7.201 s
Not saving model! Last saved: 78
Epoch 85:

Validation


Validation Completed!	Loss: 112.2460 Time: 7.047 s
Not saving model! Last saved: 87
Epoch 99:

Validation Completed!	Loss: 113.0274 Time: 7.861 s
Not saving model! Last saved: 87
