In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    print(dirname)
    for filename in filenames:
#         print(os.path.join(dirname, filename))
        pass

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

In [None]:
from glob import glob
from dataclasses import dataclass
import PIL
from torchvision import transforms

@dataclass
class ImageDataset(Dataset):
    
    dirname: str
        
    def __post_init__(self):
        self.filenames = glob(f'{self.dirname}/*.jpg')
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]
        img = PIL.Image.open(filename)
        return self.transform(img)
    
    

In [None]:
monet_dataset = ImageDataset(dirname='/kaggle/input/gan-getting-started/monet_jpg')
photo_dataset = ImageDataset(dirname='/kaggle/input/gan-getting-started/photo_jpg')

In [None]:
len(monet_dataset), len(photo_dataset)

In [None]:
import torch.nn as nn

OUTPUT_CHANNELS = 3

@dataclass(eq=False)
class Downsample(nn.Module):
    in_channels: int
    filters: int
    size: int
    apply_instance_norm: bool = True
        
        
    def __post_init__(self):
        nn.Module.__init__(self)
        layers = []
        layers.append(nn.Conv2d(
            in_channels=self.in_channels,
            kernel_size=self.size,
            out_channels=self.filters,
            stride=2,
            padding=1,
            bias=False
        ))
        if self.apply_instance_norm:
            layers.append(nn.InstanceNorm2d(
                num_features=3
            ))
        layers.append(nn.LeakyReLU())
        self.network = nn.Sequential(*layers)
    
    def forward(self, inp):
        return self.network(inp)
        
        
@dataclass(eq=False)
class Upsample(nn.Module):
    in_channels: int
    filters: int
    size: int
    apply_dropout: bool = False
        
        
    def __post_init__(self):
        nn.Module.__init__(self)
        layers = []
        layers.append(nn.ConvTranspose2d(
            in_channels=self.in_channels,
            kernel_size=self.size,
            out_channels=self.filters,
            stride=2,
            padding=1,
            bias=False
        ))
        if self.apply_dropout:
            layers.append(nn.Dropout(
                p=0.5
            ))
        layers.append(nn.ReLU())
        self.network = nn.Sequential(*layers)
    
    def forward(self, inp):
        return self.network(inp)
        
    

