# Реализация CycleGAN на архитектуре U-NET для генератора и PatchGAN для дискриминатора

## Imports

In [None]:
import torch
import torchvision as tv
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
import torchvision.utils as vutils
import torch.backends.cudnn as cudnn

import os
import numpy as np
from PIL import Image
import itertools
import random
import glob
import time
import errno
import shutil
import sys

from matplotlib import pyplot as plt
from math import sqrt
%matplotlib inline

## Глобальные параметры

In [None]:
# GPU / CPU
device = torch.device('cpu')

if torch.cuda.is_available():
    device = torch.device('cuda')

device

In [None]:
PARAMS = {
    'data_dir': '../input/gan-getting-started', # Directory for dataset
    'datasetA': 'monet_jpg', # Dataset A
    'datasetB': 'photo_jpg', # Dataset B
    'out_dir': 'output', # Directory for output
    'epochs': 200, # number of epochs
    'batch_size': 5, # size of batches in training
    'test_batch_size': 4, # size of batches in inference
    'lr': 0.0002, # learning rate
    'img_size': 256, # size of images
    'channels': 3, # number of image channels
    'num_blocks': 9, # number of residual blocks
    'log_interval': 20, # interval between logging and image sampling
    'seed': 1, # random seed
}

## Utils

In [None]:
def clear_folder(folder_path):
    """Clear all contents recursively if the folder exists.
    Create the folder if it has been accidently deleted.
    """
    create_folder(folder_path)
    for the_file in os.listdir(folder_path):
        _file_path = os.path.join(folder_path, the_file)
        try:
            if os.path.isfile(_file_path):
                os.unlink(_file_path)
            elif os.path.isdir(_file_path):
                shutil.rmtree(_file_path)
        except OSError as _e:
            print(_e)

def create_folder(folder_path):
    """Create a folder if it does not exist.
    """
    try:
        os.makedirs(folder_path)
    except OSError as _e:
        if _e.errno != errno.EEXIST:
            raise


## ImageBuffer

As suggested in the paper, we update the discriminators by randomly picking an
image from the history of generated images, rather than the fake samples in real-time.
The history of generated images is maintained by the ImageBuffer class, which is defined
as follows

In [None]:
class ImageBuffer(object):
    def __init__(self, depth=50):
        self.depth = depth
        self.buffer = []

    def update(self, image):
        # print(f'Image shape = {image.shape}')
        if len(self.buffer) == self.depth:
            i = random.randint(0, self.depth-1)
            self.buffer[i] = image
        else:
            self.buffer.append(image)
        if random.uniform(0,1) > 0.5:
            i = random.randint(0, len(self.buffer)-1)
            return self.buffer[i]
        else:
            return image

## Images Dataset
A custom dataset reader that picks up unpaired images from separate
folders.

In [None]:
class ImageDataset(Dataset):
    def __init__(self, transform=None, unaligned=True, 
                 batch_size=0, max_size=0):
        dir_A = f'{PARAMS["data_dir"]}/{PARAMS["datasetA"]}'
        dir_B = f'{PARAMS["data_dir"]}/{PARAMS["datasetB"]}'
        self.transform = tv.transforms.Compose(transform)
        self.unaligned = unaligned
        self.files_A = sorted(glob.glob(dir_A + '/*.*'))
        self.files_B = sorted(glob.glob(dir_B + '/*.*'))
        self.batch_size = batch_size
        self.max_size = max_size 

    def __getitem__(self, index):
        # Image.open(self.files_A[index % len(self.files_A)])
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

        if self.unaligned:
            item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
        else:
            item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

        return {'datasetA': item_A, 'datasetB': item_B}

    def __len__(self):
        res = max(len(self.files_A), len(self.files_B))
        if self.max_size > 0:
            res = min(res, self.max_size)
        # Обрезка по размеру батча (если задано), чтобы размеры сходились
        if self.batch_size > 1:
            res = res - res % self.batch_size
        return res

## Image Transform

In [None]:
transform = [
#     tv.transforms.Resize(int(PARAMS['img_size']*1.12), Image.BICUBIC),
#     tv.transforms.RandomCrop((PARAMS['img_size'], PARAMS['img_size'])),
    tv.transforms.RandomHorizontalFlip(),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ]


## DataLoaders

In [None]:
train_loader = DataLoader(
    ImageDataset(transform=transform, 
                 batch_size = PARAMS['batch_size'],
                 max_size=350
                ),
    batch_size = PARAMS['batch_size'],
    shuffle = True,
    num_workers = 2,
    pin_memory=True
)

