In [3]:
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import matplotlib.pyplot as plt

random.seed(42)

### Designing Encoder (E)

In [20]:
class resBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, k=3, s=1, p=1):
        super(resBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, k, stride=s, padding=p)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, k, stride=s, padding=p)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        return self.bn2(self.conv2(y)) + x
    
class resTransposeBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, k=3, s=1, p=1):
        super(resTransposeBlock, self).__init__()

        self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, k, stride=s, padding=p)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.ConvTranspose2d(out_channels, out_channels, k, stride=s, padding=p)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        return self.bn2(self.conv2(y)) + x

In [21]:
A = resBlock(in_channels=3, out_channels=3, k=5, s=1, p=2)
At = resTransposeBlock(in_channels=3, out_channels=3, k=5, s=1, p=2)

In [24]:
i = torch.rand(1, 3, 32, 32)
print(i.size())
a = A(Variable(i))
print(a.size())
a = At(a)
print(a.size())

torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])


In [33]:
class E(nn.Module):
    def __init__(self, k=3, s=1, p=1, n_res_blocks=16):
        super(E, self).__init__()
        self.n_res_blocks = n_res_blocks
        self.conv1 = nn.Conv2d(3, 64, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_1' + str(i+1), resBlock(in_channels=64, out_channels=64, k=k, s=s, p=p))
        self.conv2 = nn.Conv2d(64, 32, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_2' + str(i+1), resBlock(in_channels=32, out_channels=32, k=k, s=s, p=p))
        self.conv3 = nn.Conv2d(32, 8, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_3' + str(i+1), resBlock(in_channels=8, out_channels=8, k=k, s=s, p=p))
        self.conv4 = nn.Conv2d(8, 1, k, stride=s, padding=p)
    
    def forward(self, x):
        y = self.conv1(x)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_1'+str(i+1))(y)
        y = self.conv2(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_2'+str(i+1))(y)
        y = self.conv3(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_3'+str(i+1))(y)
        y = self.conv4(y)
        return y

In [34]:
E1 = E()

In [35]:
E1(Variable(i))

Variable containing:
(0 ,0 ,.,.) = 
  1.7116 -0.1909  0.2748  ...   2.3927 -0.5373  0.0954
  0.6214 -2.0303 -1.8899  ...  -0.2308 -1.1154 -0.0165
  0.2596  0.2444 -1.9987  ...  -1.8513 -0.9425  0.3133
           ...             ⋱             ...          
  0.6772  1.4138 -0.5727  ...   1.5234  0.8912  0.0269
  0.6772 -0.1078 -2.5094  ...   0.0764 -0.3250 -1.0620
  1.6443  1.4955  0.7867  ...   0.0401 -0.4311 -0.2909
[torch.FloatTensor of size 1x1x32x32]

### Designing Decoder

In [41]:
class D(nn.Module):
    def __init__(self, k=3, s=1, p=1, n_res_blocks=16):
        super(D, self).__init__()
        self.n_res_blocks = n_res_blocks
        self.conv1 = nn.Conv2d(1, 8, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_1' + str(i+1), resBlock(in_channels=8, out_channels=8, k=k, s=s, p=p))
        self.conv2 = nn.Conv2d(8, 32, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_2' + str(i+1), resBlock(in_channels=32, out_channels=32, k=k, s=s, p=p))
        self.conv3 = nn.Conv2d(32, 64, k, stride=s, padding=p)
        for i in range(n_res_blocks):
            self.add_module('residual_block_3' + str(i+1), resBlock(in_channels=64, out_channels=64, k=k, s=s, p=p))
        self.conv4 = nn.Conv2d(64, 3, k, stride=s, padding=p)
    
    def forward(self, x):
        y = self.conv1(x)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_1'+str(i+1))(y)
        y = self.conv2(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_2'+str(i+1))(y)
        y = self.conv3(y)
        for i in range(self.n_res_blocks):
            y = self.__getattr__('residual_block_3'+str(i+1))(y)
        y = self.conv4(y)
        return y

In [42]:
D1 = D()

In [43]:
D1(E1(Variable(i)))

Variable containing:
(0 ,0 ,.,.) = 
  0.4352  1.3713  2.1673  ...  -0.2854 -0.5704  0.5093
 -0.1276  2.6184 -1.0822  ...   3.1426  2.7637 -0.5459
 -1.4393 -1.0901 -1.0960  ...   0.1455 -0.9688 -0.4661
           ...             ⋱             ...          
  0.2317  1.0995  0.6195  ...   2.5308 -0.3304 -0.5466
  0.2395  2.0903 -2.6870  ...  -0.4819 -0.3337 -1.3960
  0.4094  1.9413 -0.6637  ...   0.5250  1.0122  0.2618

(0 ,1 ,.,.) = 
 -1.2391  0.4855 -1.4408  ...   1.4412 -0.4702  0.5231
 -0.8185  0.3060  1.1245  ...  -0.9661  0.2471  1.5891
 -0.3648 -1.8255  0.1077  ...   0.3283  1.2629 -0.6357
           ...             ⋱             ...          
 -1.4454 -0.8811  0.5056  ...  -0.5971 -0.5191  0.0851
  0.9704 -1.2024 -0.1784  ...  -0.7417 -1.3928 -0.5833
 -0.7808 -0.0429 -1.0189  ...  -0.1904 -0.5911 -0.1119

(0 ,2 ,.,.) = 
 -0.4048 -0.5498  0.8187  ...   1.3329  0.0170  0.1075
  2.0861  1.7548  1.5858  ...   1.4168  1.8486  1.5469
  0.9424 -0.4886  4.6423  ...  -0.3712  1.4002  0.35