In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
# Directory
directory = "images"
  
# Parent Directory path
parent_dir = "/kaggle/working/"
  
# Path
path = os.path.join(parent_dir, directory)
  
# Create the directory
# 'GeeksForGeeks' in
# '/home / User / Documents'
if not os.path.exists(path):
    os.mkdir(path)

if os.path.exists('/kaggle/working/images.zip'):
    os.remove('/kaggle/working/images.zip')




# **Config params**

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BASE_DIR = "/kaggle/input/gan-getting-started"
OUTPUT_DIR = "/kaggle/working/images"
TRAIN_DIR = BASE_DIR+"/data/train"
VAL_DIR = BASE_DIR+"/data/val"
BATCH_SIZE = 8
LEARNING_RATE = 0.0002
LAMBDA_IDENTITY = 0.1
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 120
LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
CHECKPOINT_CRITIC_H = "critich.pth.tar"
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"



# **transforms**

In [None]:
transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

transforms_val = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)
transforms_2 = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.RandomCrop(224,224),
        A.OneOf([
            A.MotionBlur(p=1),
            A.OpticalDistortion(p=1),
            A.GaussNoise(p=1)
        ], p=0.1),
        A.OneOf([
            A.HorizontalFlip(p=1),
            A.RandomRotate90(p=1),
            A.VerticalFlip(p=1)
        ], p=0.3),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

In [None]:
import torch
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

In [None]:
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))



In [None]:
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np

class PhotoMonetDataset(Dataset):
    def __init__(self, root_photo, root_monet, transform=None, photo_len=300, monet_len=30):
        self.root_photo = root_photo
        self.root_monet = root_monet
        self.transform = transform

        self.photo_images = os.listdir(root_photo)[:photo_len]
        self.monet_images = os.listdir(root_monet)[:monet_len]
        self.photo_len = photo_len
        self.monet_len = monet_len
        self.length_dataset = max(self.photo_len, self.monet_len) # 1500, 300
#         self.photo_len = len(self.photo_images)
#         self.monet_len = len(self.monet_images)
        # self.all_monet_images = []
        # self.all_photo_images = []
        
#         for x in self.monet_images:
#             self.all_monet_images.append(self.transform(image=np.array(Image.open(os.path.join(self.root_monet, x)).convert("RGB")))["image"])
            
#         for x in self.photo_images:
#             self.all_photo_images.append(self.transform(image=np.array(Image.open(os.path.join(self.root_photo, x)).convert("RGB")))["image"])

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
#         photo_img = self.all_photo_images[index % self.photo_len]
#         monet_img = self.all_monet_images[index % self.monet_len]
#         return photo_img, monet_img
        photo_img = self.photo_images[index % self.photo_len]
        monet_img = self.monet_images[index % self.monet_len]

        photo_path = os.path.join(self.root_photo, photo_img)
        monet_path = os.path.join(self.root_monet, monet_img)

        photo_img = np.array(Image.open(photo_path).convert("RGB"))
        monet_img = np.array(Image.open(monet_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=photo_img, image0=monet_img)
            photo_img = augmentations["image"]
            monet_img = augmentations["image0"]

        return photo_img, monet_img

In [None]:
import random, torch, os, numpy as np
import sys

import torch.nn as nn

import copy

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)




In [None]:
import numpy
import numpy as np
import torch
import matplotlib.pyplot as plt
import sys
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from torchvision.utils import save_image



