In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable

In [2]:
transform = torchvision.transforms.Compose([
    transforms.ToTensor()
    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [3]:
dataset = torchvision.datasets.MNIST(root="./mnist", train=True, download=True, transform=torchvision.transforms.ToTensor())
#dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)

Files already downloaded and verified


In [4]:
dataloader = torch.utils.data.DataLoader(dataset=dataset, shuffle=True, batch_size=200)

In [5]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

def imshow(img):
    # unnormalize
    npimg = img.numpy()
    npimg = np.reshape(npimg, (50, 32, 32), order='F')
    plt.imshow(npimg[0])

In [9]:
class Decoder(nn.Module):
    
    def __init__(self, width, height, z_dim):
        super(Decoder, self).__init__()
        self.width = width
        self.height = height
        self.z_dim = z_dim
        
        self.lin1 = nn.Sequential(
            nn.Linear(z_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 500),
            nn.ReLU(),
            nn.Linear(500, 400))
        self.out = nn.Linear(400, self.width * self.height * 1)
    
    def forward(self, X):
        out = F.relu(self.lin1(X))
        out = F.sigmoid(self.out(out))
        return out

In [10]:
encoder = Encoder(32, 32, 300)
decoder = Decoder(32, 32, 300)
params = list(encoder.parameters()) + list(decoder.parameters())

In [11]:
optim = torch.optim.Adam(params, lr=1e-3)
recon_loss = nn.BCELoss()

In [12]:
encoder.cuda()
decoder.cuda()
recon_loss.cuda()

BCELoss (
)

In [13]:
n_epochs = 10

for epoch in range(n_epochs):
    for img, label in dataloader:
        optim.zero_grad()

        X = Variable(img.cuda())

        z_mu, z_var = encoder(X)
        z = encoder.sample(200, z_mu, z_var).cuda()
        X_sample = decoder(z)

        rec_loss = recon_loss(X_sample, X)
        KL_loss = z_mu.pow(2).add_(z_var.exp()).mul_(-1).add_(1).add_(z_var)
        KL_loss = torch.sum(KL_loss).mul_(-0.5)
        total_loss = rec_loss + KL_loss

        total_loss.backward()
        optim.step()
        
    fake_images = X_sample.data.view(-1, 1, 32, 32)
    torchvision.utils.save_image(fake_images, filename="vae%03d.png" % epoch, normalize=True)
    print("Epoch: ", epoch, " total loss: ", total_loss.data[0])

Epoch:  0  total loss:  -782325.4375
Epoch:  1  total loss:  -701930.1875
Epoch:  2  total loss:  -820122.8125
Epoch:  3  total loss:  -802150.5
Epoch:  4  total loss:  -794040.0625
Epoch:  5  total loss:  -827085.1875
Epoch:  6  total loss:  -710032.625
Epoch:  7  total loss:  -796050.9375


KeyboardInterrupt: 

In [14]:
test_z = Variable(torch.FloatTensor(50, 300).normal_()).cuda()
result = decoder(test_z).data.view(-1, 3, 32, 32)
torchvision.utils.save_image(result, "pure-vae-results.png")