## Network Architecture

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.utils import save_image
from torchvision.transforms import ToTensor, Normalize, Compose
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as tt
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split
from torchvision.datasets.utils import download_url
import os
#import cv2
import math
import tarfile
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import FileLink
from IPython.display import Image

In [2]:
def relu(_input):
    model = nn.ReLU()
    return model(_input)

def lrelu(_input):
    model = nn.LeakyReLU(0.2)
    return model(_input)

def batch_norm():
    pass

def conv_2d(_input, ni, nf, ks, stride=2):
    model = nn.Conv2d(in_channels=ni, out_channels=nf, kernel_size=ks, stride=stride, padding=ks//2, bias=False)
    return model(_input)

def conv_bn(_input, ni, nf, ks, stride):
    #out1 = conv_2d(_input, ni, nf, ks, stride)
    #model2 = nn.BatchNorm2d(nf)
    #return model2(out1)
    model = nn.Sequential(nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2), 
                          nn.BatchNorm2d(nf)
                         )
    return model(_input)

def conv_bn_lrelu(_input, ni, nf, ks, stride):
    out1 = conv_bn(_input, ni, nf, ks, stride)
    out2 = lrelu(out1)
    return out2

def conv_tanh(_input, ni, nf=3, ks=3, stride=1):
    model = nn.Sequential(nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2), 
                          nn.Tanh()
                         )
    return model(_input)
    
def fc_nn(_input, input_size, output_size):
    model = nn.Sequential(nn.Flatten(), 
                          nn.Linear(input_size, output_size)
                         )
    return model(_input)

In [3]:
def res_block(_input, ni, ks=3, stride=1):
    out1 = conv_bn_lrelu(_input, ni, ni, ks, stride)
    out2 = conv_bn(out1, ni, ni, ks, stride)
    out3 = _input + out2
    out4 = lrelu(out3)
    return out4

In [4]:
def generator_global_encoder(_input):
    ni = 3
    nf = 64
    
    conv0_ = conv_2d(_input, ni, nf, ks=7, stride=1)
    conv0 = lrelu(conv0_)
    conv0r = res_block(conv0, nf, ks=7)
    
    conv1 = conv_bn_lrelu(conv0r, nf*1, nf*1, ks=5, stride=2)
    conv1r = res_block(conv1, nf*1, ks=5)
    
    conv2 = conv_bn_lrelu(conv1r, nf*1, nf*2, ks=3, stride=2)
    conv2r = res_block(conv2, nf*2, ks=3)
    
    conv3 = conv_bn_lrelu(conv2r, nf*2, nf*4, ks=3, stride=2)
    conv3r = res_block(conv3, nf*4, ks=3)
    
    conv4 = conv_bn_lrelu(conv3r, nf*4, nf*8, ks=3, stride=2)
    conv4r1 = res_block(conv4, nf*8, ks=3)
    conv4r2 = res_block(conv4r1, nf*8, ks=3)
    conv4r3 = res_block(conv4r2, nf*8, ks=3)
    conv4r4 = res_block(conv4r3, nf*8, ks=3)
    
    fc1 = fc_nn(conv4r4, 64*512, 512)
    fc2 = torch.maximum(fc1[:, 0:256], fc1[:, 256:])
    
    return conv0r, conv1r, conv2r, conv3r, conv4r4, fc2

In [5]:
def deconv_2d(_input, ni, nf, ks, stride=2, padding=1, output_padding=1):
    model = nn.ConvTranspose2d(in_channels=ni, out_channels=nf, 
                               kernel_size=ks, stride=stride, 
                               padding=padding, output_padding=output_padding)
    return model(_input)

def deconv_bn_relu(_input, ni, nf, ks, stride=2, padding=1, output_padding=1):
    model = nn.Sequential(nn.ConvTranspose2d(in_channels=ni, out_channels=nf, 
                               kernel_size=ks, stride=stride, 
                               padding=padding, output_padding=output_padding),
                          nn.BatchNorm2d(nf),
                          nn.ReLU())
    return model(_input)