In [None]:
@dataclass(eq=False)
class Generator(nn.Module):
    
    def __init__(self):
        nn.Module.__init__(self)
        self.down_stack = nn.ModuleList([
            Downsample(3, 64, 4, apply_instance_norm=False), # (bs, 128, 128, 64)
            Downsample(64, 128, 4), # (bs, 64, 64, 128)
            Downsample(128, 256, 4), # (bs, 32, 32, 256)
            Downsample(256, 512, 4), # (bs, 16, 16, 512)
            Downsample(512, 512, 4), # (bs, 8, 8, 512)
            Downsample(512, 512, 4), # (bs, 4, 4, 512)
            Downsample(512, 512, 4), # (bs, 2, 2, 512)
            Downsample(512, 512, 4, apply_instance_norm=False), # (bs, 1, 1, 512)
        ])
        
        self.up_stack = [
            Upsample(512, 512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
            Upsample(1024, 512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
            Upsample(1024, 512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
            Upsample(1024, 512, 4), # (bs, 16, 16, 1024)
            Upsample(1024, 256, 4), # (bs, 32, 32, 512)
            Upsample(512, 128, 4), # (bs, 64, 64, 256)
            Upsample(256, 64, 4), # (bs, 128, 128, 128)
        ]
        
        self.last = nn.ConvTranspose2d(
            in_channels=128,
            kernel_size=4,
            out_channels=3,
            stride=2,
            padding=1,
            bias=False
        )
        self.output_activation = nn.Tanh()
    
    def forward(self, x):
        skips = []
        for layer in self.down_stack:
            x = layer(x)
            skips.append(x)
        skips = reversed(skips[:-1])
        for layer, skip in zip(self.up_stack, skips):
            x = layer(x)
            x = torch.cat([x, skip], axis=1)
        x = self.last(x)
        x = self.output_activation(x)
        return x

In [None]:
generator = Generator()

out = generator(monet_dataset[0].unsqueeze(0))

In [None]:
out.shape

In [None]:
@dataclass(eq=False)
class Discriminator(nn.Module):
    
    def __post_init__(self):
        nn.Module.__init__(self)
        self.down_stack = nn.ModuleList([
            Downsample(3, 64, 4, apply_instance_norm=False),
            Downsample(64, 128, 4),
            Downsample(128, 256, 4),
        ])
        
        self.conv1 = nn.Conv2d(
            in_channels=256,
            kernel_size=4,
            out_channels=512,
            stride=1,
            padding=1,
            bias=False
        )
        self.norm = nn.InstanceNorm2d(
            num_features=3
        )
        self.activation = nn.LeakyReLU()
        
        self.conv2 = nn.Conv2d(
            in_channels=512,
            kernel_size=4,
            out_channels=1,
            stride=1,
            padding=1
        )
        
    def forward(self, x):
        for layer in self.down_stack:
            x = layer(x)
        x = self.conv1(x)
        x = self.norm(x)
        x = self.activation(x)
        x = self.conv2(x)
        return x
        

In [None]:
disc = Discriminator()

disc(monet_dataset[0].unsqueeze(0)).shape

In [None]:
from matplotlib.pyplot import imshow

def show_image(img_tensor):
    imshow(img_tensor.permute(1, 2, 0))

In [None]:
show_image(monet_dataset[99])

In [None]:
import pytorch_lightning as pl

In [None]:
from torch.utils.data import IterableDataset
from random import randint
import torch.optim as optim
import itertools

class CycleGANDataset(Dataset):
    
    def __init__(self, monet_dataset, photo_dataset):
        self.monet_dataset = monet_dataset
        self.photo_dataset = photo_dataset
    
    def __getitem__(self, idx):
        monet = self.monet_dataset[idx]
        random_photo_idx = randint(0, len(self.photo_dataset))
        photo = self.photo_dataset[random_photo_idx]
        return monet, photo
        
    def __len__(self):
        return len(self.monet_dataset)
    

class CycleGAN(pl.LightningModule):
    
    def __init__(self, lambda_cycle=10):
        super().__init__()
        self.save_hyperparameters()
        self.m_gen = Generator()
        self.p_gen = Generator()
        self.m_disc = Discriminator()
        self.p_disc = Discriminator()
        self.lambda_cycle = lambda_cycle
        
    def training_step(self, batch, batch_idx, optimizer_idx):
        real_monet, real_photo = batch
        
        # photo to monet back to photo
        fake_monet = self.m_gen(real_photo)
        cycled_photo = self.p_gen(fake_monet)

        # monet to photo back to monet
        fake_photo = self.p_gen(real_monet)
        cycled_monet = self.m_gen(fake_photo)

        # generating itself
        same_monet = self.m_gen(real_monet)
        same_photo = self.p_gen(real_photo)

        # discriminator used to check, inputing real images
        disc_real_monet = self.m_disc(real_monet)
        disc_real_photo = self.p_disc(real_photo)

        # discriminator used to check, inputing fake images
        disc_fake_monet = self.m_disc(fake_monet)

    
    def configure_optimizers(self):
        return [
            optim.AdamW(
                itertools.chain(*[
                    self.m_gen.parameters(),
                    self.p_gen.parameters(),
                ]),
                lr=1e-4
            ),
            optim.AdamW(
                itertools.chain(*[
                    self.m_disc.parameters(),
                    self.p_disc.parameters()
                ]),
                lr=1e-4
            )
        ]
    
    def train_dataloader(self):
        return DataLoader(CycleGANDataset(monet_dataset, photo_dataset), shuffle=True)

In [None]:
trainer = pl.Trainer(fast_dev_run=True)

In [None]:
module = CycleGAN()
trainer.fit(module)