test_loader = DataLoader(
    ImageDataset(transform=transform,  
                 batch_size = PARAMS['batch_size'],
                 max_size=300
                ),
    batch_size = PARAMS['test_batch_size'],
    shuffle = True,
    num_workers = 2,
)

## Визуализация изображений

In [None]:
# load a batch
data = next(iter(train_loader))
imgs_x = data['datasetA']
imgs_y = data['datasetB']

# visualize batch
fig, axis = plt.subplots(2, 2, figsize=(10, 10))

for i in range(2):
    axis[0, i].axis('off')
    axis[0, i].imshow(np.transpose((imgs_x[i]+1)/2, (1, 2, 0)))
    axis[0, i].set_title(f'Monet {i+1}')
    axis[1, i].axis('off')
    axis[1, i].imshow(np.transpose((imgs_y[i]+1)/2, (1, 2, 0)))
    axis[1, i].set_title(f'Photo {i+1}')

## Generator

In [None]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()

        block = [torch.nn.ReflectionPad2d(1),
                 torch.nn.Conv2d(channels, channels, 3),
                 torch.nn.InstanceNorm2d(channels),
                 torch.nn.ReLU(inplace=True),
                 torch.nn.ReflectionPad2d(1),
                 torch.nn.Conv2d(channels, channels, 3),
                 torch.nn.InstanceNorm2d(channels)]
        self.block = torch.nn.Sequential(*block)

    def forward(self, x):
        return x + self.block(x)

In [None]:
class Generator(torch.nn.Module):
    def __init__(self, channels, num_blocks=9):
        super(Generator, self).__init__()
        self.channels = channels

        model = [torch.nn.ReflectionPad2d(3)]
        model += self._create_layer(self.channels, 64, 7, stride=1, padding=0, transposed=False)
        # downsampling
        model += self._create_layer(64, 128, 3, stride=2, padding=1, transposed=False)
        model += self._create_layer(128, 256, 3, stride=2, padding=1, transposed=False)
        # residual blocks
        model += [ResidualBlock(256) for _ in range(num_blocks)]
        # upsampling
        model += self._create_layer(256, 128, 3, stride=2, padding=1, transposed=True)
        model += self._create_layer(128, 64, 3, stride=2, padding=1, transposed=True)
        # output
        model += [torch.nn.ReflectionPad2d(3),
                  torch.nn.Conv2d(64, self.channels, 7),
                  torch.nn.Tanh()]

        self.model = torch.nn.Sequential(*model)

    def _create_layer(self, size_in, size_out, kernel_size, stride=2, padding=1, transposed=False):
        layers = []
        if transposed:
            layers.append(torch.nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, output_padding=1))
        else:
            layers.append(torch.nn.Conv2d(size_in, size_out, kernel_size, stride=stride, padding=padding))
        layers.append(torch.nn.InstanceNorm2d(size_out))
        layers.append(torch.nn.ReLU(inplace=True))
        return layers

    def forward(self, x):
        return self.model(x)


## Discriminator

PatchGAN architecture

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()
        self.channels = channels

        self.model = torch.nn.Sequential(
            *self._create_layer(self.channels, 64, 2, normalize=False),
            *self._create_layer(64, 128, 2),
            *self._create_layer(128, 256, 2),
            *self._create_layer(256, 512, 1),
            torch.nn.Conv2d(512, 1, 4, stride=1, padding=1)
        )

    def _create_layer(self, size_in, size_out, stride, normalize=True):
        layers = [torch.nn.Conv2d(size_in, size_out, 4, stride=stride, padding=1)]
        if normalize:
            layers.append(torch.nn.InstanceNorm2d(size_out))
        layers.append(torch.nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, x):
        return self.model(x)

## Model training and evaluation

