In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable

In [None]:
transform = torchvision.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
#dataset = torchvision.datasets.MNIST(root="./mnist", train=True, download=True, transform=torchvision.transforms.ToTensor())
dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)

Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


In [None]:
dataloader = torch.utils.data.DataLoader(dataset=dataset, shuffle=True, batch_size=50)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

def imshow(img):
    # unnormalize
    npimg = img.numpy()
    npimg = np.reshape(npimg, (50, 32, 32), order='F')
    plt.imshow(npimg[0])

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, height, width, z_dim):
        super(Encoder, self).__init__()
        self.width = width
        self.height = height
        self.z_dim = z_dim
        self.lin1 = nn.Linear(width * height * 3, 400)
        
        self.mu_out = nn.Sequential(
            nn.Linear(400, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, z_dim))
        self.var_out = nn.Sequential(
            nn.Linear(400, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, z_dim))
    
    def forward(self, X):
        X = X.view(-1, self.width * self.height * 3)
        out = F.relu(self.lin1(X))
        mu = self.mu_out(out)
        var = self.var_out(out)
        return mu, var
    
    def sample(self, batch_size, mu, log_var):
        eps = Variable(torch.randn(batch_size, self.z_dim)).cuda()
        return mu + torch.exp(log_var / 2) * eps


In [None]:
class Decoder(nn.Module):
    
    def __init__(self, width, height, z_dim):
        super(Decoder, self).__init__()
        self.width = width
        self.height = height
        self.z_dim = z_dim
        
        self.lin1 = nn.Sequential(
            nn.Linear(z_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 500),
            nn.ReLU(),
            nn.Linear(500, 400))
        self.out = nn.Linear(400, self.width * self.height * 3)
    
    def forward(self, X):
        out = F.relu(self.lin1(X))
        out = F.sigmoid(self.out(out))
        return out

In [None]:
encoder = Encoder(32, 32, 300)
decoder = Decoder(32, 32, 300)
params = list(encoder.parameters()) + list(decoder.parameters())

In [None]:
optim = torch.optim.Adam(params, lr=1e-3)
recon_loss = nn.BCELoss(size_average=False)

In [None]:
encoder.cuda()
decoder.cuda()
recon_loss.cuda()

In [None]:
n_epochs = 15

for epoch in range(n_epochs):
    for img, label in dataloader:
        optim.zero_grad()

        X = Variable(img.cuda())

        z_mu, z_var = encoder(X)
        z = encoder.sample(50, z_mu, z_var).cuda()
        X_sample = decoder(z)

        rec_loss = recon_loss(X_sample, X)
        KL_loss = z_mu.pow(2).add_(z_var.exp()).mul_(-1).add_(1).add_(z_var)
        KL_loss = torch.sum(KL_loss).mul_(-0.5)
        total_loss = rec_loss + KL_loss

        total_loss.backward()
        optim.step()
        
    fake_images = X_sample.data.view(-1, 3, 32, 32)
    torchvision.utils.save_image(fake_images, filename="vae%03d.png" % epoch, normalize=True)
    print("Epoch: ", epoch, " total loss: ", total_loss.data[0])

In [None]:
test_z = Variable(torch.FloatTensor(50, 300).normal_()).cuda()
result = decoder(test_z).data.view(-1, 3, 32, 32)
torchvision.utils.save_image(result, "vae-results.png")