In [31]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from vae_helpers import get_1x1, get_3x3, HModule
from collections import defaultdict

In [32]:
def parse_layer_string(s):
    layers = []
    for ss in s.split(','):
        if 'x' in ss:
            res, num = ss.split('x')
            count = int(num)
            layers += [(int(res), None) for _ in range(count)]
        elif 'm' in ss:
            res, mixin = [int(a) for a in ss.split('m')]
            layers.append((res, mixin))
        elif 'd' in ss:
            res, down_rate = [int(a) for a in ss.split('d')]
            layers.append((res, down_rate))
        else:
            res = int(ss)
            layers.append((res, None))
    return layers


def pad_channels(t, width):
    d1, d2, d3, d4 = t.shape
    empty = torch.zeros(d1, width, d3, d4, device=t.device)
    empty[:, :d2, :, :] = t
    return empty


def get_width_settings(width, s):
    mapping = defaultdict(lambda: width)
    if s:
        s = s.split(',')
        for ss in s:
            k, v = ss.split(':')
            mapping[int(k)] = int(v)
    return mapping

In [33]:
class Block(nn.Module):
    def __init__(self, in_width, middle_width, out_width,
                 down_rate=None, residual=False, use_3x3=True, zero_last=False):
        super().__init__()
        self.down_rate = down_rate
        self.residual = residual
        self.c1 = get_1x1(in_width, middle_width)
        self.c2 = get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width)
        self.c3 = get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width)
        self.c4 = get_1x1(middle_width, out_width, zero_weights=zero_last)
        
    def forward(self, x):
        xhat = self.c1(F.gelu(x))
        xhat = self.c2(F.gelu(xhat))
        xhat = self.c3(F.gelu(xhat))
        xhat = self.c4(F.gelu(xhat))        
        out = x + xhat if self.residual else xhat
        if self.down_rate is not None:
            out = F.avg_pool2d(out, kernel_size=self.down_rate, stride=self.down_rate)
        return out

In [34]:
block = Block(64, 128, 64, down_rate=2, residual=True)
x = torch.randn(2, 64, 128, 128)
y = block(x)
print(y.shape)

torch.Size([2, 64, 64, 64])


In [52]:
class VDEncoder(HModule):
    def build(self):
        H = self.H
        self.in_conv = get_3x3(H.image_channels, H.width)
        self.widths = get_width_settings(H.width, H.custom_width_str)
        enc_blocks = []
        blockstr = parse_layer_string(H.enc_blocks)
        for res, down_rate in blockstr:
            use_3x3 = res > 2  # Don't use 3x3s for 1x1, 2x2 patches
            enc_blocks.append(Block(self.widths[res], int(self.widths[res] * H.bottleneck_multiple), self.widths[res], down_rate=down_rate, residual=True, use_3x3=use_3x3))
        n_blocks = len(blockstr)
        for b in enc_blocks:
            b.c4.weight.data *= np.sqrt(1 / n_blocks)
        self.enc_blocks = nn.ModuleList(enc_blocks)

    def forward(self, x):
        x = self.in_conv(x)
        activations = {}
        activations[x.shape[2]] = x
        for block in self.enc_blocks:
            x = block(x)
            res = x.shape[2]
            x = x if x.shape[1] == self.widths[res] else pad_channels(x, self.widths[res])
            activations[res] = x
        return activations

In [53]:
class Encoder(nn.Module):
    def __init__(self, H):
        super().__init__()
        self.encoder = VDEncoder(H)
        
    def forward(self, data, **kwargs):
        # x : (b, c, h, w)
        data['activations'] = self.encoder(data['x'])
        return data
        

In [54]:
from easydict import EasyDict
H = EasyDict()
H.image_channels = 3
H.width = 384
H.custom_width_str = ""
H.enc_blocks = "32x11,32d2,16x6,16d2,8x6,8d2,4x3,4d4,1x3"
H.bottleneck_multiple = 0.25
encoder = Encoder(H)

In [57]:
data = {'x': torch.randn(2, 3, 64, 64)}
data = encoder(data)
for key in data['activations'].keys():
    print(key, data['activations'][key].shape)

64 torch.Size([2, 384, 64, 64])
32 torch.Size([2, 384, 32, 32])
16 torch.Size([2, 384, 16, 16])
8 torch.Size([2, 384, 8, 8])
2 torch.Size([2, 384, 2, 2])


In [46]:
def get_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) * 4 / (1024 * 1024)

get_size(encoder)

28.6021728515625