In [None]:
# Function to initialize the weights
def _weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
# Trainer
class Model(object):
    def __init__(self,
                 device,
                 data_loader,
                 test_data_loader,
                 channels,
                 img_size,
                 num_blocks):
        self.name = 'CycleGAN'
        self.device = device
        self.data_loader = data_loader
        self.test_data_loader = test_data_loader
        self.channels = channels
        self.img_size = img_size
        self.num_blocks = num_blocks
        self.netG_AB = Generator(self.channels, self.num_blocks)
        self.netG_AB.apply(_weights_init)
        self.netG_AB.to(self.device)
        self.netG_BA = Generator(self.channels, self.num_blocks)
        self.netG_BA.apply(_weights_init)
        self.netG_BA.to(self.device)
        self.netD_A = Discriminator(self.channels)
        self.netD_A.apply(_weights_init)
        self.netD_A.to(self.device)
        self.netD_B = Discriminator(self.channels)
        self.netD_B.apply(_weights_init)
        self.netD_B.to(self.device)

        self.optim_G = None
        self.optim_D_A = None
        self.optim_D_B = None

        self.loss_adv = torch.nn.MSELoss()
        self.loss_cyc = torch.nn.L1Loss()
        self.loss_iden = torch.nn.L1Loss()

    @property
    def generator_AB(self):
        return self.netG_AB

    @property
    def generator_BA(self):
        return self.netG_BA

    @property
    def discriminator_A(self):
        return self.netD_A

    @property
    def discriminator_B(self):
        return self.netD_B

    def create_optim(self, lr, alpha=0.5, beta=0.999):
        self.optim_G = torch.optim.Adam(itertools.chain(self.netG_AB.parameters(),
                                                        self.netG_BA.parameters()),
                                        lr=lr,
                                        betas=(alpha, beta))
        self.optim_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                          lr=lr/2,
                                          betas=(alpha, beta))
        self.optim_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                          lr=lr/2,
                                          betas=(alpha, beta))

    def train(self,
              epochs,
              log_interval=100,
              out_dir='',
              verbose=True):
        self.netG_AB.train()
        self.netG_BA.train()
        self.netD_A.train()
        self.netD_B.train()
        lambda_cyc = 10
        lambda_iden = 5
        real_label = torch.ones((self.data_loader.batch_size, 1, 30, 30), device=self.device)
        fake_label = torch.zeros((self.data_loader.batch_size, 1, 30, 30), device=self.device)
        image_buffer_A = ImageBuffer()
        image_buffer_B = ImageBuffer()
        total_time = time.time()

        # Scheduler для уменьшения шага скорости обучения
        decay_epoch = 110

        lambda_func = lambda epoch: 1 - max(0, epoch-decay_epoch)/(epochs-decay_epoch)
        lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(self.optim_G, lr_lambda=lambda_func)
        lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(self.optim_D_A, lr_lambda=lambda_func)
        lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(self.optim_D_B, lr_lambda=lambda_func)
#         lr_scheduler_G = torch.optim.lr_scheduler.StepLR(self.optim_G, step_size=1, gamma=0.9)
#         lr_scheduler_D_A = torch.optim.lr_scheduler.StepLR(self.optim_D_A, step_size=1, gamma=0.9)
#         lr_scheduler_D_B = torch.optim.lr_scheduler.StepLR(self.optim_D_B, step_size=1, gamma=0.9)

        for epoch in range(epochs):
            batch_time = time.time()
            print(f'Epoch: {epoch}, lr = {lr_scheduler_G.get_last_lr()}')
            for batch_idx, data in enumerate(self.data_loader):
                real_A = data['datasetA'].to(self.device)
                real_B = data['datasetB'].to(self.device)

                # Train G
                self.optim_G.zero_grad()

                # adversarial loss
                fake_B = self.netG_AB(real_A)
                _loss_adv_AB = self.loss_adv(self.netD_B(fake_B), real_label)
                fake_A = self.netG_BA(real_B)
                _loss_adv_BA = self.loss_adv(self.netD_A(fake_A), real_label)
                adv_loss = (_loss_adv_AB + _loss_adv_BA) / 2

                # cycle loss
                recov_A = self.netG_BA(fake_B)
                _loss_cyc_A = self.loss_cyc(recov_A, real_A)
                recov_B = self.netG_AB(fake_A)
                _loss_cyc_B = self.loss_cyc(recov_B, real_B)
                cycle_loss = (_loss_cyc_A + _loss_cyc_B) / 2

                # identity loss
                _loss_iden_A = self.loss_iden(self.netG_BA(real_A), real_A)
                _loss_iden_B = self.loss_iden(self.netG_AB(real_B), real_B)
                iden_loss = (_loss_iden_A + _loss_iden_B) / 2

                g_loss = adv_loss + lambda_cyc * cycle_loss + lambda_iden * iden_loss
                g_loss.backward()
                self.optim_G.step()

                # Train D_A
                self.optim_D_A.zero_grad()

                _loss_real = self.loss_adv(self.netD_A(real_A), real_label)
