In [43]:
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn

import pyro
import pyro.optim as optim
from pyro.optim import Adam
from pyro.infer import SVI
import pyro.distributions as dist
from pyro.util import ng_zeros, ng_ones


# for loading and batching MNIST dataset
def setup_data_loaders(batch_size=128, use_cuda=False):
    root = '../data/raw'
    download = True
    trans = transforms.ToTensor()
    train_set = dset.MNIST(root=root, train=True, transform=trans,
                           download=download)
    test_set = dset.MNIST(root=root, train=False, transform=trans)

    kwargs = {'num_workers': 8, 'pin_memory': use_cuda}
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
        batch_size=batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader

In [47]:
class Decoder(nn.Module):
    
    def __init__(self, z_dim, hidden_dim, fudge=1e-3):
        
        super(Decoder, self).__init__()

        # setup the three linear transformations used
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, 784)
        
        # setup the non-linearity
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()
        
        self.fudge = fudge

        
    def forward(self, z):
        # define the forward computation on the latent z
        # first compute the hidden units
        hidden = self.softplus(self.fc1(z))
        
        # return the parameter for the output Bernoulli
        # each is of size batch_size x 784
        # fixing numerical instabilities of sigmoid with a fudge
        mu_img = (self.sigmoid(self.fc21(hidden))+self.fudge) * (1-2*self.fudge)
        
        return mu_img

In [48]:
class Encoder(nn.Module):
    
    def __init__(self, z_dim, hidden_dim):
        
        super(Encoder, self).__init__()
        
        # setup the three linear transformations used
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        
        # setup the non-linearity
        self.softplus = nn.Softplus()
        self.relu = nn.ReLU()

        
    def forward(self, x):
        
        # define the forward computation on the image x
        # first shape the mini-batch to have pixels in the rightmost dimension
        x = x.view(-1, 784)
        
        # then compute the hidden units
        
        hidden = self.softplus(self.fc1(x))
        
        # then return a mean vector and a (positive) square root covariance
        # each of size batch_size x z_dim
        z_mu = self.fc21(hidden)
        z_sigma = torch.exp(self.fc22(hidden))
        
        return z_mu, z_sigma

In [57]:
class VAE(nn.Module):
    
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
    
        super(VAE, self).__init__()
        
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, hidden_dim)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        
        self.use_cuda = use_cuda
        self.z_dim = z_dim
        
        
    def model(self, x):

        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)

        # setup hyperparameters for prior p(z)
        # the type_as ensures we get CUDA Tensors if x is on gpu
        z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
        z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)

        # sample from prior 
        # (value will be sampled by guide when computing the ELBO)
        z = pyro.sample("latent", dist.normal, z_mu, z_sigma)

        # decode the latent code z
        mu_img = self.decoder(z)

        # score against actual images
        pyro.observe("obs", dist.bernoulli, x.view(-1, 784), mu_img)

        
    def guide(self, x): 

        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)

        # use the encoder to get the parameters used to define q(z|x)
        z_mu, z_sigma = self.encoder(x)

        # sample the latent code z
        pyro.sample("latent", dist.normal, z_mu, z_sigma)
        
        
    def reconstruct_img(self, x):
            
        # encode image x
        z_mu, z_sigma = self.encoder(x)
        
        # sample in latent space
        z = dist.normal(z_mu, z_sigma)
        
        # decode the image (note we don't sample in image space)
        mu_img = self.decoder(z)
        
        return mu_img

In [50]:
_ = Encoder(10,10)
_ = Decoder(10,10)
VAE()

VAE (
  (encoder): Encoder (
    (fc1): Linear (784 -> 400)
    (fc21): Linear (400 -> 50)
    (fc22): Linear (400 -> 50)
    (softplus): Softplus (beta=1, threshold=20)
    (relu): ReLU ()
  )
  (decoder): Decoder (
    (fc1): Linear (50 -> 400)
    (fc21): Linear (400 -> 784)
    (softplus): Softplus (beta=1, threshold=20)
    (sigmoid): Sigmoid ()
  )
)

In [55]:
class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

args = Bunch({'cuda': True,
              'num_epochs': 10,
              'test_frequency': 10000})

In [52]:
train_loader, test_loader = setup_data_loaders(use_cuda = args.cuda)

In [53]:
vae = VAE(use_cuda=args.cuda)
print(vae)

optimizer = Adam({'lr': 1e-4})

svi = SVI(vae.model, vae.guide, optimizer, loss = 'ELBO')

VAE (
  (encoder): Encoder (
    (fc1): Linear (784 -> 400)
    (fc21): Linear (400 -> 50)
    (fc22): Linear (400 -> 50)
    (softplus): Softplus (beta=1, threshold=20)
    (relu): ReLU ()
  )
  (decoder): Decoder (
    (fc1): Linear (50 -> 400)
    (fc21): Linear (400 -> 784)
    (softplus): Softplus (beta=1, threshold=20)
    (sigmoid): Sigmoid ()
  )
)


In [56]:
for epoch in range(args.num_epochs):
    
    # initialize loss accumulator
    epoch_loss = 0.
    
    # do a training epoch over each mini-batch x 
    # returned by the data loader
    for _, (x, _) in enumerate(train_loader):
    
        # if on GPU put mini-batch into CUDA memory
        if args.cuda:
            x = x.cuda()
        
        # wrap the mini-batch in a PyTorch Variable
        
        x = Variable(x)
        
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x)
        
    if epoch % args.test_frequency == 0:
        
        # initialize loss accumulator
        test_loss = 0.
        
        # compute the loss over the entire test set
        for i, (x, _) in enumerate(test_loader):
        
            # if on GPU put mini-batch into CUDA memory
            if args.cuda:
                x = x.cuda()
        
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)

            # compute ELBO estimate and accumulate loss
            test_loss += svi.evaluate_loss(x)

In [58]:
epoch_loss, test_loss

(34684585.37890625, 5779326.51171875)