In [1]:
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import matplotlib.pyplot as plt

random.seed(42)

### Dummy Data stuff

In [27]:
train_data = torch.rand(1, 3, 101, 101)
test_data = torch.rand(1, 3, 101, 101)

In [3]:
n = nn.Conv2d(3, 64, 3, 2, 1)

In [5]:
x= n(Variable(train_data[0, :, :, :].view(1, 3, 512, 512)))
x.size()

torch.Size([1, 64, 256, 256])

### Designing Encoder (E)

In [7]:
class resBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, k=3, s=1, p=1):
        super(resBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, k, stride=s, padding=p)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, k, stride=s, padding=p)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        return self.bn2(self.conv2(y)) + x
    
class resTransposeBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, k=3, s=1, p=1):
        super(resTransposeBlock, self).__init__()

        self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, k, stride=s, padding=p)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.ConvTranspose2d(out_channels, out_channels, k, stride=s, padding=p)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        return self.bn2(self.conv2(y)) + x

In [16]:
class Encoder(nn.Module):
    def __init__(self, n_res_blocks=5):
        super(Encoder, self).__init__()
        self.n_res_blocks = n_res_blocks
        self.conv1 = nn.Conv2d(3, 64, 3, stride=2, padding=1)
        for i in range(n_res_blocks):
            self.add_module('residual_block_1' + str(i+1), resBlock(in_channels=64, out_channels=64, k=3, s=1, p=1))
        self.conv2 = nn.Conv2d(64, 32, 3, stride=2, padding=1)
        for i in range(n_res_blocks):
            self.add_module('residual_block_2' + str(i+1), resBlock(in_channels=32, out_channels=32, k=3, s=1, p=1))
        self.conv3 = nn.Conv2d(32, 8, 3, stride=1, padding=1)
        for i in range(n_res_blocks):
            self.add_module('residual_block_3' + str(i+1), resBlock(in_channels=8, out_channels=8, k=3, s=1, p=1))
        self.conv4 = nn.Conv2d(8, 1, 3, stride=1, padding=1)
    
    def forward(self, x):
        y = self.conv1(x)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_1'+str(i+1))(y)
        y = self.conv2(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_2'+str(i+1))(y)
        y = self.conv3(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_3'+str(i+1))(y)
        y = self.conv4(y)
        return y

E1 = Encoder()

In [28]:
t1 = E1(Variable(train_data))

### Designing Decoder (D)

In [37]:
class Decoder(nn.Module):
    def __init__(self, n_res_blocks=5):
        super(Decoder, self).__init__()
        self.n_res_blocks = n_res_blocks
        self.conv1 = nn.ConvTranspose2d(1, 8, 3, stride=1, padding=1)
        for i in range(n_res_blocks):
            self.add_module('residual_block_1' + str(i+1), resTransposeBlock(in_channels=8, out_channels=8, k=3, s=1, p=1))
        self.conv2 = nn.ConvTranspose2d(8, 32, 3, stride=1, padding=1)
        for i in range(n_res_blocks):
            self.add_module('residual_block_2' + str(i+1), resTransposeBlock(in_channels=32, out_channels=32, k=3, s=1, p=1))
        self.conv3 = nn.ConvTranspose2d(32, 64, 3, stride=2, padding=1)
        for i in range(n_res_blocks):
            self.add_module('residual_block_3' + str(i+1), resTransposeBlock(in_channels=64, out_channels=64, k=3, s=1, p=1))
        self.conv4 = nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1)
    
    def forward(self, x):
        y = self.conv1(x)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_1'+str(i+1))(y)
        y = self.conv2(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_2'+str(i+1))(y)
        y = self.conv3(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_3'+str(i+1))(y)
        y = self.conv4(y)
        return y

D1 = Decoder()

In [38]:
print(t1.size())
a = D1(t1)
print(a.size())

torch.Size([1, 1, 26, 26])
torch.Size([1, 3, 101, 101])


### Putting it in box, VAE

In [42]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder, batchsz):
        super(VAE, self).__init__()
        self.E = encoder
        self.D = decoder
        self.batchsz = batchsz
        self._enc_mu = nn.Linear(26*26, 128)
        self._enc_log_sigma = nn.Linear(26*26, 128)
        self._din_layer = nn.Linear(128, 26*26)
        
    def _sample_latent(self, h_enc):
        '''
        Return the latent normal sample z ~ N(mu, sigma^2)
        '''
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()

        self.z_mean = mu
        self.z_sigma = sigma
        return mu + sigma * Variable(std_z, requires_grad=False)  # Reparameterization trick

    def forward(self, x):
        h_enc = self.E(x)
        h_enc = h_enc.view(self.batchsz, 1, -1)
        z = self._sample_latent(h_enc)
        z = self._din_layer(z)
        z = z.view(self.batchsz, 1, 26, 26)
        return self.D(z)

In [43]:
V = VAE(E1, D1, 1)

In [44]:
V(Variable(train_data))

Variable containing:
( 0 , 0 ,.,.) = 
 -2.4041e-02  1.0406e-01  4.8794e-02  ...   7.3427e-02  1.2258e-01  5.3513e-02
 -3.9071e-01  2.2520e-01  2.1490e-01  ...  -9.0915e-02  4.6186e-01  8.8965e-02
 -1.9377e-01  7.4536e-02  3.5039e-02  ...   1.6370e-01  7.4748e-02  3.5968e-02
                 ...                   ⋱                   ...                
  8.3426e-02  3.9855e-01 -1.5048e-01  ...  -1.2913e-01  8.7150e-03  2.2838e-01
  4.4554e-03 -1.0679e+00 -1.8006e-01  ...  -1.1145e-01 -6.5338e-01 -7.7800e-01
 -1.4854e-01 -5.5733e-02  7.4461e-02  ...   1.1889e-03  5.4561e-01  7.5355e-03

( 0 , 1 ,.,.) = 
 -1.3546e-01 -1.4763e-02 -3.8727e-01  ...  -2.6202e-01 -2.9772e-01 -1.2940e-01
  1.3664e-01  3.8997e-01  2.2444e-01  ...   2.4836e-01  4.8468e-01  2.1742e-02
  1.2602e-01  1.8277e-01  6.1524e-02  ...  -2.7510e-02  2.7225e-02 -1.5069e-01
                 ...                   ⋱                   ...                
  1.1608e-01 -6.1379e-02 -1.2789e-01  ...   1.2331e-01 -6.2389e-01 -2.8936e

### training thingy

In [45]:
def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

In [None]:
def train(model, dataloader):
    