Log: 19/12/13, Zhihan Yang

The purpose of this notebook:
- With the help of `vae-designer-demo.ipynb`, I hope to construct arbitrary VAE architectures based on various model-level and layer-level parameters.
- Specially, I use this notebook to write up a function that takes in these parameters and output the desired VAE for training.

Todos:
- Remove `DataParallel` because I am agnostic towards how using multi-GPU training affects model convergence. (d)
- Instead of building the VAE from one class, build two subclasses (Encoder and Decoder) and let VAE inherit from them - the benefit is that now we can use `super(self, VAE).__init__` to directly initialize the encoder and decoder network.
    - within the init function of VAE, pytorch only collects parameters that are of certain pytorch types, which prevents me from setting attributes to instances of type Encoder and Decoder; instead, I will create two methods

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [12]:
import torch.nn as nn
import torch.optim
from collections import OrderedDict 

In [3]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)  # view(batch_size, flattened_example)

class UnFlatten(nn.Module):
    def forward(self, input, size=512):
        return input.view(input.size(0), size, 1, 1)

In [21]:
nn.Sequential??

In [19]:
nn.Sequential(OrderedDict([
    ('block1-conv1',nn.Conv2d(5, 64, kernel_size=4, stride=2))
]))

Sequential(
  (block1-conv1): Conv2d(5, 64, kernel_size=(4, 4), stride=(2, 2))
)

In [92]:
def conv_sampler(
    in_channels:int, 
    layer_num:int, 
    kernel_nums:tuple, 
    kernel_sizes:tuple, 
    strides:tuple, 
    paddings:tuple,
    final_activation:nn.Module=None,
    up_sample:bool=False,
)->nn.Sequential:
    """
    Return a convolutional sampler (nn.Sequential) with batch-normalizations and leaky ReLUs (for
    down-samplers) or ReLUs (for up-samplers).
    
    The DCGAN paper recommends that kernel sizes should be greater than 3, that strides should be 
    greater than 1, and batch-normalization should be used to guarantee a healthy gradient-flow.
    
    :param up_sample: whether the returned sampler is a up-sampler (default: False)
    """
    
    HYPERPARAMS = {
        'conv2d-bias':False,  # set to false because bn introduces biases
        'lrelu-negslope':0.2
    }
    
    # this insight comes from the dcgan paper
    if up_sample: 
        core_layer = nn.ConvTranspose2d
        core_layer_name = 'convtranpose2d'
        activation = nn.ReLU()
    else: 
        core_layer = nn.Conv2d
        core_layer_name = 'conv2d'
        activation = nn.LeakyReLU(HYPERPARAMS['lrelu-negslope'])
        
    layers = OrderedDict([])
    for i in range(layer_num):
        
        layers[f'block{i}-{core_layer_name}'] = core_layer(
            in_channels=in_channels, 
            out_channels=kernel_nums[i], 
            kernel_size=kernel_sizes[i], 
            stride=strides[i],
            padding=paddings[i],
            bias=HYPERPARAMS['conv2d-bias']
        )
        layers[f'block{i}-bn'] = nn.BatchNorm2d(kernel_nums[i])
        if i == layer_num - 1:
            if final_activation is not None:
                layers[f'block{i}-lrelu'] = final_activation
        else:
            layers[f'block{i}-lrelu'] = activation
        
        in_channels = kernel_nums[i]
        
    return nn.Sequential(layers)

In [93]:
down_sampler = conv_sampler(
    in_channels=17, 
    layer_num=2, 
    kernel_nums=(64, 128), 
    kernel_sizes=(4, 4), 
    strides=(2, 2), 
    paddings=(0, 0),
    final_activation=nn.LeakyReLU(0.2),
    up_sample=False
)

In [94]:
up_sampler = conv_sampler(
    in_channels=64, 
    layer_num=3, 
    kernel_nums=(128, 64, 17), 
    kernel_sizes=(4, 4, 4), 
    strides=(1, 2, 2), 
    paddings=(0, 1, 1),
    final_activation=nn.Sigmoid(),
    up_sample=True
)

In [5]:
class VAE(nn.Module):
    def __init__(self, dev, nc=17, h_dim=512, z_dim=64):
        super(VAE, self).__init__()
        self.dev = dev
        
        self.encoder = nn.Sequential(
            # input shape: n, 17, 16, 16
            nn.Conv2d(nc, 64, kernel_size=4, stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            # output shape: n, 64, 7, 7
            
            # input shape: n, 64, 7, 7
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # output shape: n, 128, 2, 2
            
            # input shape: n, 128, 2, 2
            Flatten()
            # output shape: n, 128 * 2 * 2 = 512
        )

        self.fc1 = nn.Linear(h_dim, z_dim)  # get means
        self.fc2 = nn.Linear(h_dim, z_dim)  # get logvars
        
        self.fc3 = nn.Linear(z_dim, h_dim)  # process the samples

        # similar to generator in DCGAN
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(h_dim, 128, kernel_size=4, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128,64,kernel_size=4,stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, nc, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        esp = torch.randn(*mu.size())
        z = mu + std * esp.to(self.dev).double()
        return z

    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparametrize(mu, logvar)
        return z, mu, logvar

    def encode(self, x):
        h = self.encoder(x)
        z, mu, logvar = self.bottleneck(h)
        return z, mu, logvar

    def decode(self, z):
        z = self.fc3(z)
        z = self.decoder(z)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        z = self.decode(z)
        return z, mu, logvar

In [6]:
def get_model(dev, z_dim, nc):
    vae = VAE(dev=dev, z_dim=z_dim, nc=nc)
    vae = vae.to(dev).double()
    opt = torch.optim.Adam(vae.parameters(), lr=1e-3)
    return vae, opt

def load_model(path, nc, dev=torch.device('cpu')):
    vae = VAE(nc=nc, dev=dev).double().to(dev)
    vae.load_state_dict(torch.load(path, map_location=dev))
    return vae