In [58]:
import torch
from torch import nn
from torch.nn import functional as F
from vae_helpers import HModule, get_1x1, get_3x3, draw_gaussian_diag_samples, gaussian_analytical_kl
from collections import defaultdict
import numpy as np

class Decoder(nn.Module):
    def __init__(self, H):
        super().__init__()
        self.decoder = VDDecoder(H)
        
    def forward(self, data, **kwargs):
        activations = data['activations']
        y, stats = self.decoder(activations)
        data['y'] = y
        data['stats'] = stats
        return data
    
    def sample(self, N):
        sample = self.decoder.forward_uncond(N)
        return sample
        

def parse_layer_string(s):
    layers = []
    for ss in s.split(','):
        if 'x' in ss:
            res, num = ss.split('x')
            count = int(num)
            layers += [(int(res), None) for _ in range(count)]
        elif 'm' in ss:
            res, mixin = [int(a) for a in ss.split('m')]
            layers.append((res, mixin))
        elif 'd' in ss:
            res, down_rate = [int(a) for a in ss.split('d')]
            layers.append((res, down_rate))
        else:
            res = int(ss)
            layers.append((res, None))
    return layers

def get_width_settings(width, s):
    mapping = defaultdict(lambda: width)
    if s:
        s = s.split(',')
        for ss in s:
            k, v = ss.split(':')
            mapping[int(k)] = int(v)
    return mapping

class Block(nn.Module):
    def __init__(self, in_width, middle_width, out_width, down_rate=None, residual=False, use_3x3=True, zero_last=False):
        super().__init__()
        self.down_rate = down_rate
        self.residual = residual
        self.c1 = get_1x1(in_width, middle_width)
        self.c2 = get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width)
        self.c3 = get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width)
        self.c4 = get_1x1(middle_width, out_width, zero_weights=zero_last)

    def forward(self, x):
        xhat = self.c1(F.gelu(x))
        xhat = self.c2(F.gelu(xhat))
        xhat = self.c3(F.gelu(xhat))
        xhat = self.c4(F.gelu(xhat))
        out = x + xhat if self.residual else xhat
        if self.down_rate is not None:
            out = F.avg_pool2d(out, kernel_size=self.down_rate, stride=self.down_rate)
        return out
    
