In [3]:
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import matplotlib.pyplot as plt

random.seed(42)

### Designing Encoder (E)

In [20]:
class resBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, k=3, s=1, p=1):
        super(resBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, k, stride=s, padding=p)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, k, stride=s, padding=p)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        return self.bn2(self.conv2(y)) + x
    
class resTransposeBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, k=3, s=1, p=1):
        super(resTransposeBlock, self).__init__()

        self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, k, stride=s, padding=p)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.ConvTranspose2d(out_channels, out_channels, k, stride=s, padding=p)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        return self.bn2(self.conv2(y)) + x

In [21]:
A = resBlock(in_channels=3, out_channels=3, k=5, s=1, p=2)
At = resTransposeBlock(in_channels=3, out_channels=3, k=5, s=1, p=2)

In [24]:
i = torch.rand(1, 3, 32, 32)
print(i.size())
a = A(Variable(i))
print(a.size())
a = At(a)
print(a.size())

torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])


In [44]:
class Encoder(nn.Module):
    def __init__(self, k=3, s=1, p=1, n_res_blocks=16):
        super(Encoder, self).__init__()
        self.n_res_blocks = n_res_blocks
        self.conv1 = nn.Conv2d(3, 64, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_1' + str(i+1), resBlock(in_channels=64, out_channels=64, k=k, s=s, p=p))
        self.conv2 = nn.Conv2d(64, 32, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_2' + str(i+1), resBlock(in_channels=32, out_channels=32, k=k, s=s, p=p))
        self.conv3 = nn.Conv2d(32, 8, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_3' + str(i+1), resBlock(in_channels=8, out_channels=8, k=k, s=s, p=p))
        self.conv4 = nn.Conv2d(8, 1, k, stride=s, padding=p)
    
    def forward(self, x):
        y = self.conv1(x)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_1'+str(i+1))(y)
        y = self.conv2(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_2'+str(i+1))(y)
        y = self.conv3(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_3'+str(i+1))(y)
        y = self.conv4(y)
        return y

In [45]:
E1 = Encoder()

In [46]:
E1(Variable(i))

Variable containing:
(0 ,0 ,.,.) = 
  0.9031 -0.3610  1.2018  ...   0.7953  0.0952 -0.7730
  0.3095 -0.8629  0.6415  ...   0.8822  0.2406 -0.5124
  0.8194  0.3846  0.2768  ...   0.0295  1.2961  0.0140
           ...             ⋱             ...          
  0.8982 -1.0420 -0.1356  ...   0.1664  0.0194 -1.0014
  0.8625  0.2871 -1.4505  ...  -0.6737  1.0635 -0.8387
  0.9252 -0.0480 -0.1523  ...  -0.8271  0.7355  0.8571
[torch.FloatTensor of size 1x1x32x32]

### Designing Decoder

In [47]:
class Decoder(nn.Module):
    def __init__(self, k=3, s=1, p=1, n_res_blocks=16):
        super(Decoder, self).__init__()
        self.n_res_blocks = n_res_blocks
        self.conv1 = nn.Conv2d(1, 8, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_1' + str(i+1), resBlock(in_channels=8, out_channels=8, k=k, s=s, p=p))
        self.conv2 = nn.Conv2d(8, 32, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_2' + str(i+1), resBlock(in_channels=32, out_channels=32, k=k, s=s, p=p))
        self.conv3 = nn.Conv2d(32, 64, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_3' + str(i+1), resBlock(in_channels=64, out_channels=64, k=k, s=s, p=p))
        self.conv4 = nn.Conv2d(64, 3, k, stride=s, padding=p)
    
    def forward(self, x):
        y = self.conv1(x)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_1'+str(i+1))(y)
        y = self.conv2(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_2'+str(i+1))(y)
        y = self.conv3(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_3'+str(i+1))(y)
        y = self.conv4(y)
        return y

In [48]:
D1 = Decoder()

In [49]:
D1(E1(Variable(i)))

Variable containing:
(0 ,0 ,.,.) = 
  1.0827 -0.0593 -0.0030  ...   1.0478  0.1732  0.2410
  0.5288  0.6761 -0.2841  ...   1.1802  0.3283 -1.0316
  0.6399 -1.9798 -0.0073  ...  -0.7257  0.4547 -0.4186
           ...             ⋱             ...          
  0.1903 -0.7642 -2.1776  ...  -0.6689  2.0894 -1.6200
 -0.3739 -1.4837 -1.6083  ...   0.9097 -0.8422 -0.9819
  0.1924 -0.2091 -0.6234  ...  -0.0026 -0.5014 -0.2059

(0 ,1 ,.,.) = 
 -0.5666  0.0458  0.0823  ...  -0.1285 -0.1468  0.3786
 -0.0863  1.0000  0.1036  ...   0.5963  0.1359  0.7577
  0.7004  0.3091 -1.3917  ...  -1.7849  1.3269 -0.1839
           ...             ⋱             ...          
  0.0722  0.7759  1.3589  ...  -1.3497  0.8703  0.8669
  0.1839  0.4355 -0.0603  ...   1.7978  2.4161  1.7603
  0.1993 -0.4282  0.1741  ...   1.2430  0.0179  0.5846

(0 ,2 ,.,.) = 
 -0.0520 -0.3354 -0.3922  ...   0.2921 -0.1066  0.3495
  1.3898  0.7022  0.2357  ...  -0.5213  2.0845  0.8894
 -0.3942  0.3375  0.3673  ...  -0.8229 -1.2347  0.83