#                 fake_A = image_buffer_A.update(fake_A)
                _loss_fake = self.loss_adv(self.netD_A(fake_A.detach()), fake_label)
                d_loss_A = (_loss_real + _loss_fake) / 2

                d_loss_A.backward()
                self.optim_D_A.step()

                # Train D_B
                self.optim_D_B.zero_grad()

                _loss_real = self.loss_adv(self.netD_B(real_B), real_label)
#                 fake_B = image_buffer_B.update(fake_B)
                _loss_fake = self.loss_adv(self.netD_B(fake_B.detach()), fake_label)
                d_loss_B = (_loss_real + _loss_fake) / 2

                d_loss_B.backward()
                self.optim_D_B.step()

                d_loss = (d_loss_A + d_loss_B) / 2

                if verbose and batch_idx % log_interval == 0 and batch_idx > 0:
                    print('Epoch {} [{}/{}] loss_D: {:.4f} loss_G: {:.4f} time: {:.2f}'.format(
                          epoch, batch_idx, len(self.data_loader),
                          d_loss.mean().item(),
                          g_loss.mean().item(),
                          time.time() - batch_time))
#                     with torch.no_grad():
#                         imgs = next(iter(self.test_data_loader))
#                         _real_A = imgs['datasetA'].to(self.device)
#                         _fake_B = self.netG_AB(_real_A)
#                         _real_B = imgs['datasetB'].to(self.device)
#                         _fake_A = self.netG_BA(_real_B)
#                         viz_sample = torch.cat((_real_A, _fake_B, _real_B, _fake_A), 0)
#                         vutils.save_image(viz_sample,
#                                           os.path.join(out_dir, 'samples_{}_{}.png'.format(epoch, batch_idx)),
#                                           nrow=self.test_data_loader.batch_size,
#                                           normalize=True)
                    batch_time = time.time()
            lr_scheduler_G.step()
            lr_scheduler_D_A.step()
            lr_scheduler_D_B.step()

#             self.save_to(path=out_dir, name=self.name, verbose=False)
        if verbose:
            print('Total train time: {:.2f}'.format(time.time() - total_time))

    def eval(self,
             batch_size=None):
        self.netG_AB.eval()
        self.netG_BA.eval()
        self.netD_A.eval()
        self.netD_B.eval()
        if batch_size is None:
            batch_size = self.test_data_loader.batch_size

        with torch.no_grad():
            for batch_idx, data in enumerate(self.test_data_loader):
                _real_A = data['testA'].to(self.device)
                _fake_B = self.netG_AB(_real_A)
                _real_B = data['testB'].to(self.device)
                _fake_A = self.netG_BA(_real_B)
                viz_sample = torch.cat((_real_A, _fake_B, _real_B, _fake_A), 0)
                vutils.save_image(viz_sample,
                                  'img_{}.png'.format(batch_idx),
                                  nrow=batch_size,
                                  normalize=True)

    def save_to(self,
                path='',
                name=None,
                verbose=True):
        if name is None:
            name = self.name
        if verbose:
            print('\nSaving models to {}_G_AB.pt and such ...'.format(name))
        torch.save(self.netG_AB.state_dict(), os.path.join(path, '{}_G_AB.pt'.format(name)))
        torch.save(self.netG_BA.state_dict(), os.path.join(path, '{}_G_BA.pt'.format(name)))
        torch.save(self.netD_A.state_dict(), os.path.join(path, '{}_D_A.pt'.format(name)))
        torch.save(self.netD_B.state_dict(), os.path.join(path, '{}_D_B.pt'.format(name)))

    def load_from(self,
                  path='',
                  name=None,
                  verbose=True):
        if name is None:
            name = self.name
        if verbose:
            print('\nLoading models from {}_G_AB.pt and such ...'.format(name))
        ckpt_G_AB = torch.load(os.path.join(path, '{}_G_AB.pt'.format(name)))
        if isinstance(ckpt_G_AB, dict) and 'state_dict' in ckpt_G_AB:
            self.netG_AB.load_state_dict(ckpt_G_AB['state_dict'], strict=True)
        else:
            self.netG_AB.load_state_dict(ckpt_G_AB, strict=True)
        ckpt_G_BA = torch.load(os.path.join(path, '{}_G_BA.pt'.format(name)))
        if isinstance(ckpt_G_BA, dict) and 'state_dict' in ckpt_G_BA:
            self.netG_BA.load_state_dict(ckpt_G_BA['state_dict'], strict=True)
        else:
            self.netG_BA.load_state_dict(ckpt_G_BA, strict=True)
        ckpt_D_A = torch.load(os.path.join(path, '{}_D_A.pt'.format(name)))
        if isinstance(ckpt_D_A, dict) and 'state_dict' in ckpt_D_A:
            self.netD_A.load_state_dict(ckpt_D_A['state_dict'], strict=True)
        else:
            self.netD_A.load_state_dict(ckpt_D_A, strict=True)
        ckpt_D_B = torch.load(os.path.join(path, '{}_D_B.pt'.format(name)))
        if isinstance(ckpt_D_B, dict) and 'state_dict' in ckpt_D_B:
            self.netD_B.load_state_dict(ckpt_D_B['state_dict'], strict=True)
        else:
            self.netD_B.load_state_dict(ckpt_D_B, strict=True)


