In [None]:
import os
# make folder for the outputs of the model
os.mkdir("/kaggle/working/images")
os.mkdir("./validation")


Configuration

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
LEARNING_RATE = 0.00021
LAMBDA_IDENTITY = 0.1
LAMBDA_CYCLE = 15
NUM_WORKERS = 4
NUM_EPOCHS = 200
LOAD_MODEL = False
SAVE_MODEL = False
PREDICTION = True
CHECKPOINT_GEN_H = "../input/model-state/genh.pth.tar"
CHECKPOINT_GEN_Z = "../input/model-state/genz.pth.tar"
CHECKPOINT_CRITIC_H = "../input/model-state/critich.pth.tar"
CHECKPOINT_CRITIC_Z = "../input/model-state/criticz.pth.tar"
TRIAN_PATH = "../input/gan-getting-started"
VALID_PATH = "../input/gan-getting-started"
OUTPUT_TRAIN_PATH="./train"
OUTPUT_PREDICT_PATH="./validation"
torch.cuda.empty_cache()


Objects for the model:
1. Generator: We create a generetor with convolotion blockes that shrink the image dimentions to shring representation of the image, and then expand them to the other domain image. 
2. Discriminator: We create a discriminator with convolotion blocks, that give us a representian of the image in low dimentions, and then we decide is the image is origin or fake with sigmoid on the shring representation. This way the discriminator learn to represente the image according to the domain that he need to decide if it real or fake. 

In [None]:
import torch.nn as nn
import torchvision.models as models

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))


In [None]:
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))


1. Transforms: we create transform that will generate as different photo in each call, the transform is probablistic and have caple of different options for output. This way we can get good perfomence even when we use only 30 photo in our data set.
2. MonetRealDataset: We create a data set for this project, the data set provide us each time one photo of monet, and one photo from the real photos, and the names of the photos.

In [None]:
from torch.utils.data import Dataset
import numpy as np


trianTransforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)
predictTransforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.OneOf([
            A.Flip(p=1),
            A.OpticalDistortion(p=1),
            A.GlassBlur(p=1)
        ], p=0.35),
        A.OneOf([
            A.GaussianBlur(p=1),
            A.FancyPCA(p=1),
        ], p=0.35),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