def deconv_bn_lrelu(_input, ni, nf, ks, stride=2, padding=1, output_padding=1):
    model = nn.Sequential(nn.ConvTranspose2d(in_channels=ni, out_channels=nf, 
                               kernel_size=ks, stride=stride, 
                               padding=padding, output_padding=output_padding),
                          nn.BatchNorm2d(nf), 
                          nn.LeakyReLU(0.2))
    return model(_input)

In [6]:
#Implementation
def generator_global_decoder(conv0, conv1, conv2, conv3, conv4, fc):
    
    batch_size = fc.shape[0]
    
    I_P_32 = torch.randn(batch_size, 32, 32, 32)
    I_P_64 = torch.randn(batch_size, 32, 64, 64)
    I_P_128 = torch.randn(batch_size, 32, 128, 128)
    
    #Layer-feat8
    noise = torch.randn(batch_size, 256)
    _input = torch.cat((fc, noise), 1)  #Output: [bs, 512]
    feat8 = relu(fc_nn(_input, 512, 64*8*8).reshape([batch_size, 64, 8, 8])) #Output: [bs, 64, 8, 8]
    
    
    #Layer-feat32
    feat32 = relu(deconv_2d(feat8, 64, 32, 3, 4, 0, 1))  #Output: [bs, 32, 32, 32]
    
    
    #Layer-feat64
    feat64 = relu(deconv_2d(feat32, 32, 16, 3, 2, 1, 1)) #Output: [bs, 16, 64, 64]

    
    #Layer-feat128
    feat128 = relu(deconv_2d(feat64, 16, 8, 3, 2, 1, 1)) #Output: [bs, 8, 128, 128]

    
    #Layer - deconv0
    select8_res_1_t = torch.cat((feat8, conv4), 1) #Output: [bs, 576, 8, 8]
    select8_res_1 = res_block(select8_res_1_t, ni=576, ks=3) #Output: [bs, 576, 8, 8]
    dec8_res2 = res_block(res_block(select8_res_1, ni=576, ks=3), ni=576, ks=3) #Output: [bs, 576, 8, 8]
    deconv0_16 = deconv_bn_relu(dec8_res2, 576, 512, 3, 2, 1, 1) #Output: [bs, 512, 16, 16]
    
    
    #Layer - deconv1
    select16_res_1 = res_block(conv3, ni=256) #Output: [bs, 256, 16, 16]
    dec16_res2_t = torch.cat((deconv0_16, select16_res_1), 1) #Output: [bs, 768, 16, 16]
    dec16_res2 = res_block(res_block(dec16_res2_t, ni=768, ks=3), ni=768, ks=3) #Output: [bs, 768, 16, 16]
    deconv1_32 = deconv_bn_relu(dec16_res2, 768, 256, 3, 2, 1, 1) #Output: [bs, 256, 32, 32]
    
    
    #Layer - deconv2
    select32_res_1_t = torch.cat((conv2, feat32, I_P_32), 1) #Output: [bs, 192, 32, 32]
    select32_res_1 = res_block(select32_res_1_t, ni=192, ks=3) #Output: [bs, 192, 32, 32]
    dec32_res2_t = torch.cat((deconv1_32, select32_res_1), 1) #Output: [bs, 448, 32, 32]
    dec32_res2 = res_block(res_block(dec32_res2_t, ni=448, ks=3), ni=448, ks=3) #Output: [bs, 448, 32, 32]
    deconv2_64 = deconv_bn_relu(dec32_res2, 448, 128, 3, 2, 1, 1) #Output: [bs, 128, 64, 64]
    
    img32 = conv_tanh(dec32_res2, ni=448, nf=3, ks=3) #Output: [bs, 3, 32, 32]
    
    
    #Layer - deconv3
    select64_res_1_t = torch.cat((conv1, feat64, I_P_64), 1) #Output: [bs, 112, 64, 64]
    select64_res_1 = res_block(select64_res_1_t, ni=112, ks=5) #Output: [bs, 112, 64, 64]
    dec64_res2_t = torch.cat((deconv2_64, select64_res_1), 1) #Output: [bs, 240, 64, 64] 
    #Not concatenated img32
    dec64_res2 = res_block(res_block(dec64_res2_t, ni=240, ks=3), ni=240, ks=3) #Output: [bs, 240, 64, 64] 
    deconv3_128 = deconv_bn_relu(dec64_res2, 240, 64, 3, 2, 1, 1) #Output: [bs, 64, 128, 128]
    
    img64 = conv_tanh(dec64_res2, ni=240, nf=3, ks=3) #Output: [bs, 3, 64, 64]
    
    
    #Layer - conv5
    select128_res_1_t = torch.cat((conv0, feat128, I_P_128), 1) #Output: [bs, 104, 128, 128]
    select128_res_1 = res_block(select128_res_1_t, ni=104, ks=7) #Output: [bs, 104, 128, 128]
    dec128_res2_t = torch.cat((deconv3_128, select128_res_1), 1) #Output: [bs, 168, 128, 128] 
    #Not concatenated img64, eyel, eyer, nose, mouth, c_eyel, c_eyer, c_nose, c_mouth
    dec128_res2 = res_block(dec128_res2_t, ni=168, ks=5) #Output: [bs, 168, 128, 128]
    dec128_conv5 = conv_bn_lrelu(dec128_res2, ni=168, nf=64, ks=5, stride=1) #Output: [bs, 64, 128, 128]
    dec128_conv5_r = res_block(dec128_conv5, ni=64) #Output: [bs, 64, 128, 128]
    
    
    #Layer - conv6
    dec128_conv6 = conv_bn_lrelu(dec128_conv5_r, ni=64, nf=32, ks=3, stride=1) #Output: [bs, 32, 128, 128]
    
    
    #Layer - conv7
    img128 = conv_tanh(dec128_conv6, ni=32, nf=3) #Output: [bs, 3, 128, 128]
    
    return img128

