In [None]:
у меня закончилось время на gpu, поэтому я не успела до конца выполнить ноутбук

In [1]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import albumentations as A
import numpy as np

# Data loading

In [2]:
class MonetPhotoDataset(Dataset):
    def __init__(self, root_monet, root_photo, transform=None):
        self.transform = transform
        self.root_monet = root_monet
        self.root_photo = root_photo
        
        self.monet_images = os.listdir(root_monet)
        self.photo_images = os.listdir(root_photo)

        self.monet_len = len(self.monet_images)
        self.photo_len = len(self.photo_images)

    def __len__(self):
        return max(len(self.monet_images), len(self.photo_images))
    
    def __getitem__(self, idx):
        monet_img = self.monet_images[idx % self.monet_len]
        photo_img = self.photo_images[idx % self.photo_len]
        
        monet_path = os.path.join(self.root_monet, monet_img)
        photo_path = os.path.join(self.root_photo, photo_img)
        
        monet_img = np.array(Image.open(monet_path).convert('RGB'))
        photo_img = np.array(Image.open(photo_path).convert('RGB'))
    
        
        albumentations = self.transform(image=photo_img, image2=monet_img)
        photo_img = albumentations['image']
        monet_img = albumentations['image2']
        
        return monet_img, photo_img


In [3]:
transform = A.Compose(
    [
        A.Resize(256, 256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
    ],
    additional_targets={'image2':'image'},
)

In [4]:
dataset = MonetPhotoDataset('/kaggle/input/gan-getting-started/monet_jpg', 
                            '/kaggle/input/gan-getting-started/photo_jpg', 
                            transform)

In [5]:
len(dataset)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
loader = DataLoader(dataset, batch_size=1, shuffle=True)

# Useful blocks

In [8]:
class ConvBlock(nn.Module):
    def __init__(self, input_size, output_size, downsampling=True, kernel_size=3, stride=2, padding=1, activation='relu', batch_norm=True, output_padding=1, padding_mode='reflect'):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_size, output_size, kernel_size, stride, padding, padding_mode=padding_mode) 
            if downsampling 
            else nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, output_padding),
            nn.InstanceNorm2d(output_size) if batch_norm else nn.Identity(),
            nn.ReLU(inplace=True) if activation == 'relu' else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True) if activation == 'lrelu' else nn.Identity()
        )
        
    def forward(self,x):
        return self.conv(x)