class MonetRealDataset(Dataset):
    def __init__(self, root_monet, root_real, transform=None,limit =False):
        # monet = monet
        # real == real
        self.root_monet = root_monet
        self.root_real = root_real
        self.transform = transform

        self.monet_images = os.listdir(root_monet)
        if limit:
            self.monet_images = self.monet_images[:30]
        self.real_images = os.listdir(root_real)
        self.length_dataset = max(len(self.monet_images), len(self.real_images)) # 1000, 1500
        self.monet_len = len(self.monet_images)
        self.real_len = len(self.real_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        monet_name = self.monet_images[index % self.monet_len]
        real_name = self.real_images[index % self.real_len]

        monet_path = os.path.join(self.root_monet, monet_name)
        real_path = os.path.join(self.root_real, real_name)

        monet_img = np.array(Image.open(monet_path).convert("RGB"))
        real_img = np.array(Image.open(real_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=monet_img, image0=real_img)
            monet_img = augmentations["image"]
            real_img = augmentations["image0"]

        return real_img, monet_img,real_name, monet_name


* Data_loading: We create a data set for train and for validation, and create the data loader for each one.
* Initialize: The function initialize all the models parts - discriminators and generators, and the loss function that we will use in the model.
* Dave_checkpoint:The function save us the model for future use.
* Load_checkpoint:The function load us the model and return as model that we can use.

In [None]:
def data_loading():

    dataset = MonetRealDataset(
        root_real= TRIAN_PATH + "/monet_jpg", root_monet=TRIAN_PATH+"/photo_jpg", transform=predictTransforms, limit=True
    )
    val_dataset = MonetRealDataset(
       root_real= VALID_PATH + "/monet_jpg", root_monet= VALID_PATH + "/photo_jpg", transform=trianTransforms
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    

    return dataset, val_dataset, val_loader, loader



def initialize():
    disc_H = Discriminator(in_channels=3).to(DEVICE)
    disc_Z = Discriminator(in_channels=3).to(DEVICE)
    gen_Z = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    gen_H = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )
    # initialize opt gen
    opt_gen = optim.Adam(
        list(gen_Z.parameters()) + list(gen_H.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )
    # initialize loss
    L1 = nn.L1Loss()
    mse = nn.MSELoss()
    
    return disc_H, disc_Z, gen_Z, gen_H, opt_disc, opt_gen, L1, mse


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


* Prediction: The function get the model and the data loader, and for each photo in the data loader, it create the fake monet of the photo and save it. In the end of the function we create a zip file with all the photos that we made from the model.
* Train: The function get the model and the data loader, and run one epoch of the data loader on the model. We first training the generators, and after this we training the disciminators, and in the end we complite the cycle by taking the fake photo and use the other generetor to return it to the origin photo, and we use min squre error as a loss function for the cycle.
* End_of_epoch: At the end of the epoch, we chosing 4 photo and use the generetor with those photo, and present it to the user to see in the eyes the model output in the end of the epoch.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import shutil
import PIL
def prediction(gen_Z,loader):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)
    for idx, (monet, real,monet_name,real_name) in enumerate(loop):
        real = real.to(DEVICE)
        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_monet = gen_Z(real)
            save_image(fake_monet*0.5+0.5, f"{OUTPUT_PREDICT_PATH}/{idx}.png")
    # make zip for submit
    shutil.make_archive("/kaggle/working/images", 'zip', OUTPUT_PREDICT_PATH)

GlobalIdx =0
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler, num_epoch):
    global GlobalIdx
    H_reals = 0
    H_fakes = 0
    fake_real = None
    fake_monet = None
    loop = tqdm(loader, leave=True)

    for idx, (monet, real, monet_name, real_name ) in enumerate(loop):
        monet = monet.to(DEVICE)
        real = real.to(DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_real = gen_H(monet)
            D_H_real = disc_H(real)
            D_H_fake = disc_H(fake_real.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_monet = gen_Z(real)
            D_Z_real = disc_Z(monet)
            D_Z_fake = disc_Z(fake_monet.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_Z_loss)/2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_real)
            D_Z_fake = disc_Z(fake_monet)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            # cycle loss
            cycle_monet = gen_Z(fake_real)
            cycle_real = gen_H(fake_monet)
            cycle_monet_loss = l1(monet, cycle_monet)
            cycle_real_loss = l1(real, cycle_real)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_monet = gen_Z(monet)
            identity_real = gen_H(real)
            identity_monet_loss = l1(monet, identity_monet)
            identity_real_loss = l1(real, identity_real)

            # add all togethor
            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_monet_loss * LAMBDA_CYCLE
                + cycle_real_loss * LAMBDA_CYCLE
                + identity_real_loss * LAMBDA_IDENTITY
                + identity_monet_loss * LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        if idx % 200 == 0:
            if SAVE_MODEL and idx !=0:
                save_checkpoint(gen_H, opt_gen, filename=CHECKPOINT_GEN_H)
                save_checkpoint(gen_Z, opt_gen, filename=CHECKPOINT_GEN_Z)
                save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITIC_H)
                save_checkpoint(disc_Z, opt_disc, filename=CHECKPOINT_CRITIC_Z)
        loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
        GlobalIdx += 1
    end_of_epoch(num_epoch,loader,gen_Z)
    return G_loss
    
def end_of_epoch(epoch,loader, gen_monet):
    _, axel = plt.subplots(4, 2, figsize=(10, 15))
    for i in range(4):
        (monet, real, monet_name, real_name ) = next(iter(loader))
        with torch.cuda.amp.autocast():
            device_real = real.to(DEVICE)
            prediction = gen_monet(device_real)[0]
            prediction = prediction.cpu().detach().numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            prediction= np.moveaxis(prediction, 0, -1)
            real = (real[0] * 127.5 + 127.5).numpy().astype(np.uint8)
            real= np.moveaxis(real, 0, -1)
            axel[i, 0].imshow(real)
            axel[i, 1].imshow(prediction)
            axel[i, 0].set_title("Input image")
            axel[i, 1].set_title("Monet")
            axel[i, 0].axis("off")
            axel[i, 1].axis("off")
    
    plt.show()
    plt.close()
 

In [None]:
import sys
import torch
#import tensorflow.Tensor as Tensor
from fastai.vision.all import show_image
#import torch.Tensor as Tensor
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np
#from IPython.display import display, Image as display, ImageD

def main():

    # initialize
    disc_H, disc_Z, gen_Z, gen_H, opt_disc, opt_gen, L1, mse  = initialize()
    
    # loading params
    if LOAD_MODEL:
        load_checkpoint(
           CHECKPOINT_GEN_H, gen_H, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_Z, gen_Z, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_H, disc_H, opt_disc, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, LEARNING_RATE,
        )
    dataset, val_dataset, val_loader, loader = data_loading()
    
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    total_loss = []
    # start training
    for epoch in range(NUM_EPOCHS):
        loss = train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler,epoch)
        total_loss.append(loss)
        iter = []
        for i in range(len(total_loss)):
            iter.append(i)
        plt.plot(iter, total_loss, color='black', linestyle='dashed', linewidth=3,
                marker='o', markerfacecolor='gray', markersize=4)
        plt.xlabel("Iteration Number")
        plt.ylabel("Loss")
        plt.title("Loss Rate")
        plt.show()
    if PREDICTION:
        prediction(gen_Z, val_loader)
  

main()
