In [199]:
%matplotlib inline
import numpy as np
import matplotlib
import imageio

In [200]:
# import pytorch modules
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import init
import functools
from torch.optim import lr_scheduler


In [211]:
class Decoder3d(nn.Module):
    def __init__(self):
        super(Decoder3d, self).__init__()
        
        # hyperparameters
        self.z_size = 200
        self.cube_len = 32
        self.bias = True

        padd = (0, 0, 0)

        self.layer1 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(self.z_size, self.cube_len*4, kernel_size=4, stride=2,\
                                     bias=self.bias, padding=padd),
            torch.nn.BatchNorm3d(self.cube_len*4),
            torch.nn.ReLU()
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(self.cube_len*4, self.cube_len*2, kernel_size=4, stride=2, bias=self.bias,\
                                     padding=(1, 1, 1)),
            torch.nn.BatchNorm3d(self.cube_len*2),
            torch.nn.ReLU()
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(self.cube_len*2, self.cube_len*1, kernel_size=4, stride=2, bias=self.bias,\
                                     padding=(1, 1, 1)),
            torch.nn.BatchNorm3d(self.cube_len*1),
            torch.nn.ReLU()
        )

        self.layer4 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(self.cube_len, 1, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        out = x.unsqueeze(-1)
        out = self.layer1(out)
        out = self.layer2(out)
#         print(out.size())
        out = self.layer3(out)
#         print(out.size())
        out = self.layer4(out)

        return out


In [212]:
class Encoder3d(nn.Module):
    def __init__(self):
        super(Encoder3d, self).__init__()
        
        # hyperparameters
        self.z_size = 200
        self.cube_len = 32
        self.bias = True
        
        self.layer1 = nn.Sequential(
            nn.Conv3d(1, self.cube_len, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
            nn.BatchNorm3d(self.cube_len*1),
            nn.ReLU()
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv3d(self.cube_len, self.cube_len*2, kernel_size=4, stride=2, bias=self.bias,\
                                     padding=(1, 1, 1)),
            nn.BatchNorm3d(self.cube_len*2),
            nn.ReLU()
        )
        
        self.layer3 = nn.Sequential(
            nn.Conv3d(self.cube_len*2, self.cube_len*4, kernel_size=4, stride=2, bias=self.bias,\
                                     padding=(1, 1, 1)),
            nn.BatchNorm3d(self.cube_len*4),
            nn.ReLU()
        )
        
        self.layer4 = torch.nn.Sequential(
            nn.Conv3d(self.cube_len*4, self.z_size, kernel_size=4, stride=2,\
                                     bias=self.bias, padding=(0,0,0)),
            nn.Tanh()
        )
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        return out

In [349]:
class Encoder2d(nn.Module):
    def __init__(self, input_nc=1, output_nc=200, ngf=8, norm_layer=nn.BatchNorm2d,\
                 use_dropout=False, padding_type='reflect'):
        super(Encoder2d, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(2),
                 nn.Conv2d(input_nc, ngf, kernel_size=5, padding=0,
                           bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]
        
        
        n_downsampling = 4
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=5,
                                stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]
        
        for i in range(2):
            model += [nn.Conv2d(ngf * mult * 2, ngf * mult * 2, kernel_size=3,
                                    stride=2, padding=1, bias=use_bias),
                          norm_layer(ngf * mult * 2),
                          nn.ReLU(True)]
            
        model += [nn.Conv2d(ngf * mult * 2, output_nc, kernel_size=2,
                                    stride=1, padding=0, bias=use_bias)]
        model += [nn.Tanh()]
        
            
        

#         mult = 2**(n_downsampling-1)
#         for i in range(n_blocks):
#             model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,\
#                                   use_dropout=use_dropout, use_bias=use_bias)]
            
        self.model = nn.ModuleList(model)
#         self.model = nn.Sequential(*model)

    def forward(self, x):
        for model in self.model:
            x = model(x)
#             print(x.size())
        return x
        #return self.model(input)

In [350]:
class Decoder2d(nn.Module):
    def __init__(self, input_nc=200, output_nc=1, ngf=8, norm_layer=nn.BatchNorm2d,\
                 use_dropout=False, padding_type='reflect'):
        super(Decoder2d, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

       

        model = []
        
        n_upsampling = 4
        mult = 2**(n_upsampling-1)
        
        model += [nn.ConvTranspose2d(input_nc, ngf * mult * 2, kernel_size=2,
                                    stride=1, padding=0, bias=use_bias),\
                                norm_layer(ngf * mult * 2),
                                nn.ReLU(True)]
        
        
        for i in range(2):
            model += [nn.ConvTranspose2d(ngf * mult * 2, ngf * mult * 2, kernel_size=3,
                                        stride=2, padding=1, bias=use_bias),
                              norm_layer(ngf * mult * 2),
                              nn.ReLU(True)]
            
        # padding to maintain output size
        p2d = (1, 1, 1, 1) 
        model += [nn.ZeroPad2d(p2d),]
        
        for i in range(n_upsampling):
            mult = 2**(n_upsampling-1-i)
            model += [nn.ConvTranspose2d(ngf * mult * 2, ngf * mult, kernel_size=5,
                                stride=2, padding=1, bias=use_bias),\
                      norm_layer(ngf * mult),
                      nn.ReLU(True)]
        
        # padding to maintain output size
        p2d = (0, 1, 0, 1) 
        model += [nn.ZeroPad2d(p2d),]
        
        model += [#nn.ReflectionPad2d(2),
                 nn.ConvTranspose2d(ngf, output_nc, kernel_size=5, padding=2,
                           bias=use_bias)]
        
#         self.model = nn.Sequential(*model)
        self.model = nn.ModuleList(model)
        
    def forward(self, x):
        out = x.squeeze(-1)
        for model in self.model:
            out = model(out)
#             print(out.size())
#         out = self.model(out)
        return out

In [351]:
images = np.load('padded_gray_images.npy')

In [352]:
# import test data
img = images[0]
img = Variable(torch.from_numpy(img).view(1,1, *img.shape).type(torch.FloatTensor))

In [353]:
_E2d = Encoder2d()
_D3d = Decoder3d()
_E3d = Encoder3d()
_D2d = Decoder2d()

In [354]:
code = _E2d(img)
print(code.size())
output = _D3d(code)
print(output.size())
code = _E3d(output)
print(code.size())
output = _D2d(code)
print(output.size())

torch.Size([1, 200, 1, 1])
torch.Size([1, 1, 32, 32, 32])
torch.Size([1, 200, 1, 1, 1])
torch.Size([1, 1, 128, 128])
