# Cycle GAN

Based on paper [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)

In [1]:
import torch
from torch import nn
import math
import numpy as np

In [2]:
HEIGHT = 160
WIDTH = 160
DATA = "/Users/zhangxiaochen/github/CycleGAN/datasets/summer2winter_yosemite/"

In [3]:
from models import generative,discriminative

In [4]:
G_1 = generative([3,128,128,256,128])
D_1 = discriminative([64,64,64,128,128,256,256])

G_2 = generative([3,128,128,256,128]) 
D_2 = discriminative([64,64,64,128,128,256,256])

In [5]:
D = discriminative([64,64,64,128,128,256,256])

In [6]:
from torch.optim import Adam

Adam(list(G_1.parameters())+list(G_2.parameters()))

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)

In [7]:
from torch.optim import Adam
class cycle(nn.Module):
    def __init__(self,g_fn = [3,128,128,256,128],d_fn = [64,64,64,128,128,256,256]):
        super(cycle,self).__init__()
        self.g_fn = g_fn
        self.d_fn = d_fn
        self.G_x = generative(self.g_fn)
        self.D_x = discriminative(self.d_fn)
        self.G_y = generative(self.g_fn)
        self.D_y = discriminative(self.d_fn)
        
        self.opt_Dx = Adam(self.D_x.parameters())
        self.opt_Dy = Adam(self.D_y.parameters())
        
        self.opt_G = Adam(list(self.G_x.parameters())+list(self.G_x.parameters()))
        
    def zero_grad(self):
        self.opt_Dx.zero_grad()
        self.opt_Dy.zero_grad()
        self.opt_G.zero_grad()
        
#def L_gan_x(y,_y):

In [8]:
c = cycle()

### Loss Function

#### GAN Loss

$\large L_{GAN}(G,D_{Y},X,Y) = {\mathbb E}_{y \tilde{} p_{Data}(y)} [logD_{Y}(y)] +  {\mathbb E}_{x \tilde{} p_{Data}(x)} [log(1-D_{Y}(G(x)))]$

$\large L_{GAN}(F,D_{X},Y,X) = {\mathbb E}_{x \tilde{} p_{Data}(x)} [logD_{X}(x)] +  {\mathbb E}_{y \tilde{} p_{Data}(y)} [log(1-D_{X}(F(y)))]$

#### Cycle Consistency Loss

$\large L_{cyc}(G,F) = {\mathbb E}_{x \tilde{} p_{Data}(x)}[||F(G(x))-x||_{1}]
+ {\mathbb E}_{y \tilde{} p_{Data}(y)}[||G(F(y))-y||_{1}]$

#### Full Objective

$\large L(G,F,D_{X},D_{Y}) =  L_{GAN}(G,D_{Y},X,Y) + L_{GAN}(F,D_{X},Y,X) + \lambda L_{cyc}(G,F)$

### Data Generator

In [9]:
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image

In [10]:
from glob import glob
class data_cg(Dataset):
    def __init__(self,dir_X,dir_Y,iters = 5000):
        self.iters = iters
        self.dir_X = dir_X
        self.dir_Y = dir_Y
        self.X_list = self.glob_list(self.dir_X)
        self.Y_list = self.glob_list(self.dir_Y)
        self.X_urls = np.random.choice(self.X_list,self.iters).tolist()
        self.Y_urls = np.random.choice(self.Y_list,self.iters).tolist()
        self.transform = transforms.Compose([transforms.Resize((HEIGHT,WIDTH)),
                                transforms.ToTensor(),
                                transforms.Normalize([.5,.5,.5],[.5,.5,.5])
                               ])
        
    def __len__(self):
        return self.iters
    
    def __getitem__(self, idx):
        x_url = self.X_urls[idx]
        y_url = self.Y_urls[idx]
        
        X_img = Image.open(x_url).convert("RGB")
        Y_img = Image.open(y_url).convert("RGB")
        
        X = self.transform(X_img)
        Y = self.transform(Y_img)
        
        return X,Y
        
    def glob_list(self,dir):
        if dir[-1]!="/":
            dir = dir+"/"
        dir = dir+"*"
        return glob(dir)

### Dataset Test

