# A Pytorch Implementation of [CycleGAN](https://arxiv.org/pdf/1703.10593)

In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
import os
import random
from typing import Tuple, List
from itertools import chain
import matplotlib.pyplot as plt
from tqdm import tqdm

### Generator

The generator consists of the following modules:
1. a $7 \times 7$ Convolution-InstanceNorm-ReLU layer.
2. two $3 \times 3$ Convolution-InstanceNorm-ReLU layers.
3. six resudial blocks that contains two $3 \times 3$ convolutional layers for $128 \times 128$ images and nine resudial blocks for $256 \times 256$ images.
4. two $3 \times 3$ fractional-strided-Convolution-InstanceNorm-ReLU layers.

In [2]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self,
                 in_features: int) -> None:
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True)
        )

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return x + self.block(x)

In [4]:
class GeneratorResNet(nn.Module):
    def __init__(self,
                 input_channels: int,
                 n_residual_blocks: int) -> None:
        super(GeneratorResNet, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3, padding_mode='reflect'),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
            *[ResidualBlock(256) for _ in range(n_residual_blocks)],
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, input_channels, kernel_size=7, padding=3, padding_mode='reflect')
        )
        self.layers.apply(weights_init_normal)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

### Discriminator

The Discriminator consists of the following modules:
1. four $4 \times 4$ Convolution-InstanceNorm-LeakyReLU layers.
2. a convolution to produce a 1-dimensional output.

In [5]:
class DiscriminatorBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 normalize: bool = True) -> None:
        super(DiscriminatorBlock, self).__init__()
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.layers = nn.Sequential(*layers)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

In [6]:
class Discriminator(nn.Module):
    def __init__(self,
                 input_shape: Tuple[int, int, int]) -> None:
        super(Discriminator, self).__init__()
        channels, height, width = input_shape
        self.output_shape = (1, height // 16, width // 16)
        self.layers = nn.Sequential(
            DiscriminatorBlock(channels, 64, normalize=False),
            DiscriminatorBlock(64, 128),
            DiscriminatorBlock(128, 256),
            DiscriminatorBlock(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)
        )
        self.apply(weights_init_normal)

    def forward(self,
                img: torch.Tensor) -> torch.Tensor:
        return self.layers(img)

### Dataset

I use the first 200 images in the monet2photo dataset for training and validation.

In [7]:
class ImageDataset(Dataset):
    def __init__(self,
                 dataset_name: str,
                 transforms_: List[nn.Module],
                 mode: str,
                 max_examples: int = 200):
        DATA_DIR = os.path.join('..', 'data', dataset_name)
        self.transform = transforms.Compose(transforms_)
        self.path_a = os.path.join(DATA_DIR, f'{mode}A')
        self.path_b = os.path.join(DATA_DIR, f'{mode}B')
        self.files_a = os.listdir(self.path_a)[:max_examples]
        self.files_b = os.listdir(self.path_b)[:max_examples]

    def __getitem__(self,
                    index: int) -> Tuple[torch.Tensor]:
        return self.transform(plt.imread(os.path.join(self.path_a, self.files_a[index % len(self.files_a)]))), \
            self.transform(plt.imread(os.path.join(self.path_b, self.files_b[index % len(self.files_b)])))

    def __len__(self):
        return max(len(self.files_a), len(self.files_b))
        

### ReplayBuffer

To reduce model oscillation, the CycleGAN update the discriminator using a history of generated images rather than produced by the latest generators, and keep a buffer that stores the 50 previously created images.

In [8]:
class ReplayBuffer:
    def __init__(self, max_size: int=50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        data = data.detach()
        res = []
        for ele in data:
            if len(self.data) < self.max_size:
                self.data.append(ele)
                res.append(ele)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    res.append(self.data[i].clone())
                    self.data[i] = ele
                else:
                    res.append(ele)
        return torch.stack(res)

### Config

The objective of CycleGAN contains three types of terms: adversarial losses for matching the distribution of generated images to the data distribution in the target domain; and cycle consistency losses to prevent the learned mappings G and F from contradicting each other; and identity loss to encourage the mapping to preserve color composition between the input and the output.

1. For adversarial loss, CycleGAN replaces the negative log likelihood loss by a least-squares loss. This loss is more stable during training and generates higher quality results.
2. The cycle consistent loss forces the mapping G to be cycle-consistent because adversarial loss alone does not guarantee that an individual input $x$ and output $y$ are paired up in a meaningful way and often leads to the problem of mode collapse.
3. Without identity loss, the generator G and F are free to change the tint of input images when there is no need to.

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 200
dataset_name = 'monet2photo'
batch_size = 1
lr = 0.0002
adam_betas = 0.5, 0.999
decay_start = 100

gan_loss = torch.nn.MSELoss()
cycle_loss = torch.nn.L1Loss()
identity_loss = torch.nn.L1Loss()

img_height = 256
img_width = 256
img_channels = 3
input_shape = (img_channels, img_height, img_width)

n_residual_blocks = 9
cyclic_loss_coefficient = 10.0
identity_loss_coefficient = 5.

generator_xy = GeneratorResNet(img_channels, n_residual_blocks).to(device)
generator_yx = GeneratorResNet(img_channels, n_residual_blocks).to(device)
discriminator_x = Discriminator(input_shape).to(device)
discriminator_y = Discriminator(input_shape).to(device)

generator_opitimizer = torch.optim.Adam(
    chain(generator_xy.parameters(), generator_yx.parameters()), lr, adam_betas
)
discriminator_optimizer = torch.optim.Adam(
    chain(discriminator_x.parameters(), discriminator_y.parameters()), lr, adam_betas
)

decay_epochs = epochs - decay_start
generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    generator_opitimizer, lambda e: 1.0 - max(0, e - decay_start) / decay_epochs
)
discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    discriminator_optimizer, lambda e: 1.0 - max(0, e - decay_start) / decay_epochs
)