In [None]:
conv0 = torch.randn(49, 64, 128, 128)
conv1 = torch.randn(49, 64, 64, 64)
conv2 = torch.randn(49, 128, 32, 32)
conv3 = torch.randn(49, 256, 16, 16)
conv4 = torch.randn(49, 512, 8, 8)
fc = torch.randn(49, 256)

output = generator_global_decoder(conv0, conv1, conv2, conv3, conv4, fc)
print(output.shape)

In [None]:
input1 = torch.randn(49, 3, 128, 128)

In [None]:
feats = generator_global_encoder(input1)

In [None]:
def receive(a1, a2, a3, a4, a5, a6):
    print(a1.shape)
    print(a2.shape)
    print(a3.shape)
    print(a4.shape)
    print(a5.shape)
    print(a6.shape)

In [None]:
receive(*feats)

In [None]:
feats[5].shape

In [None]:
result = generator_global_encoder(input1)
len(result)

In [None]:
result[5][:, 0:256].shape

## Please ignore the code after this point. It was used to test the running of the code

### 7

In [None]:
a = res_block(conv2, 128)
a.shape

torch.Size([49, 128, 32, 32])

### 6

In [None]:
input1 = torch.randn(48, 64, 8, 8)
input2 = torch.randn(56, 64, 8, 8)

a = torch.cat((input1, input2), 0)
a.shape

torch.Size([104, 64, 8, 8])

### 5

In [None]:
aa = 49
a = torch.randn(aa, 5, 1)
a.shape

torch.Size([49, 5, 1])

### 4

In [None]:
m = nn.LeakyReLU(0.1)

In [None]:
m

LeakyReLU(negative_slope=0.1)

