# Cycle GAN

Based on paper [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)

In [10]:
import torch
from torch import nn
import math

In [90]:
class resblock(nn.Module):
    def __init__(self,out_,k_):
        super(resblock,self).__init__()
        
        self.out_ = out_
        self.k_ = k_
        
        self.padding = math.floor(self.k_/2)
        self.conv1 = nn.Conv2d(self.out_,self.out_,self.k_,stride=2,padding=self.padding,bias=False)
        self.bn1 = nn.BatchNorm2d(self.out_)
        self.leaky1 = nn.LeakyReLU()
        self.upconv = nn.Upsample(scale_factor=2)
        self.bn2 = nn.BatchNorm2d(self.out_)
        
    def forward(self, x):
        x1 = x.clone()
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.leaky1(x)
        
        x = self.upconv(x)
        x = self.bn2(x)
        x = x + x1
        return x
        

In [91]:
rs = resblock(256,5)

In [92]:
class generative(nn.Module):
    def __init__(self,fn_list,k_list = None):
        super(generative,self).__init__()
        self.fn_list = fn_list
        if k_list == None:
            self.k_list = [3]*(len(self.fn_list)-1)
        else:
            self.k_list = k_list
        self.leaky = nn.LeakyReLU()
        for i in range(len(self.fn_list)-1):
            setattr(self,"tran_%s"%(i),nn.Conv2d(self.fn_list[i],
                                                 self.fn_list[i+1],
                                                 kernel_size = self.k_list[i],
                                                 padding = self.k2pad(self.k_list[i]),
                                                 bias = False))
            setattr(self,"bn_trans_%s"%(i),nn.BatchNorm2d(self.fn_list[i+1]))
            
            setattr(self,"resblock_%s"%(i),resblock(self.fn_list[i+1],self.k_list[i]))
            setattr(self,"bn_res_%s"%(i),nn.BatchNorm2d(self.fn_list[i+1]))
            
        self.conv_out = nn.Conv2d(self.fn_list[-1],3,1,bias=False)
            
    def k2pad(self,k):
        return math.floor(k/2)
    
    def forward(self,x):
        for i in range(len(self.fn_list)-1):
            x = getattr(self,"tran_%s"%(i))(x)
            x = getattr(self,"bn_trans_%s"%(i))(x)
            x = self.leaky(x)
            x = getattr(self,"resblock_%s"%(i))(x)
            x = getattr(self,"bn_res_%s"%(i))(x)
            x = self.leaky(x)
        x = self.conv_out(x)
        return x

In [118]:
class resblock_d(nn.Module):
    def __init__(self,in_,out_,k_,downsample=True):
        super(resblock_d,self).__init__()
        
        self.in_ = in_
        self.out_ = out_
        self.k_ = k_
        self.downsample = downsample
        
        self.leaky = nn.LeakyReLU()
        
        self.conv1 = nn.Conv2d(self.in_,
                               self.in_,
                               kernel_size=self.k_,
                               stride=1,
                               padding=self.k2pad(self.k_),
                               bias=False)
        self.bn_1 = nn.BatchNorm2d(self.in_)
        
        self.conv2 = nn.Conv2d(self.in_,
                               self.in_,
                               kernel_size=self.k_,
                               stride=1,
                               padding=self.k2pad(self.k_),
                               bias=False)
        self.bn_2 = nn.BatchNorm2d(self.in_)
        
        if self.downsample:
            self.out = nn.Conv2d(self.in_,self.out_,self.k_,
                                 padding=self.k2pad(self.k_),stride=2,bias=False)
            self.bn_out = nn.BatchNorm2d(self.out_)
        
    def k2pad(self,k):
        return math.floor(k/2)
    
    def forward(self,x):
        x1 = x.clone()
        
        x = self.conv1(x)
        x = self.bn_1(x)
        x = self.leaky(x)
        
        x = self.conv2(x)
        x = self.bn_2(x)
        x = self.leaky(x)
        
        x = x + x1
        
        if self.downsample:
            x = self.out(x)
            x = self.bn_out(x)
            x = self.leaky(x)
            
        return x

class discriminative(nn.Module):
    def __init__(self,fn_list,k_list = None):
        super(discriminative,self).__init__()
        self.fn_list = fn_list
        
        if k_list == None:
            self.k_list = [3]*(len(self.fn_list)-1)
        else:
            self.k_list = k_list
            
        self.conv_in = nn.Conv2d(3,self.fn_list[0],3,padding=1,bias=False)
        self.bn_in = nn.BatchNorm2d(self.fn_list[0])
        
        for i in range(len(self.fn_list)-1):
            setattr(self,"res_%s"%(i),resblock_d(self.fn_list[i],
                                                 self.fn_list[i+1],
                                                 k_ = self.k_list[i]))
        
    def forward(self,x):
        x = self.conv_in(x)
        x = self.bn_in(x)
        
        for i in range(len(self.fn_list)-1):
            x = getattr(self,"res_%s"%(i))(x)
        return x

In [123]:
D = discriminative([64,64,64,128,128,256,256])

In [125]:
D

discriminative(
  (conv_in): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn_in): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res_0): resblock_d(
    (leaky): LeakyReLU(negative_slope=0.01)
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn_2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (out): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn_out): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (res_1): resblock_d(
    (leaky): LeakyReLU(negative_slope=0.01)
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn_1): BatchNorm2d(6

In [124]:
D(torch.rand(2,3,256,256)).size()

torch.Size([2, 256, 4, 4])

In [97]:
G = generative([3,128,128,128,256,256])

In [98]:
G

generative(
  (leaky): LeakyReLU(negative_slope=0.01)
  (tran_0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn_trans_0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (resblock_0): resblock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (leaky1): LeakyReLU(negative_slope=0.01)
    (upconv): Upsample(scale_factor=2, mode=nearest)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (bn_res_0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (tran_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn_trans_1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (resblock_1): resblock(
    (conv1): Conv2d(128, 128, kernel_size=(3

In [96]:
G(torch.rand(2,3,160,160)).size()

torch.Size([2, 3, 160, 160])