In [11]:
ds = data_cg(DATA+"trainA/",DATA+"trainB/")
dl = DataLoader(ds,batch_size=2,shuffle=True)

gen = iter(dl)

a = next(gen)

### Trainning

In [12]:
from torch.nn import MSELoss

In [13]:
from torch import cuda

In [14]:
CUDA  = cuda.is_available()

In [15]:
mse = MSELoss()

#### GAN Loss

$\large L_{GAN}(G,D_{Y},X,Y) = {\mathbb E}_{y \tilde{} p_{Data}(y)} [logD_{Y}(y)] +  {\mathbb E}_{x \tilde{} p_{Data}(x)} [log(1-D_{Y}(G(x)))]$

$\large L_{GAN}(F,D_{X},Y,X) = {\mathbb E}_{x \tilde{} p_{Data}(x)} [logD_{X}(x)] +  {\mathbb E}_{y \tilde{} p_{Data}(y)} [log(1-D_{X}(F(y)))]$

#### Cycle Consistency Loss

$\large L_{cyc}(G,F) = {\mathbb E}_{x \tilde{} p_{Data}(x)}[||F(G(x))-x||_{1}]
+ {\mathbb E}_{y \tilde{} p_{Data}(y)}[||G(F(y))-y||_{1}]$

#### Full Objective

$\large L(G,F,D_{X},D_{Y}) =  L_{GAN}(G,D_{Y},X,Y) + L_{GAN}(F,D_{X},Y,X) + \lambda L_{cyc}(G,F)$

In [16]:
def loss_D(D,real,fake):
    dt = torch.cat([real,fake],dim=0)
    y_ = D(dt)
    y = torch.cat([torch.zeros(real.size()[0],1),torch.ones(fake.size()[0],1)],dim=0)
    if CUDA:
        y = y.cuda()
    return mse(y_,y)

def train_D(X,Y):
    Y_ = c.G_x(X)
    X_ = c.G_y(Y)
    loss_D_x = loss_D(D = c.D_x, real = X, fake = X_)
    loss_D_y = loss_D(D = c.D_y, real = Y, fake = Y_)
    
    loss_D_x.backward()
    loss_D_y.backward()
    
    c.opt_Dx.step()
    c.opt_Dy.step()
    
    return loss_D_x,loss_D_y

def train_G(X,Y,lbd = 1e1):
    L_gan_y = torch.log(c.D_y(Y)).sum()+torch.log(1-c.D_y(c.G_x(X))).sum()
    L_gan_x = torch.log(c.D_x(X)).sum()+torch.log(1-c.D_x(c.G_y(Y))).sum()
    L_cycle = torch.abs(c.G_y(c.G_x(X))-X).sum() + torch.abs(c.G_x(c.G_y(Y))-Y).sum()
    
    Loss = L_gan_y+L_gan_x+lbd*L_cycle
    
    Loss.backward()
    c.opt_G.step()
    
    return Loss,L_gan_y,L_gan_x,L_cycle
    
def action(*args,**kwargs):
    X,Y = args[0]
    if CUDA:
        X,Y = X.cuda(),Y.cuda()
    
    c.zero_grad()
    loss_D_x,loss_D_y = train_D(X,Y)
    
    c.zero_grad()
    Loss,L_gan_y,L_gan_x,L_cycle = train_G(X,Y,lbd=1e1)
    
    return {"loss_D_x":loss_D_x.item(),
            "loss_D_y":loss_D_y.item(),
            "Loss":Loss.item(),
            "L_gan_y":L_gan_y.item(),
            "L_gan_x":L_gan_x.item(),
            "L_cycle":L_cycle.item(),}
    

In [17]:
from p3self.matchbox import Trainer

In [18]:
trainer=Trainer(data_cg(DATA+"trainA/",DATA+"trainB/"),batch_size=2)

In [19]:
trainer.action = action

In [21]:
trainer.train(1)


  0%|          | 0/2500 [00:00<?, ?it/s][A
  0%|          | 1/2500 [04:10<173:41:57, 250.23s/it][A
  0%|          | 2/2500 [08:02<167:24:11, 241.25s/it][A
  0%|          | 3/2500 [11:52<164:45:53, 237.55s/it][A

KeyboardInterrupt: 