In [None]:
input1 = torch.randn(2)

In [None]:
input1

tensor([-0.9483, -1.0663])

In [None]:
output = m(input1)

In [None]:
output

tensor([-0.0948, -0.1066])

### 3

In [None]:
m = nn.Conv2d(16, 33, 3, 2, 1)
input2 = torch.randn(20, 16, 600, 50)
output = m(input2)

In [None]:
input2.shape

torch.Size([20, 16, 600, 50])

In [None]:
output.shape

torch.Size([20, 33, 300, 25])

In [None]:
m

Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))

### 2

In [None]:
input2 = torch.randn(20, 16, 600, 50)
a = conv_2d(input2, 16, 33, 3)
a.shape

torch.Size([20, 33, 300, 25])

### 1

In [None]:
b = conv_bn_lrelu(input2, 16, 33, 3, 2)

In [None]:
b.shape

torch.Size([20, 33, 300, 25])

In [None]:
def generator_global_decoder(feat128, feat64, feat32, feat16, feat8, featvec):
    
    batch_size = featvec.shape[0]
    noise = torch.randn(batch_size, 256)
    _input = torch.cat((featvec, noise), 1)  #Output: [bs, 512]
    
    I_P_32 = torch.randn(batch_size, 32, 32, 32)
    I_P_64 = torch.randn(batch_size, 32, 64, 64)
    I_P_128 = torch.randn(batch_size, 32, 128, 128)
    
    initial_8 = relu(fc_nn(_input, 512, 64*8*8).reshape([batch_size, 64, 8, 8])) #Output: [bs, 64, 8, 8]
    initial_32 = relu(deconv_2d(initial_8, 64, 32, 3, 4, 0, 1))  #Output: [bs, 32, 32, 32]
    initial_64 = relu(deconv_2d(initial_32, 32, 16, 3, 2, 1, 1)) #Output: [bs, 16, 64, 64]
    initial_128 = relu(deconv_2d(initial_64, 16, 8, 3, 2, 1, 1)) #Output: [bs, 8, 128, 128]
    
    before_select8_t = torch.cat((initial_8, feat8), 1) #Output: [bs, 576, 8, 8]
    before_select8 = res_block(before_select8_t, ni=576, ks=3) #Output: [bs, 576, 8, 8]
    reconstruct8 = res_block(res_block(before_select8, ni=576, ks=3), ni=576, ks=3) #Output: [bs, 576, 8, 8]
    
    reconstruct16_deconv = deconv_bn_relu(reconstruct8, 576, 512, 3, 2, 1, 1) #Output: [bs, 512, 16, 16]
    before_select16 = res_block(feat16, 256) #Output: [bs, 256, 16, 16]
    reconstruct16_t = torch.cat((reconstruct16_deconv, before_select16), 1) #Output: [bs, 768, 16, 16]
    reconstruct16 = res_block(res_block(reconstruct16_t, ni=768, ks=3), ni=768, ks=3) #Output: [bs, 768, 16, 16]
    
    reconstruct32_deconv = deconv_bn_relu(reconstruct16, 768, 256, 3, 2, 1, 1) #Output: [bs, 256, 32, 32]
    before_select32_t = torch.cat((feat32, initial_32, I_P_32), 1) #Output: [bs, 192, 32, 32]
    before_select32 = res_block(before_select32_t, ni=192, ks=3) #Output: [bs, 192, 32, 32]
    reconstruct32_t = torch.cat((reconstruct32_deconv, before_select32), 1) #Output: [bs, 448, 32, 32]
    reconstruct32 = res_block(res_block(reconstruct32_t, ni=448, ks=3), ni=448, ks=3) #Output: [bs, 448, 32, 32]
    img32 = conv_tanh(reconstruct32, ni=448, nf=3, ks=3) #Output: [bs, 3, 32, 32]
    
    reconstruct64_deconv = deconv_bn_relu(reconstruct32, 448, 128, 3, 2, 1, 1) #Output: [bs, 128, 64, 64]
    before_select64_t = torch.cat((feat64, initial_64, I_P_64), 1) #Output: [bs, 112, 64, 64]
    before_select64 = res_block(before_select64_t, ni=112, ks=5) #Output: [bs, 112, 64, 64]
    reconstruct64_t = torch.cat((reconstruct64_deconv, before_select64), 1) #Output: [bs, 240, 64, 64] 
    #Not concatenated img32
    reconstruct64 = res_block(res_block(reconstruct64_t, ni=240, ks=3), ni=240, ks=3) #Output: [bs, 240, 64, 64]
    img64 = conv_tanh(reconstruct64, ni=240, nf=3, ks=3) #Output: [bs, 3, 64, 64]
    
    reconstruct128_deconv = deconv_bn_relu(reconstruct64, 240, 64, 3, 2, 1, 1) #Output: [bs, 64, 128, 128]
    before_select128_t = torch.cat((feat128, initial_128, I_P_128), 1) #Output: [bs, 104, 128, 128]
    before_select128 = res_block(before_select128_t, ni=104, ks=7) #Output: [bs, 104, 128, 128]
    reconstruct128_t = torch.cat((reconstruct128_deconv, before_select128), 1) #Output: [bs, 168, 128, 128] 
    #Not concatenated img64, eyel, eyer, nose, mouth, c_eyel, c_eyer, c_nose, c_mouth
    reconstruct128 = res_block(reconstruct128_t, ni=168, ks=5) #Output: [bs, 168, 128, 128]
    
    reconstruct128_1 = conv_bn_lrelu(reconstruct128, ni=168, nf=64, ks=5, stride=1) #Output: [bs, 64, 128, 128]
    reconstruct128_1_r = res_block(reconstruct128_1, ni=64) #Output: [bs, 64, 128, 128]
    reconstruct128_2 = conv_bn_lrelu(reconstruct128_1_r, ni=64, nf=32, ks=3, stride=1) #Output: [bs, 32, 128, 128]
    img128 = conv_tanh(reconstruct128_2, ni=32, nf=3) #Output: [bs, 3, 128, 128]
    
    return img128