## Running model

In [None]:
PARAMS['cuda'] = torch.cuda.is_available()

if PARAMS['seed'] is not None:
    torch.manual_seed(PARAMS['seed'])
    if PARAMS['cuda']:
        torch.cuda.manual_seed(PARAMS['seed'])
    np.random.seed(PARAMS['seed'])

cudnn.benchmark = True

clear_folder(PARAMS['out_dir'])

log_file = os.path.join(PARAMS['out_dir'], 'log.txt')
print("Logging to {}\n".format(log_file))
# sys.stdout = StdOut(log_file)

print("PyTorch version: {}".format(torch.__version__))
print("CUDA version: {}\n".format(torch.version.cuda))

for key in PARAMS.keys():
    print(f'{key} = {PARAMS[key]}')

In [None]:
print('Creating model...\n')
model = Model(device, train_loader, test_loader, PARAMS['channels'], PARAMS['img_size'], PARAMS['num_blocks'])
model.create_optim(PARAMS['lr'])

In [None]:
%%time
# Train
model.train(PARAMS['epochs'], PARAMS['log_interval'], PARAMS['out_dir'], True)

# model.save_to('')

## Визуализация результатов

In [None]:
from torch.functional import Tensor
def sample_images(real_A, real_B, model, figside=5):
    assert real_A.size() == real_B.size(), 'The image size for two domains must be the same'
    
    netG = model.generator_AB
    netF = model.generator_BA

    netG.eval()
    netF.eval()
    
    real_A = real_A.type(Tensor).cuda()
    fake_B = netG(real_A)
    real_B = real_B.type(Tensor).cuda()
    fake_A = netF(real_B)
    
    nrows = real_A.size(0)
    real_A = make_grid(real_A, nrow=nrows, normalize=True)
    fake_B = make_grid(fake_B, nrow=nrows, normalize=True)
    real_B = make_grid(real_B, nrow=nrows, normalize=True)
    fake_A = make_grid(fake_A, nrow=nrows, normalize=True)
    
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1).cpu().permute(1, 2, 0)
    
    plt.figure(figsize=(figside*nrows, figside*4))
    plt.imshow(image_grid)
    plt.axis('off')
    plt.show()

In [None]:
data = next(iter(test_loader))
real_A, real_B = data['datasetA'], data['datasetB']
sample_images(real_A, real_B, model)

## Generate Images

In [None]:
photo_dir = os.path.join(PARAMS['data_dir'], PARAMS['datasetB'])
files = [os.path.join(photo_dir, name) for name in os.listdir(photo_dir)]
len(files)

In [None]:
save_dir = '../images'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
generate_transforms = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

to_image = tv.transforms.ToPILImage()
batch_size = PARAMS['batch_size']
model.generator_BA.eval()
for i in range(0, len(files), batch_size):
    # read images
    imgs = []
    for j in range(i, min(len(files), i+batch_size)):
        img = Image.open(files[j])
        img = generate_transforms(img)
        imgs.append(img)
    imgs = torch.stack(imgs, 0).type(Tensor)
    
    # generate
    fake_imgs = model.generator_BA(imgs.to(device)).detach().cpu()
    
    # save
    for j in range(fake_imgs.size(0)):
        img = fake_imgs[j].squeeze().permute(1, 2, 0)
        img_arr = img.numpy()
        img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
        img_arr = img_arr.astype(np.uint8)
        
        img = to_image(img_arr)
        _, name = os.path.split(files[i+j])
        img.save(os.path.join(save_dir, name))

In [None]:
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")