In [1]:
import os
import math
#import cv2
import tarfile
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import FileLink
from IPython.display import Image

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch.utils.data import random_split
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.utils import save_image
import torchvision.transforms as tt
from torchvision.transforms import ToTensor, Normalize, Compose
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_url

In [90]:
def relu():
    return nn.ReLU()

def lrelu(f=0.2):
    return nn.LeakyReLU(f)

def tanh():
    return nn.Tanh()

def batch_norm(ni):
    return nn.BatchNorm2d(ni)

def conv_2d(ni, nf, ks, stride=2):
    return nn.Conv2d(in_channels=ni, out_channels=nf, kernel_size=ks, stride=stride, padding=ks//2, bias=False)

def deconv_2d(ni, nf, ks, stride=2, padding=1, output_padding=1):
    return nn.ConvTranspose2d(in_channels=ni, out_channels=nf, 
                               kernel_size=ks, stride=stride, 
                               padding=padding, output_padding=output_padding)
    
def fc_nn(input_size, output_size):
    return nn.Sequential(nn.Flatten(), 
                          nn.Linear(input_size, output_size)
                         )

In [28]:
class ResBlock(nn.Module):
    def __init__(self, ni, ks=3, stride=1):
        super().__init__()
        self.conv = conv_2d(ni, ni, ks, stride)
        self.bn = batch_norm(ni)
        self.lrelu = lrelu()
        self.shortcut = lambda x: x

    def forward(self, x):
        r = self.shortcut(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.lrelu(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.lrelu(x.add_(r))
        return x

In [126]:
class GeneratorGlobal(nn.Module):
    def __init__(self):
        super().__init__()
        
        dim = [3, 64, 128, 256, 512]
        dec = [64, 32, 16, 8]
        
        #Encoder
        self.conv0 = nn.Sequential(
                    conv_2d(dim[0], dim[1], ks=7, stride=1),
                    lrelu(),
                    ResBlock(dim[1], ks=7))
        
        self.conv1 = nn.Sequential(
                    conv_2d(dim[1], dim[1], ks=5, stride=2),
                    batch_norm(dim[1]),
                    lrelu(),
                    ResBlock(dim[1], ks=5))
        
        self.conv2 = nn.Sequential(
                    conv_2d(dim[1], dim[2], ks=3, stride=2),
                    batch_norm(dim[2]),
                    lrelu(),
                    ResBlock(dim[2], ks=3))
        
        self.conv3 = nn.Sequential(
                    conv_2d(dim[2], dim[3], ks=3, stride=2),
                    batch_norm(dim[3]),
                    lrelu(),
                    ResBlock(dim[3], ks=3))
        
        self.conv4 = nn.Sequential(
                    conv_2d(dim[3], dim[4], ks=3, stride=2),
                    batch_norm(dim[4]),
                    lrelu(),
                    ResBlock(dim[4], ks=3),
                    ResBlock(dim[4], ks=3),
                    ResBlock(dim[4], ks=3),
                    ResBlock(dim[4], ks=3))
        
        self.fc1 = nn.Sequential(
                    fc_nn(dim[1]*dim[4], dim[4]))

        
        
        #Decoder
        
        #Layer-feat8 [bs, 64, 8, 8]
        self.feat8_ = nn.Sequential(
                    fc_nn(dim[4], dim[1]*8*8))
        self.feat8 = nn.Sequential(
                    relu())
        
        #Layer-feat32 [bs, 32, 32, 32]
        self.feat32 = nn.Sequential(
                    deconv_2d(dec[0], dec[1], 3, 4, 0, 1),
                    relu())
        
        #Layer-feat64 [bs, 16, 64, 64]
        self.feat64 = nn.Sequential(
                    deconv_2d(dec[1], dec[2], 3, 2, 1, 1),
                    relu())
        
        #Layer-feat128 [bs, 8, 128, 128]
        self.feat128 = nn.Sequential(
                    deconv_2d(dec[2], dec[3], 3, 2, 1, 1),
                    relu())
    
        #Layer - deconv0 [bs, 512, 16, 16]
        self.deconv0_16 = nn.Sequential(
                    ResBlock(ni=576),
                    ResBlock(ni=576),
                    ResBlock(ni=576),
                    deconv_2d(576, dim[4], 3, 2, 1, 1),
                    batch_norm(dim[4]),
                    relu()
                    )
        
        #Layer - deconv1 [bs, 256, 32, 32]
        
    select16_res_1 = res_block(conv3, ni=256) #Output: [bs, 256, 16, 16]
    dec16_res2_t = torch.cat((deconv0_16, select16_res_1), 1) #Output: [bs, 768, 16, 16]
    dec16_res2 = res_block(res_block(dec16_res2_t, ni=768, ks=3), ni=768, ks=3) #Output: [bs, 768, 16, 16]
    deconv1_32 = deconv_bn_relu(dec16_res2, 768, 256, 3, 2, 1, 1) #Output: [bs, 256, 32, 32]
        
    def forward(self, x, noise):
        conv0 = self.conv0(x)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        fc1 = self.fc1(conv4)
        fc2 = torch.maximum(fc1[:, 0:256], fc1[:, 256:])
        
        feat8_ = self.feat8_(torch.cat((fc2, noise), 1)).view(fc2.size()[0], 64, 8, 8) #Output: [bs, 64, 8, 8]
        feat8 = self.feat8(feat8_) #Output: [bs, 64, 8, 8]
        
        feat32 = self.feat32(feat8) #Output: [bs, 32, 32, 32]
        
        feat64 = self.feat64(feat32) #Output: [bs, 16, 64, 64]
        
        feat128 = self.feat128(feat64) #Output: [bs, 8, 128, 128]
        
        deconv0_16 = self.deconv0_16(torch.cat((feat8, conv4), 1)) #Output: [bs, 512, 16, 16]
        
        deconv1_32 = self.deconv1_32()
        
        #return conv0, conv1, conv2, conv3, conv4, fc2
        return deconv0_16

In [127]:
input1 = torch.randn(49, 3, 128, 128)
noi = torch.randn(49, 256)

model = GeneratorGlobal()
feats = model(input1, noi)

feats.shape

torch.Size([49, 512, 16, 16])

In [None]:
def generator_global_decoder(conv0, conv1, conv2, conv3, conv4, fc):
    
    '''
    conv0 : [batch_size, 64, 128, 128]
    conv1 : [batch_size, 64, 64, 64]
    conv2 : [batch_size, 128, 32, 32]
    conv3 : [batch_size, 256, 16, 16]
    conv4 : [batch_size, 512, 8, 8]
       fc : [batch_size, 256]
    '''
    
    batch_size = fc.shape[0]
    
    I_P_32 = torch.randn(batch_size, 32, 32, 32)
    I_P_64 = torch.randn(batch_size, 32, 64, 64)
    I_P_128 = torch.randn(batch_size, 32, 128, 128)
    
    #Layer - deconv2 [bs, 128, 64, 64]
    select32_res_1_t = torch.cat((conv2, feat32, I_P_32), 1) #Output: [bs, 192, 32, 32]
    select32_res_1 = res_block(select32_res_1_t, ni=192, ks=3) #Output: [bs, 192, 32, 32]
    dec32_res2_t = torch.cat((deconv1_32, select32_res_1), 1) #Output: [bs, 448, 32, 32]
    dec32_res2 = res_block(res_block(dec32_res2_t, ni=448, ks=3), ni=448, ks=3) #Output: [bs, 448, 32, 32]
    deconv2_64 = deconv_bn_relu(dec32_res2, 448, 128, 3, 2, 1, 1) #Output: [bs, 128, 64, 64]
    
    img32 = conv_tanh(dec32_res2, ni=448, nf=3, ks=3) #Output: [bs, 3, 32, 32]
    
    
    #Layer - deconv3 [bs, 64, 128, 128]
    select64_res_1_t = torch.cat((conv1, feat64, I_P_64), 1) #Output: [bs, 112, 64, 64]
    select64_res_1 = res_block(select64_res_1_t, ni=112, ks=5) #Output: [bs, 112, 64, 64]
    dec64_res2_t = torch.cat((deconv2_64, select64_res_1), 1) #Output: [bs, 240, 64, 64] 
    #Not concatenated img32
    dec64_res2 = res_block(res_block(dec64_res2_t, ni=240, ks=3), ni=240, ks=3) #Output: [bs, 240, 64, 64] 
    deconv3_128 = deconv_bn_relu(dec64_res2, 240, 64, 3, 2, 1, 1) #Output: [bs, 64, 128, 128]
    
    img64 = conv_tanh(dec64_res2, ni=240, nf=3, ks=3) #Output: [bs, 3, 64, 64]
    
    
    #Layer - conv5 [bs, 64, 128, 128]
    select128_res_1_t = torch.cat((conv0, feat128, I_P_128), 1) #Output: [bs, 104, 128, 128]
    select128_res_1 = res_block(select128_res_1_t, ni=104, ks=7) #Output: [bs, 104, 128, 128]
    dec128_res2_t = torch.cat((deconv3_128, select128_res_1), 1) #Output: [bs, 168, 128, 128] 
    #Not concatenated img64, eyel, eyer, nose, mouth, c_eyel, c_eyer, c_nose, c_mouth
    dec128_res2 = res_block(dec128_res2_t, ni=168, ks=5) #Output: [bs, 168, 128, 128]
    dec128_conv5 = conv_bn_lrelu(dec128_res2, ni=168, nf=64, ks=5, stride=1) #Output: [bs, 64, 128, 128]
    dec128_conv5_r = res_block(dec128_conv5, ni=64) #Output: [bs, 64, 128, 128]
    
    
    #Layer - conv6 [bs, 32, 128, 128]
    dec128_conv6 = conv_bn_lrelu(dec128_conv5_r, ni=64, nf=32, ks=3, stride=1) #Output: [bs, 32, 128, 128]
    
    
    #Layer - conv7 [bs, 3, 128, 128]
    img128 = conv_tanh(dec128_conv6, ni=32, nf=3) #Output: [bs, 3, 128, 128]
    
    return img128

In [78]:
def receive(a1, a2, a3, a4, a5, a6):
    print(a1.shape)
    print(a2.shape)
    print(a3.shape)
    print(a4.shape)
    print(a5.shape)
    print(a6.shape)

In [79]:
input1 = torch.randn(49, 3, 128, 128)
model = GeneratorGlobal()
feats = model(input1)

receive(*feats)

AttributeError: 'GeneratorGlobal' object has no attribute 'x'