In [None]:
def generator_global_decoder(feat128, feat64, feat32, feat16, feat8, featvec):
    
    batch_size = featvec.shape[0]
    
    I_P_32 = torch.randn(batch_size, 32, 32, 32)
    I_P_64 = torch.randn(batch_size, 32, 64, 64)
    I_P_128 = torch.randn(batch_size, 32, 128, 128)
    
    #Layer-feat8
    noise = torch.randn(batch_size, 256)
    _input = torch.cat((featvec, noise), 1)  #Output: [bs, 512]
    initial_8 = relu(fc_nn(_input, 512, 64*8*8).reshape([batch_size, 64, 8, 8])) #Output: [bs, 64, 8, 8]
    
    
    #Layer-feat32
    initial_32 = relu(deconv_2d(initial_8, 64, 32, 3, 4, 0, 1))  #Output: [bs, 32, 32, 32]
    
    
    #Layer-feat64
    initial_64 = relu(deconv_2d(initial_32, 32, 16, 3, 2, 1, 1)) #Output: [bs, 16, 64, 64]

    
    #Layer-feat128
    initial_128 = relu(deconv_2d(initial_64, 16, 8, 3, 2, 1, 1)) #Output: [bs, 8, 128, 128]

    
    #Layer - deconv0
    before_select8_t = torch.cat((initial_8, feat8), 1) #Output: [bs, 576, 8, 8]
    before_select8 = res_block(before_select8_t, ni=576, ks=3) #Output: [bs, 576, 8, 8]
    reconstruct8 = res_block(res_block(before_select8, ni=576, ks=3), ni=576, ks=3) #Output: [bs, 576, 8, 8]
    reconstruct16_deconv = deconv_bn_relu(reconstruct8, 576, 512, 3, 2, 1, 1) #Output: [bs, 512, 16, 16]
    
    
    #Layer - deconv1
    before_select16 = res_block(feat16, ni=256) #Output: [bs, 256, 16, 16]
    reconstruct16_t = torch.cat((reconstruct16_deconv, before_select16), 1) #Output: [bs, 768, 16, 16]
    reconstruct16 = res_block(res_block(reconstruct16_t, ni=768, ks=3), ni=768, ks=3) #Output: [bs, 768, 16, 16]
    reconstruct32_deconv = deconv_bn_relu(reconstruct16, 768, 256, 3, 2, 1, 1) #Output: [bs, 256, 32, 32]
    
    
    #Layer - deconv2
    before_select32_t = torch.cat((feat32, initial_32, I_P_32), 1) #Output: [bs, 192, 32, 32]
    before_select32 = res_block(before_select32_t, ni=192, ks=3) #Output: [bs, 192, 32, 32]
    reconstruct32_t = torch.cat((reconstruct32_deconv, before_select32), 1) #Output: [bs, 448, 32, 32]
    reconstruct32 = res_block(res_block(reconstruct32_t, ni=448, ks=3), ni=448, ks=3) #Output: [bs, 448, 32, 32]
    reconstruct64_deconv = deconv_bn_relu(reconstruct32, 448, 128, 3, 2, 1, 1) #Output: [bs, 128, 64, 64]
    
    img32 = conv_tanh(reconstruct32, ni=448, nf=3, ks=3) #Output: [bs, 3, 32, 32]
    
    
    #Layer - deconv3
    before_select64_t = torch.cat((feat64, initial_64, I_P_64), 1) #Output: [bs, 112, 64, 64]
    before_select64 = res_block(before_select64_t, ni=112, ks=5) #Output: [bs, 112, 64, 64]
    reconstruct64_t = torch.cat((reconstruct64_deconv, before_select64), 1) #Output: [bs, 240, 64, 64] 
    #Not concatenated img32
    reconstruct64 = res_block(res_block(reconstruct64_t, ni=240, ks=3), ni=240, ks=3) #Output: [bs, 240, 64, 64] 
    reconstruct128_deconv = deconv_bn_relu(reconstruct64, 240, 64, 3, 2, 1, 1) #Output: [bs, 64, 128, 128]
    
    img64 = conv_tanh(reconstruct64, ni=240, nf=3, ks=3) #Output: [bs, 3, 64, 64]
    
    
    #Layer - conv5
    before_select128_t = torch.cat((feat128, initial_128, I_P_128), 1) #Output: [bs, 104, 128, 128]
    before_select128 = res_block(before_select128_t, ni=104, ks=7) #Output: [bs, 104, 128, 128]
    reconstruct128_t = torch.cat((reconstruct128_deconv, before_select128), 1) #Output: [bs, 168, 128, 128] 
    #Not concatenated img64, eyel, eyer, nose, mouth, c_eyel, c_eyer, c_nose, c_mouth
    reconstruct128 = res_block(reconstruct128_t, ni=168, ks=5) #Output: [bs, 168, 128, 128]
    reconstruct128_1 = conv_bn_lrelu(reconstruct128, ni=168, nf=64, ks=5, stride=1) #Output: [bs, 64, 128, 128]
    reconstruct128_1_r = res_block(reconstruct128_1, ni=64) #Output: [bs, 64, 128, 128]
    
    
    #Layer - conv6
    reconstruct128_2 = conv_bn_lrelu(reconstruct128_1_r, ni=64, nf=32, ks=3, stride=1) #Output: [bs, 32, 128, 128]
    
    
    #Layer - conv7
    img128 = conv_tanh(reconstruct128_2, ni=32, nf=3) #Output: [bs, 3, 128, 128]
    
    return img128

In [None]:
https://github.com/UnrealLink/TP-GAN