## CycleGAN

* task: image translation: map from zebra/horse space to horse/zebra space
* model: $G$ model: downsample - resnet block * N - upsample, $D$ model: patchGAN
* loss term: LSGAN loss + cycle loss + identity loss

In [2]:
import torch
import torch.nn as nn
import torchvision

import torch.optim as optim
from torch.utils.data import DataLoader,Dataset

import torchvision.transforms as transforms
from PIL import Image

from tqdm import tqdm
import os
import time
import numpy as np
import sys
sys.path.append("./../utils")
from utils import BasicConv, init_weight, set_seed
import matplotlib.pyplot as plt
set_seed(2022)
%matplotlib inline

## define model

### G's model
1. the model of G is the model used in style transfer
2. the model follow the struct of downsample -> ResBlock --> upsample

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features, k_s, s, p):
        super(ResidualBlock, self).__init__()

        conv_block =  [ nn.Conv2d(in_features, in_features, k_s, s, p, padding_mode="reflect"),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(in_features, in_features, k_s, s, p, padding_mode="reflect"),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

In [5]:
class Gen(nn.Module):
    """
    follow structure: downsample -> ResNetBlock * N -> upsample
    input: (N, 3, 256, 256)
    output: (N, 3, 256, 256)
    """
    def __init__(self, img_channel=3, num_resblock=8) -> None:
        super(Gen, self).__init__()
        self.head = nn.Sequential(
            nn.Conv2d(img_channel, 64, 7, 1, 3, padding_mode="reflect"),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, 2, 1, padding_mode="reflect"),
            nn.InstanceNorm2d(129),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, 2, 1, padding_mode="reflect"),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        reslayers = []
        for _ in range(num_resblock):
            reslayers.append(ResidualBlock(256, 3, 1, 1))
        self.res_layers = nn.Sequential(*reslayers)

        self.tail = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, img_channel, 7, 1, 3, padding_mode="reflect")
        )
 
    def forward(self, x):
        x = self.head(x)
        x = self.res_layers(x)
        x = self.tail(x)
        return x
        
def test_gen():
    model = Gen()
    x = torch.rand(8, 3, 256, 256)
    print(model)
    print(f"model params num:{sum(p.numel() for p in model.parameters() if p.requires_grad==True)}")
    print(f"input size{x.shape}, output size: {model(x).shape}")
test_gen()

StyleTransferModel(
  (head): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), padding_mode=reflect)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (4): InstanceNorm2d(129, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (7): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (8): ReLU(inplace=True)
  )
  (res_layers): Sequential(
    (0): ResidualBlock(
      (conv_block): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        

In [7]:
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, k_s, s, p, norm=True):
        super(ConvBlock, self).__init__()

        conv_block =  [ nn.Conv2d(in_channel, out_channel, k_s, s, p),
                        nn.InstanceNorm2d(in_channel) if norm==True else nn.Identity(),
                        nn.LeakyReLU(0.2, inplace=True)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return self.conv_block(x)


class Disc(nn.Module):
    """
    input: (N, 3, 256, 256)
    output: (N, 1, 8, 8)
    """
    def __init__(self, input_channel):
        super(Disc, self).__init__()
        main = []
        main.append(ConvBlock(input_channel, 64, 5, 2, 1, norm=False))
        main.append(ConvBlock(64, 128, 3, 2, 1))
        main.append(ConvBlock(128, 256, 3, 2, 1))
        main.append(ConvBlock(256, 512, 3, 1, 1))
        main.append(nn.Conv2d(512, 1, 3, 1, 1))
        self.model = nn.Sequential(*main)

    def forward(self, x):
        return self.model(x)
def test_disc():
    model = Disc(6)
    x = torch.rand(8, 6, 256, 256)
    print(f"model params num:{sum(p.numel() for p in model.parameters() if p.requires_grad==True)}")
    print(f"input size{x.shape}, output size: {model(x).shape}")
test_disc()

model params num:1563457
input sizetorch.Size([8, 6, 256, 256]), output size: torch.Size([8, 1, 32, 32])


## define some params

In [None]:
LEARNING_RATE_GEN = 2e-4
LEARNING_RATE_DISC = 2e-4
LAMBDA_CYCLE = 10
LAMBDA_INDENTITY = 5
IMAGE_CHANNEL = 3
NUM_EPOCH = 10
BATCH_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

## prepare data

In [None]:
# download and unzip dataset
! wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip && unzip horse2zebra.zip

In [None]:
class HorseZebraData(Dataset):
    def __init__(self, transform, path_horse="./horse2zebra/trainA/",path_zebra="./horse2zebra/trainB/"):
        self.path_horse = path_horse
        self.path_zebra = path_zebra

        self.transform = transform

        self.lst_horse = os.listdir(self.path_horse)
        self.lst_zebra = os.listdir(self.path_zebra)

        # one thing about the dataset:
        # the num of horse pic != num of zebra pic
        self.length = max(len(self.lst_horse), len(self.lst_zebra))

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # some pic on the dataset just have one channel
        # so need to use function: convert("RGB")
        horse = Image.open(self.path_horse + self.lst_horse[torch.randint(len(self.lst_horse), size=(1,))]).convert("RGB")
        zebra = Image.open(self.path_zebra + self.lst_zebra[idx]).convert("RGB")
        if self.transform:
            horse = self.transform(horse)
            zebra = self.transform(zebra)
        return horse, zebra



In [None]:
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
])