class DecBlock(nn.Module):
    def __init__(self, H, res, mixin, n_blocks):
        super().__init__()
        self.base = res
        self.mixin = mixin
        self.H = H
        self.widths = get_width_settings(H.width, H.custom_width_str)
        width = self.widths[res]
        use_3x3 = res > 2
        cond_width = int(width * H.bottleneck_multiple)
        self.zdim = H.zdim
        self.enc = Block(width * 2, cond_width, H.zdim * 2, residual=False, use_3x3=use_3x3)
        self.prior = Block(width, cond_width, H.zdim * 2 + width, residual=False, use_3x3=use_3x3, zero_last=True)
        self.z_proj = get_1x1(H.zdim, width)
        self.z_proj.weight.data *= np.sqrt(1 / n_blocks)
        self.resnet = Block(width, cond_width, width, residual=True, use_3x3=use_3x3)
        self.resnet.c4.weight.data *= np.sqrt(1 / n_blocks)
        self.z_fn = lambda x: self.z_proj(x)
        self.M = H.M
        
    def get_kl(self, z):
        n, z_dim, h, w = z.size()
        # z : (n, 1, z_dim, h, w)
        z = z.unsqueeze(1)
        # (1, m, z_dim, h, w)
        e = torch.rand(1, self.M, z_dim, h, w).to(z.device)
        # (n, m, z_dim, h, w)
        distance = (z - e) ** 2
        # (m, z_dim, h, w)
        min_distance = torch.min(distance, dim=0).values
        # (z_dim, h, w)
        loss = torch.mean(min_distance, dim=0)
        # ()
        loss = torch.mean(loss)
        return loss
        
    def sample(self, x, acts, dropout_index):
        qm, qv = self.enc(torch.cat([x, acts], dim=1)).chunk(2, dim=1)
        qm = F.tanh(qm)
        feats = self.prior(x)
        pm, pv, xpp = feats[:, :self.zdim, ...], feats[:, self.zdim:self.zdim * 2, ...], feats[:, self.zdim * 2:, ...]
        x = x + xpp
        #z = draw_gaussian_diag_samples(qm, qv)
        #kl = gaussian_analytical_kl(qm, pm, qv, pv)
        z = qm
        kl = self.get_kl(z)
        n, zdim, h, w = z.size()
        # (n, zdim*h*w)
        dropout = torch.linspace(0, zdim*h*w-1, zdim*h*w).unsqueeze(0).repeat(n, 1).to(z.device)
        dropout = dropout < dropout_index[:, None]
        dropout = dropout.reshape(n, zdim, h, w)
        z = z * dropout
        return z, x, kl, dropout_index - zdim*h*w

    def sample_uncond(self, x, t=None, lvs=None):
        n, c, h, w = x.shape
        feats = self.prior(x)
        pm, pv, xpp = feats[:, :self.zdim, ...], feats[:, self.zdim:self.zdim * 2, ...], feats[:, self.zdim * 2:, ...]
        x = x + xpp
        if lvs is not None:
            z = lvs
        else:
            if t is not None:
                pv = pv + torch.ones_like(pv) * np.log(t)
            #z = draw_gaussian_diag_samples(pm, pv)
            # (n, z, h, w)
            z = torch.rand_like(pm).to(pm.device)
            
        return z, x

    def get_inputs(self, xs, activations):
        acts = activations[self.base]
        try:
            x = xs[self.base]
        except KeyError:
            x = torch.zeros_like(acts)
        if acts.shape[0] != x.shape[0]:
            x = x.repeat(acts.shape[0], 1, 1, 1)
        return x, acts

    def forward(self, xs, activations, dropout_index, get_latents=False):
        x, acts = self.get_inputs(xs, activations)
        if self.mixin is not None:
            x = x + F.interpolate(xs[self.mixin][:, :x.shape[1], ...], scale_factor=self.base // self.mixin)
        z, x, kl, dropout_index = self.sample(x, acts, dropout_index)
        x = x + self.z_fn(z)
        x = self.resnet(x)
        xs[self.base] = x
        if get_latents:
            return xs, dict(z=z.detach(), kl=kl)
        return xs, dict(kl=kl), dropout_index

    def forward_uncond(self, xs, t=None, lvs=None):
        try:
            x = xs[self.base]
        except KeyError:
            ref = xs[list(xs.keys())[0]]
            x = torch.zeros(dtype=ref.dtype, size=(ref.shape[0], self.widths[self.base], self.base, self.base), device=ref.device)
        if self.mixin is not None:
            x = x + F.interpolate(xs[self.mixin][:, :x.shape[1], ...], scale_factor=self.base // self.mixin)
        z, x = self.sample_uncond(x, t, lvs=lvs)
        x = x + self.z_fn(z)
        x = self.resnet(x)
        xs[self.base] = x
        return xs
    
class VDDecoder(HModule):
    
    def get_zcount(self, blocks, zdim):
        return sum([res*res*zdim for res, mixin in blocks])

    def build(self):
        H = self.H
        resos = set()
        dec_blocks = []
        self.widths = get_width_settings(H.width, H.custom_width_str)
        blocks = parse_layer_string(H.dec_blocks)
        self.zcount = self.get_zcount(blocks, H.zdim)
        for idx, (res, mixin) in enumerate(blocks):
            dec_blocks.append(DecBlock(H, res, mixin, n_blocks=len(blocks)))
            resos.add(res)
        self.resolutions = sorted(resos)
        self.dec_blocks = nn.ModuleList(dec_blocks)
        self.bias_xs = nn.ParameterList([nn.Parameter(torch.zeros(1, self.widths[res], res, res)) for res in self.resolutions if res <= H.no_bias_above])
        self.gain = nn.Parameter(torch.ones(1, H.width, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, H.width, 1, 1))
        self.final_fn = lambda x: x * self.gain + self.bias

    def forward(self, activations, get_latents=False):
        stats = []
        xs = {a.shape[2]: a for a in self.bias_xs}
        activation = activations[list(activations.keys())[0]]
        dropout_index = torch.randint(0, self.zcount, size=(len(activation),)).to(activation.device)
        for block in self.dec_blocks:
            xs, block_stats, dropout_index = block(xs, activations, dropout_index, get_latents=get_latents)
            stats.append(block_stats)
        xs[self.H.image_size] = self.final_fn(xs[self.H.image_size])
        return xs[self.H.image_size], stats

    def forward_uncond(self, n, t=None, y=None):
        xs = {}
        for bias in self.bias_xs:
            xs[bias.shape[2]] = bias.repeat(n, 1, 1, 1)
        for idx, block in enumerate(self.dec_blocks):
            try:
                temp = t[idx]
            except TypeError:
                temp = t
            xs = block.forward_uncond(xs, temp)
        xs[self.H.image_size] = self.final_fn(xs[self.H.image_size])
        return xs[self.H.image_size]

    def forward_manual_latents(self, n, latents, t=None):
        xs = {}
        for bias in self.bias_xs:
            xs[bias.shape[2]] = bias.repeat(n, 1, 1, 1)
        for block, lvs in itertools.zip_longest(self.dec_blocks, latents):
            xs = block.forward_uncond(xs, t, lvs=lvs)
        xs[self.H.image_size] = self.final_fn(xs[self.H.image_size])
        return xs[self.H.image_size]


In [59]:
from easydict import EasyDict
H = EasyDict()
H.image_size = 32
H.image_channels = 3
H.width = 384
H.zdim = 16
H.custom_width_str = ""
H.dec_blocks = "1x1,4m1,4x2,8m4,8x5,16m8,16x10,32m16,32x21"
H.bottleneck_multiple = 0.25
H.no_bias_above = 64
H.M = 1
decoder = Decoder(H)

def get_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) * 4 / (1024 * 1024)

get_size(decoder)

120.5799560546875

In [60]:
N = 2
data = {'activations': {1: torch.randn(N, H.width, 1, 1),
                   4: torch.randn(N, H.width, 4, 4),
                   8: torch.randn(N, H.width, 8, 8),
                   16: torch.randn(N, H.width, 16, 16),
                   32: torch.randn(N, H.width, 32, 32),
                  }
       }
data = decoder(data)
print(data.keys())

dict_keys(['activations', 'y', 'stats'])


In [19]:
blocks = parse_layer_string(H.dec_blocks)
zcount = 0
for res, mixin in blocks:
    zcount += res*H.zdim*H.zdim
    
zcount

240896

In [22]:
def get_zcount(blocks, zdim):
    return sum([res*res*zdim for res, mixin in blocks])

get_zcount(blocks, H.zdim)

412432

In [32]:
data['activations'][list(data['activations'].keys())[0]].shape

torch.Size([16, 384, 1, 1])

In [30]:
list(data['activations'].keys())[0]

1

In [40]:
torch.linspace(0, 99, 100).unsqueeze(0).repeat(4, 1).shape

torch.Size([4, 100])

In [37]:
torch.randint(0, 100, size=(2,))

tensor([66, 86])