In [3]:
import torch
from torch import nn, optim
import torch.nn.functional as F

from torchinfo import summary

In [87]:
class ResizeConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        x = self.conv(x)
        return x

class BasicBlockEnc(nn.Module):

    def __init__(self, in_planes, stride=1):
        super().__init__()

        planes = in_planes*stride

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        if stride == 1:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class BasicBlockDec(nn.Module):

    def __init__(self, in_planes, stride=1):
        super().__init__()

        planes = int(in_planes/stride)

        self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(in_planes)
        # self.bn1 could have been placed here, but that messes up the order of the layers when printing the class

        if stride == 1:
            self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.shortcut = nn.Sequential()
        else:
            self.conv1 = ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride)
            self.bn1 = nn.BatchNorm2d(planes)
            self.shortcut = nn.Sequential(
                ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn2(self.conv2(x)))
        out = self.bn1(self.conv1(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class ResNet18Enc(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
        super().__init__()
        self.in_planes = 64
        self.z_dim = z_dim
        self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(BasicBlockEnc, 64, num_Blocks[0], stride=1)
        self.layer2 = self._make_layer(BasicBlockEnc, 128, num_Blocks[1], stride=2)
        self.layer3 = self._make_layer(BasicBlockEnc, 256, num_Blocks[2], stride=2)
        self.layer4 = self._make_layer(BasicBlockEnc, 512, num_Blocks[3], stride=2)

    def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in strides:
            layers += [BasicBlockEnc(self.in_planes, stride)]
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.maxpool1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
#         x = F.adaptive_avg_pool2d(x, 1)
#         x = x.view(x.size(0), -1)
#         x = self.linear(x)
#         mu = x[:, :self.z_dim]
#         logvar = x[:, self.z_dim:]
#         return mu, logvar
        return x

class ResNet18Dec(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
        super().__init__()
        self.in_planes = 512


        self.layer4 = self._make_layer(BasicBlockDec, 512, num_Blocks[3], stride=2)
        self.layer3 = self._make_layer(BasicBlockDec, 256, num_Blocks[2], stride=2)
        self.layer2 = self._make_layer(BasicBlockDec, 128, num_Blocks[1], stride=2)
        self.layer1 = self._make_layer(BasicBlockDec, 64, num_Blocks[0], stride=1)
        self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2)
        self.convtrans1 = nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1)

    def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in reversed(strides):
            layers += [BasicBlockDec(self.in_planes, stride)]
        self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, z):
#         x = self.linear(z)
#         x = x.view(z.size(0), 512, 1, 1)
#         x = F.interpolate(z, scale_factor=4)
        x = self.layer4(z)
        x = self.layer3(x)
        x = self.layer2(x)
        x = self.layer1(x)
        x = torch.sigmoid(self.conv1(x))
#         x = self.convtrans1(x)
        x = x.view(x.size(0), 3, 256, 256)
        return x

class VAE(nn.Module):

    def __init__(self, z_dim):
        super().__init__()
        self.encoder = ResNet18Enc(z_dim=z_dim)
        self.decoder = ResNet18Dec(z_dim=z_dim)

    def forward(self, x):
#         mean, logvar = self.encoder(x)
        z = self.encoder(x)
#         z = self.reparameterize(mean, logvar)
        x = self.decoder(z)
        return x, z
    
#     @staticmethod
#     def reparameterize(mean, logvar):
#         std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two
#         epsilon = torch.randn_like(std)
#         return epsilon * std + mean

In [88]:
vae = VAE(z_dim=10).cuda()

In [89]:
summary(vae, (1, 3, 256, 256), depth=6)

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [ResNet18Enc: 1, Conv2d: 2, BatchNorm2d: 2, MaxPool2d: 2, Sequential: 2, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Sequential: 2, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Conv2d: 5, BatchNorm2d: 5, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Sequential: 2, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Conv2d: 5, BatchNorm2d: 5, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Sequential: 2, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Conv2d: 5, BatchNorm2d: 5, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Sequential: 2, BasicBlockDec: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, BasicBlockDec: 3, Conv2d: 4, BatchNorm2d: 4, ResizeConv2d: 4, Conv2d: 5, BatchNorm2d: 4, Sequential: 4, ResizeConv2d: 5, Conv2d: 6, BatchNorm2d: 5]

In [98]:
INPUT_SHAPE = 256

In [147]:
class ResizeConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        x = self.conv(x)
        return x

class BasicBlockEnc(nn.Module):

    def __init__(self, in_planes, stride=1):
        super().__init__()

        planes = in_planes*stride

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        if stride == 1:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class BasicBlockDec(nn.Module):

    def __init__(self, in_planes, stride=1):
        super().__init__()
        
        self.convtrans1 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=4)

        planes = int(in_planes/stride)

        self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(in_planes)
        # self.bn1 could have been placed here, but that messes up the order of the layers when printing the class
 
        if stride == 1:
            self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.shortcut = nn.Sequential()
        else:
            self.conv1 = ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride)
            self.bn1 = nn.BatchNorm2d(planes)
            self.shortcut = nn.Sequential(
                ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn2(self.conv2(x)))
        out = self.bn1(self.conv1(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class ResNet18Enc(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
        super().__init__()
        self.in_planes = 64
        self.z_dim = z_dim
        self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(BasicBlockEnc, 64, num_Blocks[0], stride=1)
        self.layer2 = self._make_layer(BasicBlockEnc, 128, num_Blocks[1], stride=2)
        self.layer3 = self._make_layer(BasicBlockEnc, 256, num_Blocks[2], stride=2)
        self.layer4 = self._make_layer(BasicBlockEnc, 512, num_Blocks[3], stride=2)
        self.linear = nn.Linear(512, 2 * z_dim)

    def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in strides:
            layers += [BasicBlockEnc(self.in_planes, stride)]
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.maxpool1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
#         x = F.adaptive_avg_pool2d(x, 1)
#         x = x.view(x.size(0), -1)
#         x = self.linear(x)
#         mu = x[:, :self.z_dim]
#         logvar = x[:, self.z_dim:]
#         return mu, logvar
        return x

class ResNet18Dec(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
        super().__init__()
        self.in_planes = 256

        self.linear = nn.Linear(z_dim, 512)

        self.layer4 = self._make_layer(BasicBlockDec, 256, num_Blocks[3], stride=2)
        self.layer3 = self._make_layer(BasicBlockDec, 128, num_Blocks[2], stride=2)
        self.layer2 = self._make_layer(BasicBlockDec, 64, num_Blocks[1], stride=2)
        self.layer1 = self._make_layer(BasicBlockDec, 32, num_Blocks[0], stride=1)
        self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2)

    def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in reversed(strides):
            layers += [BasicBlockDec(self.in_planes, stride)]
        self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, z):
#         x = self.linear(z)
#         x = x.view(z.size(0), 512, 1, 1)
#         x = F.interpolate(z, scale_factor=4)
        x = self.layer4(z)
        x = self.layer3(x)
        x = self.layer2(x)
        x = self.layer1(x)
        x = torch.sigmoid(self.conv1(x))
        x = x.view(x.size(0), 3, INPUT_SHAPE, INPUT_SHAPE)
        return x

class VAE(nn.Module):

    def __init__(self, z_dim):
        super().__init__()
        self.encoder = ResNet18Enc(z_dim=z_dim)
        self.decoder = ResNet18Dec(z_dim=z_dim)

    def forward(self, x):
#         mean, logvar = self.encoder(x)
        z = self.encoder(x)
#         z = self.reparameterize(mean, logvar)
        x = self.decoder(z)
        return x, z
    


In [148]:
vae = VAE(z_dim=10).cuda()
summary(vae, (1, 3, INPUT_SHAPE, INPUT_SHAPE), depth=6)

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [ResNet18Enc: 1, Conv2d: 2, BatchNorm2d: 2, MaxPool2d: 2, Sequential: 2, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Sequential: 2, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Conv2d: 5, BatchNorm2d: 5, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Sequential: 2, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Conv2d: 5, BatchNorm2d: 5, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Sequential: 2, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Conv2d: 5, BatchNorm2d: 5, BasicBlockEnc: 3, Conv2d: 4, BatchNorm2d: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4]

In [138]:
enc = ResNet18Enc().cuda()
summary(enc, (1, 3, 256, 256), depth=6)

Layer (type:depth-idx)                   Output Shape              Param #
ResNet18Enc                              --                        --
├─Conv2d: 1-1                            [1, 64, 128, 128]         1,728
├─BatchNorm2d: 1-2                       [1, 64, 128, 128]         128
├─MaxPool2d: 1-3                         [1, 64, 64, 64]           --
├─Sequential: 1-4                        [1, 64, 64, 64]           --
│    └─BasicBlockEnc: 2-1                [1, 64, 64, 64]           --
│    │    └─Conv2d: 3-1                  [1, 64, 64, 64]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 64, 64]           128
│    │    └─Conv2d: 3-3                  [1, 64, 64, 64]           36,864
│    │    └─BatchNorm2d: 3-4             [1, 64, 64, 64]           128
│    │    └─Sequential: 3-5              [1, 64, 64, 64]           --
│    └─BasicBlockEnc: 2-2                [1, 64, 64, 64]           --
│    │    └─Conv2d: 3-6                  [1, 64, 64, 64]           36,8

In [176]:
class BasicBlockDec(nn.Module):

    def __init__(self, stride=1):
        super().__init__()
        
        self.convtrans1 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(512)
        self.convtrans2 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=1, padding=1)



    def forward(self, x):
        out1 = self.convtrans1(x)
        out2 = torch.relu(self.bn2(out1))
        out2 = self.convtrans2(out2)
        out2 = torch.relu(self.bn2(out2))
        final = torch.add(out1, out2)
        
        return final

