Соревнование:
https://www.kaggle.com/c/gan-getting-started

Ноутбки, с которых я ~~частично украл код~~ вдохновлялся:
* https://www.kaggle.com/code/nachiket273/cyclegan-pytorch/notebook
* https://www.kaggle.com/code/amyjang/monet-cyclegan-tutorial


Статейки и полезные ссылки:
* https://arxiv.org/pdf/1703.10593.pdf
* https://neptune.ai/blog/6-gan-architectures/amp
* https://www.tensorflow.org/tutorials/generative/cyclegan
* https://towardsdatascience.com/overview-of-cyclegan-architecture-and-training-afee31612a2f
* https://towardsdatascience.com/cyclegan-learning-to-translate-images-without-paired-training-data-5b4e93862c8d
* https://arxiv.org/pdf/1607.08022.pdf

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

import torchvision.transforms as transforms
import torchvision.models as models

from tqdm.autonotebook import tqdm
from PIL import Image
import os
import itertools

import copy
import numpy as np

In [None]:
SEED = 42
np.random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
from google.colab import drive
import os

drive.mount("/content/gdrive/")

In [None]:
# Работал на двух гугл аккаунтах, поэтому такая двойственность с путями:
# !ls /content/gdrive/MyDrive/'Colab Notebooks'/'Tinkoff DL'/'Занятие 6'/
!ls /content/gdrive/MyDrive/'Colab Notebooks'/

In [None]:
# !unzip -q /content/gdrive/MyDrive/'Colab Notebooks'/'Tinkoff DL'/'Занятие 6'/gan-getting-started.zip
!unzip -q /content/gdrive/MyDrive/'Colab Notebooks'/gan-getting-started.zip