def train_fn(disc_M, disc_P, gen_P, gen_M, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    M_reals = 0
    M_fakes = 0
    loop = tqdm(loader, leave=True)

    ret_D_loss = [0]
    ret_loss_G_P = [0]
    ret_loss_G_H = [0]
    ret_cycle_photo_loss = [0]
    ret_cycle_monet_loss = [0]
    ret_identity_monet_loss = [0]
    ret_identity_photo_loss = [0]
    ret_G_loss = [0]
    
    for idx, (photo, monet) in enumerate(loop):
        photo = photo.to(DEVICE)
        monet = monet.to(DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_monet = gen_M(photo)
            D_M_real = disc_M(monet)
            D_M_fake = disc_M(fake_monet.detach())
            M_reals += D_M_real.mean().item()
            M_fakes += D_M_fake.mean().item()
            D_M_real_loss = mse(D_M_real, torch.ones_like(D_M_real))
            D_M_fake_loss = mse(D_M_fake, torch.zeros_like(D_M_fake))
            D_M_loss = D_M_real_loss + D_M_fake_loss

            fake_photo = gen_P(monet)
            D_P_real = disc_P(photo)
            D_P_fake = disc_P(fake_photo.detach())
            D_P_real_loss = mse(D_P_real, torch.ones_like(D_P_real))
            D_P_fake_loss = mse(D_P_fake, torch.zeros_like(D_P_fake))
            D_P_loss = D_P_real_loss + D_P_fake_loss

            # put it togethor
            D_loss = (D_M_loss + D_P_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_M(fake_monet)
            D_P_fake = disc_P(fake_photo)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_P = mse(D_P_fake, torch.ones_like(D_P_fake))

            # cycle loss
            cycle_photo = gen_P(fake_monet)
            cycle_monet = gen_M(fake_photo)
            cycle_photo_loss = l1(photo, cycle_photo)
            cycle_monet_loss = l1(monet, cycle_monet)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_photo = gen_P(photo)
            identity_monet = gen_M(monet)
            identity_photo_loss = l1(photo, identity_photo)
            identity_monet_loss = l1(monet, identity_monet)

            # add all togethor
            G_loss = (
                    loss_G_P
                    + loss_G_H
                    + cycle_photo_loss * LAMBDA_CYCLE
                    + cycle_monet_loss * LAMBDA_CYCLE
                    + identity_monet_loss * LAMBDA_IDENTITY
                    + identity_photo_loss * LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        ret_D_loss[0]=float(D_loss)
        ret_loss_G_P[0]=float(loss_G_P)
        ret_loss_G_H[0]= float(loss_G_H)
        ret_cycle_photo_loss[0]= float(cycle_photo_loss)
        ret_cycle_monet_loss[0]=float(cycle_monet_loss)
        ret_identity_monet_loss[0]= float(identity_monet_loss)
        ret_identity_photo_loss[0]=float(identity_photo_loss)
        ret_G_loss[0]=float(G_loss)
        
#         if idx % 400 == 0:
#             save_image(fake_monet * 0.5 + 0.5, OUTPUT_DIR + f"/fakehorse_{idx}.png")
#             save_image(photo * 0.5 + 0.5, OUTPUT_DIR + f"/zebra_{idx}.png")

        loop.set_postfix(M_real=M_reals / (idx + 1), M_fake=M_fakes / (idx + 1))

    return {"D_loss": ret_D_loss, "loss_G_P": ret_loss_G_P,
            "loss_G_H": ret_loss_G_H,
            "cycle_photo_loss": ret_cycle_photo_loss,
            "cycle_monet_loss": ret_cycle_monet_loss,
            "identity_monet_loss": ret_identity_monet_loss,
            "identity_photo_loss": ret_identity_photo_loss, "G_loss": ret_G_loss}

In [None]:
def val_fn(gen_H, loader):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (zebra, horse) in enumerate(loop):
        zebra = zebra.to(DEVICE)
        horse = horse.to(DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_horse = gen_H(zebra)
        
        save_image(fake_horse * 0.5 + 0.5, OUTPUT_DIR + f"/fakehorse_{idx}.png")

#         loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))

In [None]:
disc_M = Discriminator(in_channels=3).to(DEVICE)
disc_P = Discriminator(in_channels=3).to(DEVICE)
gen_P = Generator(img_channels=3, num_residuals=4).to(DEVICE)
gen_M = Generator(img_channels=3, num_residuals=4).to(DEVICE)

opt_disc = optim.Adam(
    list(disc_M.parameters()) + list(disc_P.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

opt_gen = optim.Adam(
    list(gen_P.parameters()) + list(gen_M.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(opt_gen, lr_lambda=LambdaLR(NUM_EPOCHS, 0,
                                                                               NUM_EPOCHS / 2).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(opt_disc, lr_lambda=LambdaLR(NUM_EPOCHS, 0,
                                                                                NUM_EPOCHS / 2).step)
L1 = nn.L1Loss()
mse = nn.MSELoss()


In [None]:
if LOAD_MODEL:
    load_checkpoint(
        CHECKPOINT_GEN_H, gen_M, opt_gen, LEARNING_RATE,
    )
    load_checkpoint(
        CHECKPOINT_GEN_Z, gen_P, opt_gen, LEARNING_RATE,
    )
    load_checkpoint(
        CHECKPOINT_CRITIC_H, disc_M, opt_disc, LEARNING_RATE,
    )
    load_checkpoint(
        CHECKPOINT_CRITIC_Z, disc_P, opt_disc, LEARNING_RATE,
    )

In [None]:
dataset = PhotoMonetDataset(
    root_monet=BASE_DIR + "/monet_jpg", root_photo=BASE_DIR + "/photo_jpg",
    transform=transforms_2
)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

losses = {"D_loss": [], "loss_G_P": [],
          "loss_G_H": [],
          "cycle_photo_loss": [],
          "cycle_monet_loss": [],
          "identity_monet_loss": [],
          "identity_photo_loss": [],
          "G_loss": []}

for epoch in range(NUM_EPOCHS):
    print("-" * 4 + "Epoch number: " + str(epoch) + "-" * 4)
    d = train_fn(disc_M, disc_P, gen_P, gen_M, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)

#         for key, val in d.items():
#             print("key: " + key + ". " + "val: " + str(val))
#             losses[key] = losses[key] + val
#             plt.title(key)
#             plt.xlabel("epoch")
#             plt.ylabel(key)
#             plt.plot( np.array(range(epoch+1)),np.array(losses[key]))
#             plt.show()

    lr_scheduler_G.step()
    lr_scheduler_D.step()

    if SAVE_MODEL:
        save_checkpoint(gen_M, opt_gen,
                        filename=str(LEARNING_RATE) + "_" + str(LAMBDA_IDENTITY) + "_" + str(
                            LAMBDA_CYCLE) + CHECKPOINT_GEN_H)
        save_checkpoint(gen_P, opt_gen,
                        filename=str(LEARNING_RATE) + "_" + str(LAMBDA_IDENTITY) + "_" + str(
                            LAMBDA_CYCLE) + CHECKPOINT_GEN_Z)
        save_checkpoint(disc_M, opt_disc,
                        filename=str(LEARNING_RATE) + "_" + str(LAMBDA_IDENTITY) + "_" + str(
                            LAMBDA_CYCLE) + CHECKPOINT_CRITIC_H)
        save_checkpoint(disc_P, opt_disc,
                        filename=str(LEARNING_RATE) + "_" + str(LAMBDA_IDENTITY) + "_" + str(
                            LAMBDA_CYCLE) + CHECKPOINT_CRITIC_Z)



val_dataset = PhotoMonetDataset(
    root_monet=BASE_DIR + "/monet_jpg", root_photo=BASE_DIR + "/photo_jpg",
    transform=transforms_val,photo_len=7000
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
)

val_fn(gen_M, val_loader)




In [None]:
from zipfile import ZipFile
from os.path import basename

# create a ZipFile object
with ZipFile('images.zip', 'w') as zipObj:
   # Iterate over all the files in directory
   for folderName, subfolders, filenames in os.walk("/kaggle/working/images"):
       for filename in filenames:
           #create complete filepath of file in directory
           filePath = os.path.join(folderName, filename)
           # Add file to zip
           zipObj.write(filePath, basename(filePath))
        
        
for dirname, _, filenames in os.walk('/kaggle/working/images'):
    for filename in filenames:
        os.remove(os.path.join(dirname, filename))
        print(os.path.join(dirname, filename))
        
os.rmdir('/kaggle/working/images')
# os.remove('/kaggle/working/images.zip')