In [180]:
basic = BasicBlockDec().cuda()
print(type(basic))
summary(basic, (1, 512, 8, 8), depth=6)

<class '__main__.BasicBlockDec'>


Layer (type:depth-idx)                   Output Shape              Param #
BasicBlockDec                            --                        --
├─ConvTranspose2d: 1-1                   [1, 512, 16, 16]          4,194,816
├─BatchNorm2d: 1-2                       [1, 512, 16, 16]          1,024
├─ConvTranspose2d: 1-3                   [1, 512, 16, 16]          2,359,808
├─BatchNorm2d: 1-4                       [1, 512, 16, 16]          (recursive)
Total params: 6,555,648
Trainable params: 6,555,648
Non-trainable params: 0
Total mult-adds (G): 1.68
Input size (MB): 0.13
Forward/backward pass size (MB): 3.15
Params size (MB): 26.22
Estimated Total Size (MB): 29.50

In [208]:
class ResizeConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        x = self.conv(x)
        return x

class BasicBlockEnc(nn.Module):

    def __init__(self, in_planes, stride=1):
        super().__init__()

        planes = in_planes*stride

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        if stride == 1:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class BasicBlockDec(nn.Module):

    def __init__(self, shape):
        super().__init__()
        if shape == 512:
            shape2 = 512
        else:
            shape2 = int(shape * 2)
        
        self.convtrans1 = nn.ConvTranspose2d(shape2, shape, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(shape)
        self.convtrans2 = nn.ConvTranspose2d(shape, shape, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        out1 = self.convtrans1(x)
        out2 = torch.relu(self.bn2(out1))
        out2 = self.convtrans2(out2)
        out2 = torch.relu(self.bn2(out2))
        final = torch.add(out1, out2)
        
        return final

class ResNet18Enc(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
        super().__init__()
        self.in_planes = 64
        self.z_dim = z_dim
        self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(BasicBlockEnc, 64, num_Blocks[0], stride=1)
        self.layer2 = self._make_layer(BasicBlockEnc, 128, num_Blocks[1], stride=2)
        self.layer3 = self._make_layer(BasicBlockEnc, 256, num_Blocks[2], stride=2)
        self.layer4 = self._make_layer(BasicBlockEnc, 512, num_Blocks[3], stride=2)
        self.linear = nn.Linear(512, 2 * z_dim)

    def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in strides:
            layers += [BasicBlockEnc(self.in_planes, stride)]
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.maxpool1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
#         x = F.adaptive_avg_pool2d(x, 1)
#         x = x.view(x.size(0), -1)
#         x = self.linear(x)
#         mu = x[:, :self.z_dim]
#         logvar = x[:, self.z_dim:]
#         return mu, logvar
        return x

class ResNet18Dec(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
        super().__init__()
        self.layer1 = BasicBlockDec(512)
        self.layer2 = BasicBlockDec(256)
        self.layer3 = BasicBlockDec(128)
        self.layer4 = BasicBlockDec(64)

        self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2)

    def _make_layer(self, BasicBlockDec, shape):
        return 
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in reversed(strides):
            layers += [BasicBlockDec(self.in_planes, stride)]
        self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
#         x = self.linear(z)
#         x = x.view(z.size(0), 512, 1, 1)
#         x = F.interpolate(z, scale_factor=4)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

#         x = torch.sigmoid(self.conv1(x))
#         x = x.view(x.size(0), 3, INPUT_SHAPE, INPUT_SHAPE)
        return x

class VAE(nn.Module):

    def __init__(self, z_dim):
        super().__init__()
        self.encoder = ResNet18Enc(z_dim=z_dim)
        self.decoder = ResNet18Dec(z_dim=z_dim)

    def forward(self, x):
#         mean, logvar = self.encoder(x)
        x = self.encoder(x)
#         z = self.reparameterize(mean, logvar)
        x = self.decoder(x)
        return x
    


In [209]:
dec = ResNet18Dec().cuda()
summary(dec, (1, 512, 8, 8), depth=6)

Layer (type:depth-idx)                   Output Shape              Param #
ResNet18Dec                              --                        --
├─BasicBlockDec: 1-1                     [1, 512, 16, 16]          --
│    └─ConvTranspose2d: 2-1              [1, 512, 16, 16]          4,194,816
│    └─BatchNorm2d: 2-2                  [1, 512, 16, 16]          1,024
│    └─ConvTranspose2d: 2-3              [1, 512, 16, 16]          2,359,808
│    └─BatchNorm2d: 2-4                  [1, 512, 16, 16]          (recursive)
├─BasicBlockDec: 1-2                     [1, 256, 32, 32]          --
│    └─ConvTranspose2d: 2-5              [1, 256, 32, 32]          2,097,408
│    └─BatchNorm2d: 2-6                  [1, 256, 32, 32]          512
│    └─ConvTranspose2d: 2-7              [1, 256, 32, 32]          590,080
│    └─BatchNorm2d: 2-8                  [1, 256, 32, 32]          (recursive)
├─BasicBlockDec: 1-3                     [1, 128, 64, 64]          --
│    └─ConvTranspose2d: 2-9          

In [210]:
vae = VAE(z_dim=10).cuda()
summary(vae, (1, 3, INPUT_SHAPE, INPUT_SHAPE), depth=6)

Layer (type:depth-idx)                        Output Shape              Param #
VAE                                           --                        --
├─ResNet18Enc: 1-1                            [1, 512, 8, 8]            --
│    └─Conv2d: 2-1                            [1, 64, 128, 128]         1,728
│    └─BatchNorm2d: 2-2                       [1, 64, 128, 128]         128
│    └─MaxPool2d: 2-3                         [1, 64, 64, 64]           --
│    └─Sequential: 2-4                        [1, 64, 64, 64]           --
│    │    └─BasicBlockEnc: 3-1                [1, 64, 64, 64]           --
│    │    │    └─Conv2d: 4-1                  [1, 64, 64, 64]           36,864
│    │    │    └─BatchNorm2d: 4-2             [1, 64, 64, 64]           128
│    │    │    └─Conv2d: 4-3                  [1, 64, 64, 64]           36,864
│    │    │    └─BatchNorm2d: 4-4             [1, 64, 64, 64]           128
│    │    │    └─Sequential: 4-5              [1, 64, 64, 64]           --
│    │