In [52]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

from multiprocessing import cpu_count

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Subset, DataLoader

from torch.distributions import *

import skorch
import numpy as np

%load_ext tensorboard
torch.autograd.set_detect_anomaly(True)

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7fd1945420f0>

## Dataset

In [31]:
use_cuda = torch.cuda.is_available()

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': cpu_count(), 'pin_memory': True} if use_cuda else {}
train_loader = DataLoader(Subset(
    datasets.MNIST('/data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), indices=range(10000)),
    batch_size=64, shuffle=True, **kwargs)
test_loader = DataLoader(Subset(
    datasets.MNIST('/data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), indices=range(10000)),
    batch_size=1000, shuffle=True, **kwargs)

## Generic autoencoder class

In [91]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.writer = SummaryWriter(log_dir='/data/runs')
    
    def trains(self, device, train_loader, optimizer, epoch):
        self.train()
        loss_sum = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            loss = self.compute_loss_train(data, target)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
            if batch_idx % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
            self.writer.add_scalar('Loss/train', loss.item(), epoch*len(train_loader)+batch_idx)
            
    def tests(self, device, test_loader):
        self.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                loss, output = self.compute_loss_test(data, target)
                test_loss += loss
                l1 = F.l1_loss(output, data.view(-1, 784), reduction='sum')

        test_loss /= len(test_loader.dataset)

        print('\nTest set: Average loss: {:.4f}, Reconstruction error: {}\n'.format(
            test_loss, l1))

## Autoencoder

In [92]:
class SimpleAutoencoder(Autoencoder):
    def __init__(self):
        super(SimpleAutoencoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output       
    
    def compute_loss_train(self, data, target):
        output = self(data)
        return F.nll_loss(output, target)
    
    def compute_loss_test(self, data, target):
        output = self(data)
        return F.nll_loss(output, target, reduction='sum').item(), output  # sum up batch loss

In [80]:
model = SimpleAutoencoder().to(device)
optimizer = optim.Adadelta(model.parameters())

# plot model
dataiter = iter(train_loader)
images, labels = dataiter.next()

# create grid of images
img_grid = torchvision.utils.make_grid(images)

# write to tensorboard
#writer.add_image('mnist_images', img_grid)

scheduler = StepLR(optimizer, step_size=1)
for epoch in range(1, 14 + 1):
    model.trains(device, train_loader, optimizer, epoch)
    model.tests(device, test_loader)
    scheduler.step()



KeyboardInterrupt: 

## Gaussian Variational Autoencoder

In [93]:
class VAE(Autoencoder):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)

        self.fc1 = nn.Linear(9216, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return BCE + KLD
    
    def compute_loss_train(self, data, target):
        recon_batch, mu, logvar = self(data)
        return self.loss_function(recon_batch, data, mu, logvar)
    
    def compute_loss_test(self, data, target):
        recon_batch, mu, logvar = self(data)
        return self.loss_function(recon_batch, data, mu, logvar).item(), recon_batch  # sum up batch loss

In [94]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


scheduler = StepLR(optimizer, step_size=1)
for epoch in range(1, 14 + 1):
    model.trains(device, train_loader, optimizer, epoch)
    model.tests(device, test_loader)
    scheduler.step()


Test set: Average loss: -13556.3616, Reconstruction error: 495748.84375


Test set: Average loss: -14095.1669, Reconstruction error: 491923.34375


Test set: Average loss: -14156.5123, Reconstruction error: 488775.40625



KeyboardInterrupt: 

## Stick-breaking process

In [141]:
def stickbreakingprocess(k, a, b):
    batch_size = a.size()[0]
    uniform_samples = Uniform(torch.tensor([0.0]), torch.tensor([1.0])).rsample(torch.tensor([batch_size,k])).view(-1,k)
    exp_a = 1/a
    exp_b = 1/b
    km = (1- uniform_samples.pow(exp_b)).pow(exp_a)
    
    #no Nans are allowed in the matrix
    #assert not torch.isnan(km).any().item()
    
    sticks = []
    remaining_sticks = torch.ones_like(km[:,0])
    for i in range(0,k-1):
        with torch.no_grad():
            sticks.append(remaining_sticks * km[:,i])
        remaining_sticks *= (1-km[:,i])
    latent_variables = torch.stack((*sticks, remaining_sticks))

    #all stick segments must sum to 1
    #assert torch.allclose(latent_variables.sum(axis=1), torch.ones([batch_size]))
    
    return latent_variables.T

In [147]:
stickbreakingprocess(20, torch.rand(10,20), torch.rand(10,20))

tensor([[1.7893e-01, 2.9643e-01, 2.9868e-01, 1.2843e-01, 9.0161e-03, 8.1781e-07,
         7.2528e-04, 8.7332e-03, 7.9053e-02, 3.4905e-06, 1.7160e-06, 7.2626e-07,
         2.0973e-06, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [9.9991e-01, 4.2863e-05, 2.2989e-34, 4.6840e-05, 7.2359e-07, 1.1992e-13,
         9.4005e-09, 1.4730e-08, 2.8938e-08, 1.8523e-10, 2.1799e-17, 2.6697e-19,
         1.1168e-17, 1.4743e-18, 1.5319e-18, 7.6615e-18, 2.5432e-19, 0.0000e+00,
         5.5224e-21, 1.2506e-22],
        [2.1631e-04, 9.5246e-03, 9.8635e-01, 3.2056e-03, 6.6717e-14, 5.9279e-04,
         6.0618e-05, 2.6082e-09, 4.8095e-05, 5.4345e-09, 1.3581e-10, 4.9258e-09,
         2.0890e-10, 3.9888e-11, 1.6006e-11, 1.9535e-14, 1.2496e-13, 2.5209e-16,
         2.0741e-19, 1.0658e-22],
        [5.7463e-01, 4.2537e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
       

## Stick-breaking Autoencoder

In [143]:
class SBVAE(Autoencoder):
    def __init__(self, k):
        super(SBVAE, self).__init__()
        self.k = k
        
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)

        self.fc1 = nn.Linear(9216, 400)
        self.fc21 = nn.Linear(400, k)
        self.fc22 = nn.Linear(400, k)
        
        
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        h1 = F.relu(self.fc1(x))
        return F.softplus(self.fc21(h1)), F.softplus(self.fc22(h1))

    def reparameterize(self, a, b):
        return stickbreakingprocess(20, a, b)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        a, b = self.encode(x)
        z = self.reparameterize(a, b)
        return self.decode(z), a, b
    
    def Beta(self, a,b):
        return torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a+b))
    
    def loss_function(self, recon_x, x, a, b, prior_alpha, prior_beta):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        kl = 1./(1+a*b) * self.Beta(1./a, b)
        kl += 1./(2+a*b) * self.Beta(2./a, b)
        kl += 1./(3+a*b) * self.Beta(3./a, b)
        kl += 1./(4+a*b) * self.Beta(4./a, b)
        kl += 1./(5+a*b) * self.Beta(5./a, b)
        kl += 1./(6+a*b) * self.Beta(6./a, b)
        kl += 1./(7+a*b) * self.Beta(7./a, b)
        kl += 1./(8+a*b) * self.Beta(8./a, b)
        kl += 1./(9+a*b) * self.Beta(9./a, b)
        kl += 1./(10+a*b) * self.Beta(10./a, b)
        kl *= (prior_beta-1)*b
                                                                                                                                            
        kl += (a-prior_alpha)/a * (-np.euler_gamma - torch.digamma(b) - 1/b) #T.psi(self.posterior_b)                                                                                        

        # add normalization constants                                                                                                                                                                
        kl += torch.log(a*b) #+ torch.log(self.Beta(prior_alpha, prior_beta))

        # final term                                                                                                                                                                                 
        kl += -(b-1)/b 

        return BCE + kl.sum()
    
    def compute_loss_train(self, data, target):
        recon_batch, a, b = self(data)
        return self.loss_function(recon_batch, data, a, b, torch.Tensor([1]), torch.Tensor([5]))
    
    def compute_loss_test(self, data, target):
        recon_batch, a, b = self(data)
        return self.loss_function(recon_batch, data, a, b, 1, 5).item(), recon_batch  # sum up batch loss

In [156]:
next(iter(train_loader))[0].size()

torch.Size([64, 1, 28, 28])

In [153]:
model = SBVAE(k=20).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.writer.add_graph(model, next(iter(train_loader))[0])


scheduler = StepLR(optimizer, step_size=1)
for epoch in range(1, 14 + 1): 
    model.trains(device, train_loader, optimizer, epoch)
    model.tests(device, test_loader)
    scheduler.step()

  This is separate from the ipykernel package so we can avoid doing imports until
  This is separate from the ipykernel package so we can avoid doing imports until
	%rand : Float(64, 20, 1) = aten::rand(%148, %149, %150, %151, %152), scope: SBVAE # /opt/conda/lib/python3.6/site-packages/torch/distributions/uniform.py:67:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
  check_tolerance, _force_outplace, True, _module_class)
Not within tolerance rtol=1e-05 atol=1e-05 at input[37, 276] (0.5113456845283508 vs. 0.5626243948936462) and 49735 other locations (99.00%)
  check_tolerance, _force_outplace, True, _module_class)




KeyboardInterrupt: 