In [9]:
class ResBlock(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.conv = nn.Sequential(
            ConvBlock(input_size, input_size,stride=1),
            nn.InstanceNorm2d(input_size),
            nn.ReLU(),
            ConvBlock(input_size, input_size, stride=1, activation='no activation'),
        )

    def forward(self, x):
        return torch.relu(x + self.conv(x))

# Generator

In [10]:
class Generator(nn.Module):
    def __init__(self, input_dim=3, res_num=5, num_filter=64):
        super().__init__()
        layers = [
                  ConvBlock(input_dim, num_filter, kernel_size=7, stride=1, padding=3, batch_norm=False),
                  ConvBlock(num_filter, num_filter * 2),
                  ConvBlock(2 * num_filter, 4 * num_filter)
        ]
        for i in range(res_num):
            layers.append(ResBlock(4 * num_filter))
        layers.extend([
                        ConvBlock(4 * num_filter, 2 * num_filter, downsampling=False),
                        ConvBlock(2 * num_filter, num_filter, downsampling=False),
                        ConvBlock(num_filter, input_dim, kernel_size=7, stride=1, padding=3, activation='no activation', padding_mode='reflect')
        ])
        self.main = nn.Sequential(*layers)
    
    def forward(self, x):
        return torch.tanh(self.main(x))

# Discriminator

In [11]:
class Discriminator(nn.Module):
    def __init__(self, input_dim=3, num_filter=64):
        super().__init__()
        self.main = nn.Sequential(
            ConvBlock(input_dim, num_filter, kernel_size=4, activation='lrelu', batch_norm=False),
            ConvBlock(num_filter, num_filter * 2, kernel_size=4, activation='lrelu'),
            ConvBlock(2 * num_filter, num_filter * 4, kernel_size=4, activation='lrelu'),
            ConvBlock(4 * num_filter, num_filter * 8, kernel_size=4, stride=1, activation='lrelu'),
            ConvBlock(8 * num_filter, 1, kernel_size=4, stride=1, batch_norm=False, activation='no activation', padding_mode='reflect'),
        )
  
    def forward(self, x):
        return torch.sigmoid(self.main(x))

In [12]:
device

# Training

In [15]:
monet_generator = Generator().to(device)
photo_generator = Generator().to(device)

monet_discriminator = Discriminator().to(device)
photo_discriminator = Discriminator().to(device)

gen_opt = torch.optim.Adam(
    list(monet_generator.parameters()) + list(photo_generator.parameters()),
    lr=2e-4,
    betas=(0.5, 0.999)
)
schedule_g = torch.optim.lr_scheduler.StepLR(gen_opt, 1, 0.7)

disc_opt = torch.optim.Adam(
    list(monet_discriminator.parameters()) + list(photo_discriminator.parameters()),
    lr=2e-4,
    betas=(0.5, 0.999)
)
schedule_d = torch.optim.lr_scheduler.StepLR(disc_opt, 1, 0.7)

l1 = nn.L1Loss().to(device)
mse = nn.MSELoss().to(device)

EPOCHS = 5
CYCLE_LAMBDA = 10

In [None]:
for epoch in range(EPOCHS):
    for i, (monet, photo) in enumerate(loader):
        photo = photo.transpose(2, 3).transpose(1, 2).float().to(device)
        monet = monet.transpose(2, 3).transpose(1, 2).float().to(device)
        
        #DISCRIMINATORS TRAINING
        #monet discriminator
        fake_monet = monet_generator(photo)
        fake_photo = photo_generator(monet)
        fake_monet = monet_generator(photo)
        labels = monet_discriminator(monet)
        fake_labels = monet_discriminator(fake_monet.detach())
        disc_monet_loss = mse(labels, torch.ones_like(labels).to(device)) + mse(fake_labels, torch.zeros_like(fake_labels).to(device))
        #photo discriminator
        fake_photo = photo_generator(monet)
        labels = photo_discriminator(photo)
        fake_labels = monet_discriminator(fake_photo.detach())
        disc_photo_loss = mse(labels, torch.ones_like(labels)) + mse(fake_labels, torch.zeros_like(fake_labels))
        #total discriminator loss
        D_loss = (disc_monet_loss + disc_photo_loss) / 2

        disc_opt.zero_grad()
        D_loss.backward()
        disc_opt.step()
        
        #GENERATOR TRAINING
        #generator loss
        fake_monet_labels = monet_discriminator(fake_monet)
        fake_photo_labels = photo_discriminator(fake_photo)
        gen_loss = mse(fake_monet_labels, torch.ones_like(fake_monet_labels)) + mse(fake_photo_labels, torch.ones_like(fake_photo_labels))
        #cycle loss
        cycle_monet = monet_generator(fake_monet.detach())
        cycle_photo = photo_generator(fake_photo.detach())
        cycle_loss = l1(cycle_monet, monet) + l1(cycle_photo, photo)
        #indentity loss
        identity_monet = photo_generator(monet)
        identity_photo = monet_generator(photo)
        identity_loss = l1(identity_monet, monet) + l1(identity_photo, photo)
        #total generator loss
        G_loss = gen_loss + CYCLE_LAMBDA * cycle_loss + identity_loss
        
        gen_opt.zero_grad()
        G_loss.backward()
        gen_opt.step()
        
        if i % 500 == 0:
            print(
                'epoch: %d | G_loss: %.4f | D_loss %.4f'
                %(epoch + 1, gen_loss.cpu().detach().item(), D_loss.cpu().detach().item())
            )
            for idx, (_, photo) in enumerate(loader):
                plt.subplot(121)
                plt.imshow(photo[0]* 0.5 + 0.5)
                photo = photo.transpose(2, 3).transpose(1, 2).float().to(device)
                fake_monet = monet_generator(photo).cpu().detach()[0]
                plt.subplot(122)
                plt.imshow(fake_monet.squeeze().permute(1, 2, 0) * 0.5 + 0.5)
                plt.show()

                if idx == 1:
                    break
        
    schedule_d.step()
    schedule_g.step()

In [None]:
if i % 500 == 0:
    print(
        'epoch: %d | G_loss: %.4f | D_loss %.4f'
        %(epoch + 1, gen_loss.cpu().detach().item(), D_loss.cpu().detach().item())
    )
    for idx, (_, photo) in enumerate(loader):
        plt.subplot(121)
        plt.imshow(photo[0]* 0.5 + 0.5)
        photo = photo.transpose(2, 3).transpose(1, 2).float().to(device)
        fake_monet = monet_generator(photo).cpu().detach()[0]
        plt.subplot(122)
        plt.imshow(fake_monet.squeeze().permute(1, 2, 0) * 0.5 + 0.5)
        plt.show()

        if idx == 5:
            break