In [None]:
dataset = HorseZebraData(transform)
dataloader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE)

In [None]:
# get test samples
# here, i choose to pick images in training set to visiualize whether the generator is learning
fix_horse, fix_zebra = next(iter(dataloader))
fix_horse_grid = torchvision.utils.make_grid(fix_horse, nrow=2)
fix_zebra_grid = torchvision.utils.make_grid(fix_zebra, nrow=2)
plt.subplot(1,2,1)
plt.imshow(fix_horse_grid.numpy())
plt.title("horse")
plt.axis("off")
plt.subplot(1,2,2)
plt.imshow(fix_zebra_grid.numpy())
plt.title("zebra")
plt.axis("off")
plt.tight_layout()
plt.show()

## keep the image buffer

in the paper, they do not put the image newly generated by Gen into Disc, instead, they keep a buffer of generated image and then sample image from the buffer to Disc.

In [None]:
class ImageBuffer:
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

    def get_image(self, data):
        # data is what Gen generated and send in func
        # we will return a tensor of the same shape as data send in
        tensor_return = []
        for each in data:
            each = each.unsqueeze(0)
            if len(self.data) < self.max_size:
                self.data.append(each)
                tensor_return.append(each)
            else:
                # with prob == 0.5
                # we will update the buffer
                if torch.randn(1) < 0.5:
                    choose = torch.randint(self.max_size, size=(1,))
                    tensor_return.append(self.data[choose])
                    self.data[choose] = each
                else:
                    tensor_return.append(each)
        return torch.cat(tensor_return, dim=0)


## define optim and loss

In [None]:
G_H_Z = Gen(IMAGE_CHANNEL).to(DEVICE)
G_Z_H = Gen(IMAGE_CHANNEL).to(DEVICE)

D_H = Disc(IMAGE_CHANNEL).to(DEVICE)
D_Z = Disc(IMAGE_CHANNEL).to(DEVICE)

G_H_Z.apply(init_weight)
G_Z_H.apply(init_weight)

D_H.apply(init_weight)
D_Z.apply(init_weight)

# optim two networks' params at the same time
opt_gen = optim.Adam(list(G_Z_H.parameters()) + list(G_H_Z.parameters()), lr=LEARNING_RATE_GEN, betas=(0.5, 0.999))
opt_d_h = optim.Adam(D_H.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999))
opt_d_z = optim.Adam(D_Z.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999))

gan_criterion = nn.MSELoss()
cycle_criterion = nn.L1Loss()
identity_criterion = nn.L1Loss()

model_checkpoint = {
    'G_H_Z':None,
    'G_Z_H':None,
    'D_Z':None,
    'D_H':None
}


## set up the training loop

In [None]:
# function to sample images
def sample_image(fix_horse, fix_zebra):
    fake_zebra = G_H_Z(fix_horse)
    fake_horse = G_Z_H(fix_zebra)
    cycle_horse = G_Z_H(fake_zebra)
    cycle_zebra = G_H_Z(fake_horse)

    grid_real_horse = torchvision.utils.make_grid(fix_horse, nrow=1, normalize=True)
    grid_fake_zebra = torchvision.utils.make_grid(fake_zebra, nrow=1, normalize=True)
    grid_cycle_horse = torchvision.utils.make_grid(cycle_horse, nrow=1, normalize=True)
    grid_horse = torch.cat((grid_real_horse, grid_fake_zebra, grid_cycle_horse), dim=2)

    grid_real_zebra = torchvision.utils.make_grid(fix_zebra, nrow=1, normalize=True)
    grid_fake_horse = torchvision.utils.make_grid(fake_horse, nrow=1, normalize=True)
    grid_cycle_zebra = torchvision.utils.make_grid(cycle_zebra, nrow=1, normalize=True)
    grid_zebra = torch.cat((grid_real_zebra, grid_fake_horse, grid_cycle_zebra), dim=2)
    torchvision.utils.save_image(grid_horse, f"./horse_output/horse_epoch{epoch}_batch_{batch_idx}.png", normalize=False)
    torchvision.utils.save_image(grid_zebra, f"./zebra_output/zebra_epoch{epoch}_batch_{batch_idx}.png", normalize=False)
    print("saving samples... done!")