In [None]:
def unnorm(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    for channel, m, s in zip(img, mean, std):
        channel *= s
        channel += m
    return img

In [None]:
class ImageDataset(Dataset):
    def __init__(self, monet_path, photo_path, size=256):
        
        self.size = size
        self.monet_path = monet_path
        self.photo_path = photo_path

        self.transforms = transforms.Compose([
            transforms.Resize(self.size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.monet = self.load(monet_path)
        self.photo = self.load(photo_path, len(self.monet))

      
    def load(self, path, count=-1):
        filenames = []
        for root, dirs, files in os.walk(path):
            for file in files:
                if file.endswith('.jpg'):
                    filenames.append(os.path.join(root, file))

        images = []
        for filename in tqdm(filenames):
            try:
                with Image.open(filename) as image:
                    images.append(image.copy())
            except:
                pass
        return images
        
    def __len__(self):
        return min(len(self.monet), len(self.photo))

    def get_photo(self):
        return self.photo

    def __getitem__(self, idx):
        assert idx < len(self), "Index exceeds dataset size!"
        monet_img, photo_img = self.monet[idx], self.photo[idx]
        monet_img, photo_img = self.transforms(monet_img), self.transforms(photo_img)
        return monet_img, photo_img

In [None]:
RESCALE_SIZE = 256

dataset = ImageDataset('monet_jpg', 'photo_jpg', RESCALE_SIZE)

In [None]:
len(dataset)

In [None]:
loader = DataLoader(dataset, batch_size=1, pin_memory=True)

In [None]:
len(loader)

In [None]:
fig = plt.figure(figsize=(12, 12))

monet_img, photo_img  = next(iter(loader))

fig.add_subplot(1, 2, 1)
plt.title('Monet')
plt.imshow(unnorm(monet_img[0]).permute(1, 2, 0))
plt.axis("off")

fig.add_subplot(1, 2, 2)
plt.title('Photo')
plt.imshow(unnorm(photo_img[0]).permute(1, 2, 0))
plt.axis("off")
plt.show()

In [None]:
def ConvLayer(channels_in, channels_out, kernel_size=3, stride=2, pad=1, use_leaky=False):
    return nn.Sequential(
        nn.Conv2d(channels_in, channels_out, kernel_size, stride, pad),
        nn.InstanceNorm2d(channels_out),
        nn.LeakyReLU() if use_leaky else nn.ReLU(),
    )

class ResidualBlock(nn.Module):
    def __init__(self, channels_in):
        super().__init__()
        self.channels_in = channels_in
        self.block = nn.Sequential(
            ConvLayer(channels_in, channels_in, kernel_size=3, stride=1, pad=1, use_leaky=False),
            nn.Conv2d(channels_in, channels_in, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels_in),
        )

    def forward(self, x):
        res = x + self.block(x)
        res = nn.InstanceNorm2d(self.channels_in)(res)
        res = nn.ReLU()(res)
        return res

def Upsample(channels_in, channels_out, kernel_size=3, stride=2, pad=1, out_pad=1):
    return nn.Sequential(
        nn.ConvTranspose2d(channels_in, channels_out, kernel_size, stride, padding=pad, output_padding=out_pad),
        nn.InstanceNorm2d(channels_out),
        nn.ReLU()
    )

Используем такой генератор:

<img src="https://i0.wp.com/nttuan8.com/wp-content/uploads/2020/05/4-2.png?fit=1024%2C641&ssl=1">


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        # Convolution output size = 1 + (Input Size - Filter size + 2 * Padding) / Stride
        # Transpose Convolution Output Size = (Input Size - 1) * Strides + Filter Size - 2 * Padding + Ouput Padding

        self.encoder = nn.Sequential(
            # Input: 3, 256, 256
            ConvLayer(channels_in=3, channels_out=64, kernel_size=7, stride=1, pad=3, use_leaky=False), # -> 64, 256, 256
            ConvLayer(channels_in=64, channels_out=128), # 128, 128, 128
            ConvLayer(channels_in=128, channels_out=256), # 256, 64, 64
        )

        self.transformer = nn.Sequential(
            ResidualBlock(256), # -> 256, 64, 64
            ResidualBlock(256), # -> 256, 64, 64
            ResidualBlock(256), # -> 256, 64, 64
            ResidualBlock(256), # -> 256, 64, 64
            ResidualBlock(256), # -> 256, 64, 64
            ResidualBlock(256), # -> 256, 64, 64
        )

        self.decoder = nn.Sequential(
            Upsample(256, 128), # -> 128, 128, 128
            Upsample(128, 64), # -> 64, 256, 256
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3), # -> 3, 256, 256
            nn.Tanh(),
        )
    
    def forward(self, x):
        res = self.encoder(x)
        res = self.transformer(res)
        res = self.decoder(res)
        return res

И такой дискриминатор:

<img src="https://images.viblo.asia/retina/1696ebe2-b162-41a8-8f0b-92fc8bc88fdf.png">


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        # Convolution output size = 1 + (Input Size - Filter size + 2 * Padding) / Stride

        self.layers  = nn.Sequential(
            # Input: 3, 256, 256
            ConvLayer(channels_in=3, channels_out=64, kernel_size=4, stride=2, pad=1, use_leaky=True), # -> 64, 128, 128
            ConvLayer(channels_in=64, channels_out=128, kernel_size=4, stride=2, pad=1, use_leaky=True), # -> 128, 64, 64
            ConvLayer(channels_in=128, channels_out=256, kernel_size=4, stride=2, pad=1, use_leaky=True), # -> 256, 32, 32
            ConvLayer(channels_in=256, channels_out=512, kernel_size=4, stride=2, pad=1, use_leaky=True), # -> 512, 16, 16
            # nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1), # -> 1, 16, 16 - можно и так, но это не по схеме
            nn.ZeroPad2d((1,0,1,0)),  # -> 512, 17, 17
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), # -> 1, 16, 16
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        return self.layers(x)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('Instance') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class CycleGAN(object):
    def __init__(self, epochs, device, lr=2e-4, lmbda=10, id_coef=0.5, decay_epoch=0):
        
        self.epochs = epochs
        self.start_epoch = 0
        self.decay_epoch = decay_epoch if decay_epoch > 0 else self.epochs // 2
        self.lmbda = lmbda
        self.id_coef = id_coef
        self.device = device

        # Определим модели
        self.gen_m2p = Generator()
        self.gen_p2m = Generator()
        self.disc_m = Discriminator()
        self.disc_p = Discriminator()

        self.init_models()

        # Функции потерь
        self.gen_criterion = nn.MSELoss() 
        self.disc_criterion = nn.MSELoss() 
        self.cycle_criterion = nn.L1Loss()
        self.identity_criterion = nn.L1Loss()

        # Оптимайзеры
        self.generator_optimizer = torch.optim.Adam(
            itertools.chain(self.gen_m2p.parameters(), self.gen_p2m.parameters()), 
            lr = lr, 
            betas=(0.5, 0.999)
        )
        self.discriminator_optimizer = torch.optim.Adam(
            itertools.chain(self.disc_m.parameters(), self.disc_p.parameters()), 
            lr = lr, 
            betas=(0.5, 0.999)
        )

        # Вот тут раньше было сэмплирование

        # Lr_schedulers
        self.gen_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.generator_optimizer, 
                                                              lr_lambda=lambda epoch: 1 - max(0, epoch-decay_epoch) / (epochs-decay_epoch))
        self.disc_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.discriminator_optimizer, 
                                                               lr_lambda=lambda epoch: 1 - max(0, epoch-decay_epoch) / (epochs-decay_epoch))
        
        # Для статистики
        self.gen_history = []
        self.disc_history = []
        
    def init_models(self):
        weights_init(self.gen_m2p)
        weights_init(self.gen_p2m)
        weights_init(self.disc_m)
        weights_init(self.disc_p)
        self.gen_m2p = self.gen_m2p.to(self.device)
        self.gen_p2m = self.gen_p2m.to(self.device)
        self.disc_m = self.disc_m.to(self.device)
        self.disc_p = self.disc_p.to(self.device)
    
    def update_req_grad(self, models, requires_grad=True):
        for model in models:
            for param in model.parameters():
                param.requires_grad = requires_grad
    
    def upload_model(self, PATH):
        checkpoint = torch.load(PATH) if os.path.exists(PATH) else 0

        if checkpoint:
            self.start_epoch = checkpoint['last_epoch'] + 1
            self.gen_m2p.load_state_dict(checkpoint['gen_m2p'])
            self.gen_p2m.load_state_dict(checkpoint['gen_p2m'])
            self.disc_m.load_state_dict(checkpoint['disc_m'])
            self.disc_p.load_state_dict(checkpoint['disc_p'])
            self.generator_optimizer.load_state_dict(checkpoint['generator_optimizer'])
            self.discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer'])
            self.gen_history = checkpoint['gen_history']
            self.disc_history = checkpoint['disc_history']
            return True

        return False

    # 0. распакавать данные на нужное устройство
    # 1. сбросить градиент
    # 2. прогнать данные через сеть
    # 3. посчитать loss
    # 4. залоггировать его куда-нибудь
    # 5. сделать .backward()
    # 6. optimizer.step()
    # (7. вывести пример ) 
    def train(self, pm_loader, PATH, fname='temp_checkpoint.pt'):

        checkpoint_exists = self.upload_model(PATH+fname)
    
        if checkpoint_exists:
            for e, (g, h) in enumerate(zip(self.gen_history, self.disc_history)):
                print(f"Epoch: {e} | Generator Loss: {g} | Discriminator Loss: {h}")
            print(f"\nContinue training from {self.start_epoch} epoch")
        
        for epoch in range(self.start_epoch, self.epochs):
            
            avg_gen_loss, avg_disc_loss = 0, 0
            data = tqdm(pm_loader, leave=False, total=len(pm_loader))
            for i, (real_monet, real_photo) in enumerate(data):

                real_monet, real_photo = real_monet.to(device), real_photo.to(device) 

                # Пропустим через генератор и посчитаем все лоссы

                self.update_req_grad([self.disc_m, self.disc_p], False)
                self.generator_optimizer.zero_grad()

                # monet -> photo -> monet
                fake_photo = self.gen_m2p(real_monet)
                cycled_monet = self.gen_p2m(fake_photo)

                # photo -> monet -> photo
                fake_monet = self.gen_p2m(real_photo)
                cycled_photo = self.gen_m2p(fake_monet)

                # monet -> monet, photo -> photo
                id_monet = self.gen_p2m(real_monet)
                id_photo = self.gen_m2p(real_photo)

                # Пропустим fake_monet и fake_photo через дискриминатор, чтобы оценить его лосс
                disc_monet = self.disc_m(fake_monet)
                disc_photo = self.disc_p(fake_photo)
                real = torch.ones(disc_monet.size()).to(self.device) # - идельный ответ дискриминатора

                # Посчитаем лоссы
                id_loss_monet = self.identity_criterion(id_monet, real_monet) * self.lmbda * self.id_coef
                id_loss_photo = self.identity_criterion(id_photo, real_photo) * self.lmbda * self.id_coef

                cycle_loss_monet = self.cycle_criterion(cycled_monet, real_monet) * self.lmbda
                cycle_loss_photo = self.cycle_criterion(cycled_photo, real_photo) * self.lmbda

                gen_loss_monet = self.gen_criterion(disc_monet, real)
                gen_loss_photo = self.gen_criterion(disc_photo, real)

                total_gen_loss = cycle_loss_monet + cycle_loss_photo + gen_loss_monet + gen_loss_photo + id_loss_monet + id_loss_photo
                avg_gen_loss += total_gen_loss.item()

                total_gen_loss.backward()
                self.generator_optimizer.step()
                
                # Переходим к дискриминатору

                self.update_req_grad([self.disc_m, self.disc_p], True)
                self.discriminator_optimizer.zero_grad()

                # Если передавать те же самые тензоры fake_monet и fake_photo, потом считать лоссы и делать .backward(),
                # то python будет ругаться, что .backward() делается два раза:
                # "Trying to backward through the graph a second time ...".
                # Выход из ситуации: fake_monet и fake_photo провести через какие-нибудь "дебри", чтобы python решил,
                # что это новые тензоры, хотя они по значениям одинаковые. - хз, насколько я прав.
                # Также хз, сработает ли данный метод, если все обучается на cpu(), но пока это не моя проблема.
                fake_monet = torch.tensor(fake_monet.cpu().data.numpy()).to(self.device)
                fake_photo = torch.tensor(fake_photo.cpu().data.numpy()).to(self.device)
                disc_real_monet = self.disc_m(real_monet)
                disc_fake_monet = self.disc_m(fake_monet)
                disc_real_photo = self.disc_p(real_photo)
                disc_fake_photo = self.disc_p(fake_photo)

                real = torch.ones(disc_real_monet.size()).to(self.device)
                fake = torch.zeros(disc_fake_monet.size()).to(self.device)

                disc_real_monet_loss = self.disc_criterion(disc_real_monet, real)
                disc_fake_monet_loss = self.disc_criterion(disc_fake_monet, fake)
                disc_real_photo_loss = self.disc_criterion(disc_real_photo, real)
                disc_fake_photo_loss = self.disc_criterion(disc_fake_photo, fake)

                monet_disc_loss = 0.5 * (disc_real_monet_loss + disc_fake_monet_loss)
                photo_disc_loss = 0.5 * (disc_real_photo_loss + disc_fake_photo_loss)
                total_disc_loss = monet_disc_loss + photo_disc_loss
                avg_disc_loss += total_disc_loss.item()

                monet_disc_loss.backward(retain_graph=True)
                photo_disc_loss.backward()
                self.discriminator_optimizer.step()

                data.set_postfix(gen_loss=total_gen_loss.item(), disc_loss=total_disc_loss.item())
              
            avg_gen_loss /= len(pm_loader)
            avg_disc_loss /= len(pm_loader)
            
            self.gen_history.append(avg_gen_loss)
            self.disc_history.append(avg_disc_loss)

            torch.save({
                'last_epoch': epoch,
                'gen_m2p': self.gen_m2p.state_dict(),
                'gen_p2m': self.gen_p2m.state_dict(),
                'disc_m': self.disc_m.state_dict(),
                'disc_p': self.disc_p.state_dict(),
                'generator_optimizer': self.generator_optimizer.state_dict(),
                'discriminator_optimizer': self.discriminator_optimizer.state_dict(),
                'gen_history': self.gen_history,
                'disc_history': self.disc_history
            }, PATH+fname)
            
            print(f"Epoch: {epoch} | Generator Loss: {avg_gen_loss} | Discriminator Loss: {avg_disc_loss}")

            if epoch % 2 == 0:

                test_monet, test_photo = next(iter(loader))

                pred_monet = self.gen_p2m(test_photo.to(device)).cpu().detach()
                pred_photo = self.gen_m2p(test_monet.to(device)).cpu().detach()
                pred_monet, pred_photo = unnorm(pred_monet[0]), unnorm(pred_photo[0])
                test_monet, test_photo = unnorm(test_monet[0]), unnorm(test_photo[0])

                fig, ax = plt.subplots(2, 2, figsize=(10, 10))
                ax[0, 0].imshow(test_monet.permute(1, 2, 0))
                ax[0, 0].set_title("Input monet")
                ax[0, 0].axis("off")

                ax[0, 1].imshow(pred_photo.permute(1, 2, 0))
                ax[0, 1].set_title("Pred photo")
                ax[0, 1].axis("off")

                ax[1, 0].imshow(test_photo.permute(1, 2, 0))
                ax[1, 0].set_title("Input photo")
                ax[1, 0].axis("off")

                ax[1, 1].imshow(pred_monet.permute(1, 2, 0))
                ax[1, 1].set_title("Pred monet")
                ax[1, 1].axis("off")
                
                plt.show()

            self.gen_lr_sched.step()
            self.disc_lr_sched.step()

In [None]:
gan = CycleGAN(31, device)

In [None]:
# PATH = '/content/gdrive/MyDrive/Colab Notebooks/Tinkoff DL/Занятие 6/'
PATH = '/content/gdrive/MyDrive/Colab Notebooks/'

In [None]:
# Так как при обучении я пользовался другим гугл аккаунтом, то в dataloader был другой порядок картинок, поэтому
# первые картинки в процессе обучения (на которой чекаем прогресс генератора и дискриманатора) и первые картинки в даталоадере сейчас различаются.
gan.train(loader, PATH)

In [None]:
# Для загрузки весов модели
gan.train(loader, PATH)

In [None]:
plt.xlabel("Epochs")
plt.ylabel("Losses")
plt.plot(gan.gen_history, 'r', label='Generator Loss')
plt.plot(gan.disc_history, 'b', label='Descriminator Loss')
plt.legend()
plt.show()

In [None]:
def print_ex(test_monet, test_photo):

    with torch.no_grad():
        pred_monet = gan.gen_p2m(test_photo.to(device)).cpu().detach().to(torch.float32)
        pred_photo = gan.gen_m2p(test_monet.to(device)).cpu().detach().to(torch.float32)

    test_monet, test_photo = unnorm(test_monet), unnorm(test_photo)
    pred_monet, pred_photo = unnorm(pred_monet), unnorm(pred_photo)

    fig, ax = plt.subplots(2, 2, figsize=(8, 8))
    ax[0, 0].imshow(test_monet.permute(1, 2, 0))
    ax[0, 0].set_title("Input monet")
    ax[0, 0].axis("off")

    ax[0, 1].imshow(pred_photo.permute(1, 2, 0))
    ax[0, 1].set_title("Pred photo")
    ax[0, 1].axis("off")

    ax[1, 0].imshow(test_photo.permute(1, 2, 0))
    ax[1, 0].set_title("Input photo")
    ax[1, 0].axis("off")

    ax[1, 1].imshow(pred_monet.permute(1, 2, 0))
    ax[1, 1].set_title("Pred monet")
    ax[1, 1].axis("off")

    plt.show()

def print_exs(idxs, source):
    for i in idxs:
        assert len(source[i]) == 2, "Not two photos passed"
        print_ex(*source[i])
        print('_'*100)

In [None]:
print_exs(range(8, 108, 20), dataset)

Мне кажется, что получилось хорошо!

In [None]:
class PhotoDataset(Dataset):
    def __init__(self, photo_path, size=256):
        
        self.size = size
        self.photo_path = photo_path

        self.transforms = transforms.Compose([
            transforms.Resize(self.size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.photo = self.load(photo_path)

      
    def load(self, path, count=-1):
        filenames = []
        for root, dirs, files in os.walk(path):
            for file in files:
                if file.endswith('.jpg'):
                    filenames.append(os.path.join(root, file))

        images = []
        for filename in tqdm(filenames):
            try:
                with Image.open(filename) as image:
                    images.append(image.copy())
            except:
                pass
        return images
        
    def __len__(self):
        return len(self.photo)

    def __getitem__(self, idx):
        return self.transforms(self.photo[idx])

In [None]:
# photo_dataset = PhotoDataset('photo_jpg', RESCALE_SIZE)
photo_loader = DataLoader(photo_dataset, batch_size=1, pin_memory=True)

In [None]:
len(photo_loader)

In [None]:
# !mkdir /content/gdrive/MyDrive/'Colab Notebooks'/'Tinkoff DL'/'Занятие 6'/images
!mkdir /content/gdrive/MyDrive/'Colab Notebooks'/images

In [None]:
data = tqdm(photo_loader, leave=False, total=len(photo_loader))
for i, photo in enumerate(data):
    with torch.no_grad():
        pred_monet = gan.gen_p2m(photo.to(device)).cpu().detach()
    pred_monet = unnorm(pred_monet)
    img = transforms.ToPILImage()(pred_monet[0]).convert("RGB")
    # img.save("/content/gdrive/MyDrive/Colab Notebooks/Tinkoff DL/Занятие 6/images/" + str(i+1) + ".jpg")
    img.save("/content/gdrive/MyDrive/Colab Notebooks/images/" + str(i+1) + ".jpg")

In [None]:
!ls /content/gdrive/MyDrive/'Colab Notebooks'/images/

In [None]:
# ! ls ../input/hw6-results/images

In [None]:
# import shutil

# shutil.make_archive("./images", 'zip', "../input/hw6-results/images")