In [2]:
# https://github.com/sony/sqvae/blob/main/vision/networks/net_64.py

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

## Resblocks
class ResBlock(nn.Module):
    def __init__(self, dim, act="relu"):
        super().__init__()
        if act == "relu":
            activation = nn.ReLU()
        elif act == "elu":
            activation = nn.ELU()
        self.block = nn.Sequential(
            activation,
            nn.Conv2d(dim, dim, 3, 1, 1),
            nn.BatchNorm2d(dim),
            activation,
            nn.Conv2d(dim, dim, 1),
            nn.BatchNorm2d(dim)
        )

    def forward(self, x):
        return x + self.block(x)

class DecoderVqResnet64(nn.Module):
    def __init__(self, dim_z, cfgs, flg_bn=True):
        super(DecoderVqResnet64, self).__init__()
        # Resblocks
        num_rb = cfgs.num_rb
        layers_resblocks = []
        for i in range(num_rb-1):
            layers_resblocks.append(ResBlock(dim_z))
        self.res = nn.Sequential(*layers_resblocks)
        # Convolution layers
        layers_convt = []
        layers_convt.append(nn.ConvTranspose2d(dim_z, dim_z, 3, stride=1, padding=1))
        if flg_bn:
            layers_convt.append(nn.BatchNorm2d(dim_z))
        layers_convt.append(nn.ReLU())
        layers_convt.append(nn.ConvTranspose2d(dim_z, dim_z // 2, 4, stride=2, padding=1))
        if flg_bn:
            layers_convt.append(nn.BatchNorm2d(dim_z // 2))
        layers_convt.append(nn.ReLU())
        layers_convt.append(nn.ConvTranspose2d(dim_z // 2, 3, 4, stride=2, padding=1))
        layers_convt.append(nn.Sigmoid())
        self.convt = nn.Sequential(*layers_convt)
        
    def forward(self, z):
        out_res = self.res(z)
        out = self.convt(out_res)

        return out
    
class Decoder(nn.Module):
    def __init__(self, z_dim=512, n_resblocks=6, **kwargs):
        super().__init__()
        from easydict import EasyDict
        cfgs = EasyDict(num_rb=n_resblocks)
        self.decoder = DecoderVqResnet64(z_dim, cfgs)
        
    def forward(self, data, **kwargs):
        # x : (b, c, h, w)
        data['y'] = self.decoder(data['z'])
        data['recon_loss'] = F.mse_loss(data['y'], data['x'])
        return data


In [4]:
decoder = Decoder()
data = {'x': torch.randn(2, 3, 64, 64),
        'z': torch.randn(2, 512, 16, 16)}
decoder(data)

{'x': tensor([[[[-0.7920,  0.3415,  0.6648,  ...,  0.7160,  1.2360, -1.2533],
           [-1.9391,  0.6080,  0.2429,  ..., -0.3804, -0.4999, -1.6848],
           [-0.1709,  1.0713, -0.0678,  ..., -1.4061, -2.2036, -0.7827],
           ...,
           [ 1.9159,  0.1254,  0.8385,  ..., -0.1648,  3.6019, -1.5098],
           [-0.5258, -0.5435, -0.8465,  ..., -0.3113,  0.0685,  2.0186],
           [-1.3702, -0.7742,  0.6796,  ...,  0.3385, -0.7149,  0.3521]],
 
          [[ 0.8799,  0.7947, -0.5562,  ..., -0.3090, -0.1672,  1.8210],
           [ 0.7391, -0.2926,  0.4028,  ...,  0.0205,  2.6043,  0.2713],
           [ 1.3274, -0.7049,  0.7737,  ..., -1.2726,  0.3466, -1.2280],
           ...,
           [ 1.0983,  1.6681,  1.2170,  ...,  0.8827, -0.8025, -1.0963],
           [-1.0164,  0.0383,  1.0731,  ...,  0.1196, -0.4333, -0.3390],
           [ 0.7402,  2.1910, -0.1935,  ..., -0.9776,  1.2968,  0.5063]],
 
          [[ 0.9234,  1.2344, -0.0283,  ...,  1.4216,  1.5597,  0.5546],
        