In [None]:
import os

import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from PIL import Image
from tqdm.notebook import tqdm

import warnings
warnings.simplefilter('ignore')

In [None]:
class MonetPhotoDataset(Dataset):
    def __init__(self, root_monet, root_photo, transform=None):
        self.transform = transform
        self.root_monet = root_monet
        self.root_photo = root_photo
        
        self.monet_images = os.listdir(root_monet)
        self.photo_images = os.listdir(root_photo)
        self.length_dataset = max(len(self.monet_images), len(self.photo_images))

        self.monet_len = len(self.monet_images)
        self.photo_len = len(self.photo_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, idx):
        monet_img = self.monet_images[idx % self.monet_len]
        photo_img = self.photo_images[idx % self.photo_len]
        
        monet_path = os.path.join(self.root_monet, monet_img)
        photo_path = os.path.join(self.root_photo, photo_img)
        
        monet_img = Image.open(monet_path).convert('RGB')
        photo_img = Image.open(photo_path).convert('RGB')
    
        
        monet_img = self.transform(monet_img)
        photo_img = self.transform(photo_img)
        
        return monet_img, photo_img

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

In [None]:
dataset = MonetPhotoDataset('../input/gan-getting-started/monet_jpg', '../input/gan-getting-started/photo_jpg', transform)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class ResBlock(nn.Module):
    def __init__(self, f):
        super(ResBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(f, f, 3, 1, 1), 
            nn.InstanceNorm2d(f),
            nn.ReLU(),
            
            nn.Conv2d(f, f, 3, 1, 1),
        )
        
        self.norm = nn.InstanceNorm2d(f)
        
    def forward(self, x):
        return torch.relu(self.conv(x) + x)
    
class Generator(nn.Module):
    def __init__(self, f=64, res_blocks=6):
        super(Generator, self).__init__()
        layers = []
            
        layers.extend([
            nn.Conv2d(3, f, 7, 1, 3),
            nn.InstanceNorm2d(f), 
            nn.ReLU(True),

            nn.Conv2d(f, f * 2, 3, 2, 1),
            nn.InstanceNorm2d(f * 2),
            nn.ReLU(True),

            nn.Conv2d(f * 2, f * 4, 3, 2, 1),
            nn.InstanceNorm2d(f * 4),
            nn.ReLU(True)
        ])
            
        for i in range(res_blocks):
            layers.append(ResBlock(f * 4))
        
        layers.extend([
            nn.ConvTranspose2d(f * 4, f * 2 * 4, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.InstanceNorm2d(f * 2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(f * 2, f * 4, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.InstanceNorm2d(f),
            nn.ReLU(True),
    
            nn.ReflectionPad2d(3),
            nn.Conv2d(f, 3, 7, 1, 0),
            nn.Tanh()
            
        ])
         
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super(Discriminator, self).__init__()
        self.init = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 1, 1),
            nn.InstanceNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        output = self.init(x)
        return output

In [None]:
disc_M = Discriminator().to(device)
disc_P = Discriminator().to(device)

gen_M = Generator().to(device)
gen_P = Generator().to(device)

opt_disc = torch.optim.Adam(
    list(disc_M.parameters()) + list(disc_P.parameters()),
    lr=5e-4,
    betas=(0.5, 0.999)
)

opt_gen = torch.optim.Adam(
    list(gen_M.parameters()) + list(gen_P.parameters()),
    lr=5e-4,
    betas=(0.5, 0.999)
)

l1 = nn.L1Loss()
mse = nn.MSELoss()

In [None]:
loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [None]:
num_epochs = 6

for epoch in range(num_epochs):
    loop = tqdm(loader, leave=True)
    for idx, (monet, photo) in enumerate(loop):
        monet = monet.to(device)
        photo = photo.to(device)

        fake_photo = gen_P(monet)
        D_P_real = disc_P(photo)
        D_P_fake = disc_P(fake_photo.detach())
        D_P_real_loss = mse(D_P_real, torch.ones(len(D_P_real)).to(device))
        D_P_fake_loss = mse(D_P_fake, torch.zeros(len(D_P_fake)).to(device))
        D_P_loss = D_P_real_loss + D_P_fake_loss

        fake_monet = gen_M(photo)
        D_M_real = disc_M(monet)
        D_M_fake = disc_M(fake_monet.detach())
        D_M_real_loss = mse(D_M_real, torch.ones(len(D_M_real)).to(device))
        D_M_fake_loss = mse(D_M_fake, torch.zeros(len(D_M_fake)).to(device))
        D_M_loss = D_M_real_loss + D_M_fake_loss

        D_loss = (D_P_loss + D_M_loss) / 2
        opt_disc.zero_grad()
        D_loss.backward()
        opt_disc.step()

        D_P_fake = disc_P(fake_photo)
        D_M_fake = disc_M(fake_monet)
        loss_G_P = mse(D_P_fake, torch.ones(len(D_P_fake)).to(device))
        loss_G_M = mse(D_M_fake, torch.ones(len(D_M_fake)).to(device))

        cycle_photo = gen_P(fake_monet)
        cycle_monet = gen_M(fake_photo)
        cycle_photo_loss = l1(photo, cycle_photo)
        cycle_monet_loss = l1(monet, cycle_monet)

        identity_photo = gen_P(photo)
        identity_monet = gen_M(monet)
        identity_photo_loss = l1(photo, identity_photo)
        identity_monet_loss = l1(monet, identity_monet)

        G_loss = (
            loss_G_P 
            + loss_G_M 
            + cycle_photo_loss * 10
            + cycle_monet_loss * 10
            + identity_photo_loss
            + identity_monet_loss
        )

        opt_gen.zero_grad()
        G_loss.backward()
        opt_gen.step()
        
        if idx % 200 == 0:
            print(
                'G_P_loss:', loss_G_P.cpu().detach().item(), 
                'G_M_loss:', loss_G_M.cpu().detach().item(), 
                'D_M_loss:', D_M_loss.cpu().detach().item(), 
                'D_P_loss:', D_P_loss.cpu().detach().item()
            )

In [None]:
for idx, (monet, photo) in enumerate(loader):
    monet = monet.to(device)
    photo = photo.to(device)
    fake_photo = gen_P(monet)
    fake_monet = gen_M(photo).cpu().detach()[0]
    
    plt.subplot(121)
    plt.imshow(photo.cpu().detach()[0].squeeze().permute(1, 2, 0) * 0.5 + 0.5)
    plt.subplot(122)
    plt.imshow(fake_monet.squeeze().permute(1, 2, 0) * 0.5 + 0.5)
        
    plt.show()
        
    if idx == 4:
        break

In [None]:
import PIL
! mkdir ./images

In [None]:
import PIL
i = 1
for (monet, photo) in tqdm(dataset):
    
    photo = photo.to(device)
    prediction = gen_M(photo[None, :]).cpu().detach()[0]
    prediction = (127.5 + 127.5 * prediction).permute(1, 2, 0).numpy().astype(np.uint8)

    im = PIL.Image.fromarray(prediction)
    im.save("./images/" + str(i) + ".jpg")
    i += 1
    if i == 7030:
        break

In [None]:
import shutil
shutil.make_archive('/kaggle/working/images/', 'zip', '/kaggle/working/images/')