In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.distributions import Normal, Categorical

from models.blocks import ConvBlock, TransposeConvBlock, ResConvBlock, CategoricalStraightThrough

import torchvision
from torchvision import transforms

In [2]:
height = 128
width = 128

kernel_size = 4
padding = 1
stride = 2

height = (height + 2 * padding - kernel_size) // stride + 1
width = (width + 2 * padding - kernel_size) // stride + 1

height, width

(64, 64)

In [3]:
# image size
height = 128
width = 128

# settings
kernel_size = 3
padding = 1
stride = 2

# channels
input_channels = 1
channels = [16, 32, 64, 128, 256, 512, 1024]
blocks = len(channels)

for i in range(blocks):
    height = (height + 2 * padding - kernel_size) // stride + 1
    width = (width + 2 * padding - kernel_size) // stride + 1

    print(f"adding ConvBlock({input_channels, channels[i]}) ==> output shape:{(channels[i], height, width)} ==> prod: {channels[i] * height * width}")
    input_channels = channels[i]

adding ConvBlock((1, 16)) ==> output shape:(16, 64, 64) ==> prod: 65536
adding ConvBlock((16, 32)) ==> output shape:(32, 32, 32) ==> prod: 32768
adding ConvBlock((32, 64)) ==> output shape:(64, 16, 16) ==> prod: 16384
adding ConvBlock((64, 128)) ==> output shape:(128, 8, 8) ==> prod: 8192
adding ConvBlock((128, 256)) ==> output shape:(256, 4, 4) ==> prod: 4096
adding ConvBlock((256, 512)) ==> output shape:(512, 2, 2) ==> prod: 2048
adding ConvBlock((512, 1024)) ==> output shape:(1024, 1, 1) ==> prod: 1024


In [4]:
channels = [16, 32, 64, 128, 256, 512, 1024]

In [5]:
for i, c in enumerate(reversed(channels)):
    print(i,c)

0 1024
1 512
2 256
3 128
4 64
5 32
6 16


In [6]:
len(channels)

7

In [30]:
class CategoricalVAE(nn.Module):
    def __init__(self, grayscale=True, vae_ent_coeff=0.0):
        super(CategoricalVAE, self).__init__()

        if grayscale:
            self.input_channels = 1
        else:
            self.input_channels = 3
        self.vae_ent_coeff = vae_ent_coeff
        
        self.encoder = nn.Sequential()
        self.decoder = nn.Sequential()
        self.categorical = CategoricalStraightThrough(num_classes=32)

        # settings
        kernel_size = 3
        stride = 2
        padding = 1

        # channels
        input_channels = self.input_channels
        channels = [16, 32, 64, 128, 256, 512, 1024]

        print("Initializing encoder:")
        height, width = 128, 128
        for i, out_channels in enumerate(channels):
            
            height = (height + 2*padding - kernel_size) // stride + 1
            width = (width + 2*padding - kernel_size) // stride + 1

            print(f"- adding ConvBlock({input_channels, out_channels}) \
                  ==> output shape: ({out_channels}, {height}, {width}) ==> prod: {out_channels * height * width}")
            conv_block = ConvBlock(input_channels, out_channels, kernel_size, stride, 
                                   padding, height, width)
            self.encoder.add_module(f"conv_block_{i}", conv_block)
            
            input_channels = out_channels
        
        print("\nInitializing decoder:")
        height, width = 1, 1
        padding=1
        for i, out_channels in enumerate(reversed(channels)):
            
            height = (height - 1)*stride - 2*padding + kernel_size + 1
            width = (width - 1)*stride - 2*padding + kernel_size + 1
            
            # last layer
            if i == len(channels)-1:
                out_channels = self.input_channels
            
            print(f"- adding transpose ConvBlock({input_channels}, {out_channels}) \
                  ==> output shape: ({out_channels}, {height}, {width}) ==> prod: {out_channels * height * width}")
            transpose_conv_block = ConvBlock(input_channels, out_channels, kernel_size, stride, 
                                             padding, height, width, transpose_conv=True)
            self.decoder.add_module(f"transpose_conv_block_{i}", transpose_conv_block)
            
            input_channels = out_channels

        self.decoder.add_module("output_activation", nn.Sigmoid())


    def encode(self, x):
        logits = self.encoder(x).view(-1, 32, 32)
        z = self.categorical(logits)
        return z
    
    def decode(self, z):
        x = self.decoder(z.view(-1, 32*32, 1, 1))
        return x

    def forward(self, x):
        z = self.encode(x).view(-1, 32, 32)
        x_hat = self.decode(z)
        return x_hat

In [31]:
vae = CategoricalVAE()

Initializing encoder:
- adding ConvBlock((1, 16))                   ==> output shape: (16, 64, 64) ==> prod: 65536
- adding ConvBlock((16, 32))                   ==> output shape: (32, 32, 32) ==> prod: 32768
- adding ConvBlock((32, 64))                   ==> output shape: (64, 16, 16) ==> prod: 16384
- adding ConvBlock((64, 128))                   ==> output shape: (128, 8, 8) ==> prod: 8192
- adding ConvBlock((128, 256))                   ==> output shape: (256, 4, 4) ==> prod: 4096
- adding ConvBlock((256, 512))                   ==> output shape: (512, 2, 2) ==> prod: 2048
- adding ConvBlock((512, 1024))                   ==> output shape: (1024, 1, 1) ==> prod: 1024

Initializing decoder:
- adding transpose ConvBlock(1024, 1024)                   ==> output shape: (1024, 2, 2) ==> prod: 4096
- adding transpose ConvBlock(1024, 512)                   ==> output shape: (512, 4, 4) ==> prod: 8192
- adding transpose ConvBlock(512, 256)                   ==> output shape: (256, 8, 8) ==

In [19]:
x = torch.rand(8, 1, 1, 1)
upsample = nn.ConvTranspose2d(1,1,kernel_size,stride,padding)

In [32]:
vae.encode(torch.rand(8, 1, 128, 128)).shape

Output shape: torch.Size([8, 16, 64, 64])
Output shape: torch.Size([8, 32, 32, 32])
Output shape: torch.Size([8, 64, 16, 16])
Output shape: torch.Size([8, 128, 8, 8])
Output shape: torch.Size([8, 256, 4, 4])
Output shape: torch.Size([8, 512, 2, 2])
Output shape: torch.Size([8, 1024, 1, 1])


torch.Size([8, 32, 32])

In [34]:
vae(torch.rand(8, 1, 128, 128)).shape

Output shape: torch.Size([8, 16, 64, 64])
Output shape: torch.Size([8, 32, 32, 32])
Output shape: torch.Size([8, 64, 16, 16])
Output shape: torch.Size([8, 128, 8, 8])
Output shape: torch.Size([8, 256, 4, 4])
Output shape: torch.Size([8, 512, 2, 2])
Output shape: torch.Size([8, 1024, 1, 1])
Output shape: torch.Size([8, 1024, 2, 2])
Output shape: torch.Size([8, 512, 4, 4])
Output shape: torch.Size([8, 256, 8, 8])
Output shape: torch.Size([8, 128, 16, 16])
Output shape: torch.Size([8, 64, 32, 32])
Output shape: torch.Size([8, 32, 64, 64])
Output shape: torch.Size([8, 1, 128, 128])


torch.Size([8, 1, 128, 128])