transforms_train = [
    transforms.ToTensor(),
    transforms.Resize(int(img_height * 1.12), transforms.InterpolationMode.BICUBIC),
    transforms.RandomCrop((img_height, img_height)),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

transforms_val = [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

dataloader = DataLoader(
    ImageDataset(dataset_name, transforms_train, 'train'),
    batch_size=batch_size,
    shuffle=True    
)


valid_dataloader = DataLoader(
    ImageDataset(dataset_name, transforms_val, 'test'),
    batch_size=5,
    shuffle=True
)

In [10]:
output_dir = 'output'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

### Train

In [11]:
def plot_image(img: torch.Tensor,
               path: str) -> None:
    img = img.cpu()
    img_min, img_max = img.min(), img.max()
    img = (img - img_min) / (img_max - img_min + 1e-5)
    img = img.permute(1, 2, 0)
    plt.imsave(path, img)
    # plt.imshow(img)
    # plt.axis('off')
    # plt.show()

In [12]:
def sample_images(n: int):
    X, Y = next(iter(valid_dataloader))
    generator_xy.eval()
    generator_yx.eval()
    with torch.no_grad():
        X, Y = X.to(device), Y.to(device)
        gen_y = generator_xy(X)
        gen_x = generator_yx(Y)
        
        X = make_grid(X, nrow=5, normalize=True)
        Y = make_grid(Y, nrow=5, normalize=True)
        gen_x = make_grid(gen_x, nrow=5, normalize=True)
        gen_y = make_grid(gen_y, nrow=5, normalize=True)

        image_grid = torch.cat([X, gen_y, Y, gen_x], 1)
    plot_image(image_grid, path=f'{output_dir}/epoch_{n:03d}.jpg')

In [13]:
gen_x_buffer = ReplayBuffer()
gen_y_buffer = ReplayBuffer()

for epoch in range(1, epochs + 1):
    loss_g = loss_d = 0
    count_g = count_d = 0
    for i, (X, Y) in tqdm(enumerate(dataloader, 1)):
        X, Y = X.to(device), Y.to(device)
        true_labels = torch.ones(X.size(0), *discriminator_x.output_shape, device=device, requires_grad=False)
        false_labels = torch.zeros(X.size(0), *discriminator_x.output_shape, device=device, requires_grad=False)
        
        # Train the generators
        generator_xy.train()
        generator_yx.train()
        
        loss_identity = identity_loss(generator_yx(X), X) \
            + identity_loss(generator_xy(Y), Y)
        
        gen_y = generator_xy(X)
        gen_x = generator_yx(Y)
        
        loss_gan = gan_loss(discriminator_y(gen_y), true_labels) \
            + gan_loss(discriminator_x(gen_x), true_labels)
        
        loss_cycle = cycle_loss(generator_yx(gen_y), X) \
            + cycle_loss(generator_xy(gen_x), Y)
            
        loss_generator = loss_gan + cyclic_loss_coefficient * loss_cycle + identity_loss_coefficient * loss_identity
        
        generator_opitimizer.zero_grad()
        loss_generator.backward()
        generator_opitimizer.step()
        
        loss_g += loss_generator.item()
        count_g += X.size(0)
        
        # Train the discriminators
        loss_discriminator = (gan_loss(discriminator_x(X), true_labels) + 
                              gan_loss(discriminator_x(gen_x_buffer.push_and_pop(gen_x)), false_labels) + 
                              gan_loss(discriminator_y(Y), true_labels) + 
                              gan_loss(discriminator_y(gen_y_buffer.push_and_pop(gen_y)), false_labels))
        
        discriminator_optimizer.zero_grad()
        loss_discriminator.backward()
        discriminator_optimizer.step()
        
        loss_d += loss_discriminator.item()
        count_d += X.size(0)

    generator_lr_scheduler.step()
    discriminator_lr_scheduler.step()
    
    print(f'\nEpoch {epoch:03d}, Generator Loss: {loss_g / count_g:.4f}, Discriminator Loss: {loss_d / count_d:.4f}')
    if epoch % 25 == 0:
        sample_images(epoch)
    # with open(f'{output_dir}/losses.txt', 'a') as f:
    #     f.write(f'\nEpoch {epoch:03d}, Generator Loss: {loss_g / count_g:.4f}, Discriminator Loss: {loss_d / count_d:.4f}')

  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
25it [00:25,  1.03s/it]


KeyboardInterrupt: 

In [None]:
torch.save(generator_xy.state_dict(), f'{output_dir}/generator_xy.pth')
torch.save(generator_yx.state_dict(), f'{output_dir}/generator_yx.pth')
torch.save(discriminator_x.state_dict(), f'{output_dir}/discriminator_x.pth')
torch.save(discriminator_y.state_dict(), f'{output_dir}/discriminator_y.pth')

## Reference
1. https://nn.labml.ai/gan/cycle_gan/index.html
2. [Official Implementation](https://github.com/junyanz/CycleGAN)