In [8]:
import torch
import inspect
import torch.optim as optim
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.modules.activation as A
from tqdm import tqdm
%pylab inline
import torchvision
import torchvision.transforms as transforms
from livelossplot import PlotLosses
import os, errno
from torch.distributions.multivariate_normal import MultivariateNormal

Populating the interactive namespace from numpy and matplotlib


In [9]:
mnist_data = torchvision.datasets.MNIST('../datasets/mnist', download=True, train=True,
                                           transform=transforms.Compose([
                           transforms.ToTensor(),
#                            transforms.Normalize((0.1307,), (0.3081,))
                       ]))
mnist_loader = torch.utils.data.DataLoader(mnist_data, batch_size=32, shuffle=True)

In [10]:
# list(mnist_loader)[0][0][0]

In [11]:
class Encoder(nn.Module): 
    def __init__(self, nz=100):
        """
        Encoder(z|X) = Q(z|X) = N(z|mu(X;theta), Sigma(X;theta))
        
        Input 
        - X: datapoint
        
        Output
        - mu_z(X): means
        - Sigma_z(X): variances
        """
        super(Encoder, self).__init__()
        self.features = nn.Sequential(
            # Project image 
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # Conv1
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            # Conv2
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(256),
        )
        self.fc11 = nn.Linear(2304, nz)
        self.fc12 = nn.Linear(2304, nz)
        
    def forward(self, X):
        X = self.features(X)
        X = X.view(X.size(0), 2304)
        return self.fc11(X), self.fc12(X)
     
class Decoder(nn.Module):
    def __init__(self, cz=100):
        """
        P(X|z;theta) = N(X|f(z;theta), sigma^2 * I)
        
        Input
        - z: noise
        
        Output:
        - mu_X(z): mean of gaussian X. Shape: [flatten(X), 1]
        """
        super(Decoder, self).__init__()
        layer_list = [
            nn.ConvTranspose2d(cz, 256, 4, 1, 0, bias=False),
            A.ReLU(True),
            nn.BatchNorm2d(256),
            # Conv1. (?, channels[0], 4, 4) -> (?, channels[1], 8, 8)
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            A.ReLU(True),
            nn.BatchNorm2d(128),
            # Conv2. (?, channels[1], 8, 8) -> (?, channels[2], 16, 16)
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            A.ReLU(True),
            # Conv3. 16x16 -> 28x28
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, 3, bias=False),
            A.Sigmoid()
        ]
        
        self.model = nn.Sequential(*layer_list)
        
    def forward(self, z):
        return self.model(z.view(z.size(0), z.size(1), 1, 1))

In [None]:
Q = Encoder(nz=100)
D = Decoder()

def KL_divergence_gaussians(noise_mus, noise_sigmas):
    """
    From Appendix B: https://arxiv.org/pdf/1312.6114.pdf
    """    
#     Q_z_X = torch.distributions.multivariate_normal.MultivariateNormal(
#     loc, covariance_matrix=None)
#     P_z = torch.distributions.multivariate_normal.MultivariateNormal(
#     torch.zeros(), covariance_matrix=torch.eye())
#     Q_loss = torch.distributions.kl.kl_divergence(Q_z_X, P_z)

    return .5 * torch.sum(1 + torch.log(noise_sigmas**2) - noise_mus**2 - noise_sigmas**2) 

def sample_noise(z_mus, z_sigmas):
#     import ipdb; ipdb.set_trace()
    eps = torch.randn_like(z_mus)*z_sigmas
    return eps+z_mus

In [None]:
fixed_noise = torch.randn([32, 100, 1, 1])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
opt = optim.Adam(list(Q.parameters()) + list(D.parameters()), lr=0.0002, betas=(.5, .999))
for epoch in range(5):
    for batch_idx, X in tqdm(enumerate(mnist_loader)):
        X = X[0]
        X = X.to(device, dtype=torch.float32)
        batch_size = X.shape[0]
        
        z_mus, z_sigmas = Q(X)
        
        z = sample_noise(z_mus, z_sigmas) # Should be 32x100... 
        
        Q.zero_grad()       
        Q_loss = KL_divergence_gaussians(z_mus, z_sigmas)

        D.zero_grad()
        D_loss = F.binary_cross_entropy(D(z), X)
        (Q_loss + D_loss).backward()
        
        opt.step()
        if batch_idx % 100 == 0:
            import torchvision.utils as vutils
            gen_img = D(fixed_noise)
            vutils.save_image(gen_img.detach(), F"fake_samples-{epoch}-{batch_idx}.png", normalize=True)


0it [00:00, ?it/s][A
1it [00:00,  2.76it/s][A
2it [00:00,  3.14it/s][A
3it [00:00,  3.29it/s][A
4it [00:01,  3.41it/s][A
5it [00:01,  3.51it/s][A
6it [00:01,  3.52it/s][A
7it [00:02,  3.46it/s][A
8it [00:02,  3.52it/s][A
Exception in thread Thread-4:
Traceback (most recent call last):
  File "/Users/richard/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/Users/richard/anaconda3/lib/python3.6/site-packages/tqdm/_tqdm.py", line 144, in run
    for instance in self.tqdm_cls._instances:
  File "/Users/richard/anaconda3/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration

1875it [09:42,  3.22it/s]
1875it [09:11,  3.40it/s]
210it [01:00,  3.48it/s]

In [None]:
torch.Tensor([1,2,3]).narrow(0, 0, 2)

In [None]:
m = torch.Tensor([[[1,1],[1,1]],[[2,2],[2,2]],[[3,3],[3,3]]])
m

In [None]:
# n =m.view(1, torch.prod(torch.tensor(m.size())))
n = m.view(m.shape[0], )

In [None]:
n.view(3, 3)

In [None]:
Q_z_X = torch.distributions.multivariate_normal.MultivariateNormal(
    torch.ones(3), covariance_matrix=torch.eye(3))
P_z = torch.distributions.multivariate_normal.MultivariateNormal(
    torch.ones(3), covariance_matrix=torch.eye(3))
Q_loss = torch.distributions.kl.kl_divergence(Q_z_X, P_z)

In [None]:
Q_loss

In [None]:
torch.Tensor([1,2]) * torch.Tensor([2,3])