def save_checkpoint(dict, path="model.pth.tar"):
    """
    save the model and optim through a dictionary
     {"model": model.state_dict(), "optim": optim.state_dict()} (*)


    params:
        data: a dict of the structure (*) storing the model and optim

    return:
        None
    """
    now = time.strftime("%D_%H:%M")
    print(f"saving checkpoint at {now}, path is '{path}'")
    torch.save(dict, path)
    print("saving model... done!")

In [None]:
fake_zebra_buffer = ImageBuffer()
fake_horse_buffer = ImageBuffer()

In [None]:
for epoch in range(NUM_EPOCH):
    loop = tqdm(dataloader, leave=True)
    for batch_idx, (horse, zebra) in enumerate(loop):
        G_H_Z.train()
        G_Z_H.train()

        horse = horse.to(DEVICE)
        zebra = zebra.to(DEVICE)

        #### train Gen ####
        opt_gen.zero_grad()

        # train generator with the lastest generated image
        fake_horse = G_Z_H(zebra)
        d_h_fake = D_H(fake_horse)
        fake_zebra = G_H_Z(horse)
        d_z_fake = D_Z(fake_zebra)
        # gan loss
        gan_loss = gan_criterion(d_z_fake, torch.ones_like((d_z_fake)).to(DEVICE)) + gan_criterion(d_h_fake, torch.ones_like(d_h_fake))

        # identity loss
        identity_zebra = G_H_Z(zebra)
        identity_horse = G_Z_H(horse)
        identity_loss = (identity_criterion(identity_zebra, zebra) + identity_criterion(identity_horse, horse)) * LAMBDA_INDENTITY

        # cycle loss
        cycle_horse = G_Z_H(fake_zebra)
        cycle_zebra = G_H_Z(fake_horse)
        cycle_loss = (cycle_criterion(cycle_horse, horse) + cycle_criterion(cycle_zebra, zebra)) * LAMBDA_CYCLE

        Gen_loss = gan_loss + identity_loss + cycle_loss
        Gen_loss.backward()
        opt_gen.step()

        #### train discriminator ####
        opt_d_h.zero_grad()
        # train discriminator with the image from buffer

        d_h_real = D_H(horse)
        # get image from buffer
        fake_horse_ = fake_horse_buffer.get_image(fake_horse)
        d_h_fake = D_H(fake_horse_.detach())
        d_h_loss = gan_criterion(d_h_real, torch.ones_like(d_h_real).to(DEVICE)) + \
            gan_criterion(d_h_fake, torch.zeros_like(d_h_fake).to(DEVICE))
        d_h_loss.backward()
        opt_d_h.step()

        opt_d_z.zero_grad()
        d_z_real = D_Z(zebra)
        fake_zebra_ = fake_zebra_buffer.get_image(fake_zebra)
        d_z_fake = D_Z(fake_zebra_.detach())
        d_z_loss = gan_criterion(d_z_real, torch.ones_like((d_z_real).to(DEVICE))) + \
            gan_criterion(d_z_fake, torch.zeros_like(d_z_fake).to(DEVICE))
        d_z_loss.backward(retain_graph=True)
        opt_d_z.step()

        if batch_idx % 40 == 0:
          G_Z_H.eval()
          G_H_Z.eval()
          with torch.no_grad():
                # sample gans output
                sample_image(fix_horse.to(DEVICE), fix_zebra.to(DEVICE))

    # saving models
    model_checkpoint['G_H_Z']=G_H_Z.state_dict()
    model_checkpoint['G_Z_H']=G_Z_H.state_dict()
    model_checkpoint['D_Z']=D_Z.state_dict()
    model_checkpoint['D_H']=D_H.state_dict()
    save_checkpoint(model_checkpoint)

In [None]:
!ls ./horse_output

In [None]:
! zip -r zebraoutput.zip zebra_output