# Monet Painting GANs

In [None]:
import os, glob, random, shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torchvision import transforms
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
import torchvision.datasets as dset
from torchvision.utils import make_grid
from torchvision.io import read_image
import torch.nn.functional as F

## Introduction

In this project, we will employ a GAN (Generative Adversarial Network) architecture, which comprises of two neural network models: the generator and the discriminator. These two models are designed to work in opposition to each other, with the generator attempting to deceive the discriminator by generating images that are similar to Monet's style, and the discriminator attempting to accurately distinguish between real Monet-style images and the generated ones. This adversarial process enables the generator to continuously improve its image generation capability until it finally produces highly-realistic Monet-style images that can be difficult to distinguish from genuine ones.

The dataset provided for this project consists of two main categories of images: Monet paintings and photographs. The Monet category includes a total of 300 high-quality painting images, each sized at 256x256 pixels, and provided in both JPEG and TFRecord formats (the same set of images). Similarly, the photograph category contains 7028 high-resolution photos, also sized at 256x256 pixels, offered in both JPEG and TFRecord formats. With this diverse set of images, our GAN architecture will be trained to learn to generate Monet-style images that closely resemble the paintings in the Monet directory using the real photographs provided in the photo directory.
word count: 105, token 

## EDA

We use `ImgTransform` to augment images before loading datasets. This introduces more variety during training and improves learning, especially with a limited number of Monet paintings. We apply basic transformations using `torchvision.transforms`, like `Resize`, `RandomHorizontalFlip`, and `RandomVerticalFlip`. These transformations are only necessary during model training, specified using the stage argument. Finally, we scale the images down for better convergence.

In [None]:
class ImgTransform:
    def __init__(self, img_size=256):
        self.train_transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        
        self.test_transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        
        
    def __call__(self, img, stage="train"):
        if stage == "train":
            img = self.train_transform(img)
        else:
            img = self.test_transform(img)
        return img

We create a custom `torch.utils.data.Dataset` class with `__init__`, `__len__`, and `__getitem__` methods to load and store datasets. 

We use the stage argument to differentiate between training and prediction datasets, where the training dataset includes both Monet paintings and photos while the prediction dataset only includes photos.

In [None]:
class GANDataset(Dataset):
    def __init__(self, photo_files, monet_files, transform, stage='train'):
        self.photo_files = photo_files
        self.monet_files = monet_files
        self.transform = transform
        self.stage = stage

    def __len__(self):
        photo_len = len(self.photo_files)
        monet_len = len(self.monet_files)
        if self.stage == 'train':
            return min(photo_len,monet_len)
        else:
            return photo_len
        
    def __getitem__(self, idx):     
        if self.stage == 'train':
            monet_img = Image.open(self.monet_files[idx])
            photo_img = Image.open(self.photo_files[idx])
            monet_img = self.transform(monet_img,self.stage)
            photo_img = self.transform(photo_img,self.stage)
            return photo_img, monet_img
        else:
            photo_img = Image.open(self.photo_files[idx])
            photo_img = self.transform(photo_img,self.stage)
            return photo_img
    

To load and iterate through datasets, we use `torch.utils.data.DataLoader`. To organize data processing, we create a datamodule using `pl.LightningDataModule`. Important methods include `setup` to create and transform datasets, `train_dataloader` to generate the training dataloader, and `test_dataloader` to generate the testing dataloader. Other methods can be found here. 


In [None]:
class GANDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=8,transform=None):
        super().__init__()
        self.monet_files = sorted(glob.glob("/kaggle/input/gan-getting-started/monet_jpg/*.jpg"))
        self.photo_files = sorted(glob.glob("/kaggle/input/gan-getting-started/photo_jpg/*.jpg"))
        self.batch_size = batch_size
        self.transform = transform
        
    def setup(self, stage):
        if stage == "train":
            self.train = GANDataset(monet_files=self.monet_files,photo_files=self.photo_files,
                                    transform=self.transform,stage=stage)
        else:
            self.test = GANDataset(monet_files=self.monet_files,photo_files=self.photo_files,
                                    transform=self.transform,stage=stage)
            
    def train_dataloader(self):
        return DataLoader(self.train,batch_size=self.batch_size,shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size,shuffle=False)

Below are some sample photos and Monet paintings. Our goal is to add Monet's style to the photos.

In [None]:
def display_img(img, nrow=4, title=""):
    img = img.detach().cpu()*0.5 + 0.5
    img_tmp = make_grid(img, nrow=nrow).permute(1, 2, 0)
    plt.figure(figsize=(18, 8))
    plt.imshow(img_tmp)
    plt.axis("off")
    plt.title(title)
    plt.show()

dm = GANDataModule(batch_size=8,transform=ImgTransform(img_size=256))
dm.setup("train")
dm.setup("test")

dataloader = dm.train_dataloader()
photo, monet = next(iter(dataloader))

In [None]:
display_img(monet,title='Monet Pictures')

In [None]:
display_img(photo,title='Photos')

## Model Building

### Generator

The CycleGAN generator is built using the `U-Net` architecture, which is designed with a U-shaped network consisting of downsampling and upsampling blocks with skip connections. This structure allows for the information to flow seamlessly between the encoder and decoder layers, allowing for better feature extraction and image reconstruction capabilities. By using the U-Net architecture, the CycleGAN generator is able to effectively generate high-quality images that are visually pleasing and accurate representations of the output domain.

For instance, the generator architecture looks like this:

![U-Net architecture](https://www.researchgate.net/publication/342456648/figure/fig1/AS:906478919626752@1593132824738/U-Net-architecture-diagram-modified-from-the-original-study-27-Green-yellow-boxes_W640.jpg)

In [None]:
# Define the generator network - based on the U-net architecture
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        def downsample(in_channels, out_channels, normalize=True):
            layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2))
            return layers
        
        def upsample(in_channels, out_channels, dropout=False):
            layers = [nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)]
            layers.append(nn.InstanceNorm2d(out_channels))
            if dropout:
                layers.append(nn.Dropout(0.5))
            layers.append(nn.ReLU())
            return layers
        
        # Encoder layers (downsampling):
        self.downsample_1 = nn.Sequential(*downsample(3, 64, normalize=False))
        self.downsample_2 = nn.Sequential(*downsample(64, 128))
        self.downsample_3 = nn.Sequential(*downsample(128, 256))
        self.downsample_4 = nn.Sequential(*downsample(256, 512))
        self.downsample_5 = nn.Sequential(*downsample(512, 512))
        self.downsample_6 = nn.Sequential(*downsample(512, 512))
        self.downsample_7 = nn.Sequential(*downsample(512, 512))
        self.downsample_8 = nn.Sequential(*downsample(512, 512, normalize=False))
        
        # Decoder layers (upsampling):
        self.upsample_1 = nn.Sequential(*upsample(512, 512, dropout=True))
        self.upsample_2 = nn.Sequential(*upsample(1024, 512, dropout=True))
        self.upsample_3 = nn.Sequential(*upsample(1024, 512, dropout=True))
        self.upsample_4 = nn.Sequential(*upsample(1024, 512))
        self.upsample_5 = nn.Sequential(*upsample(1024, 256))
        self.upsample_6 = nn.Sequential(*upsample(512, 128))
        self.upsample_7 = nn.Sequential(*upsample(256, 64))
        self.upsample_8 = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        # Encoder pass (downsampling):
        down_1 = self.downsample_1(x)
        down_2 = self.downsample_2(down_1)
        down_3 = self.downsample_3(down_2)
        down_4 = self.downsample_4(down_3)
        down_5 = self.downsample_5(down_4)
        down_6 = self.downsample_6(down_5)
        down_7 = self.downsample_7(down_6)
        down_8 = self.downsample_8(down_7)
        
        # Decoder pass (upsampling):
        up_1 = self.upsample_1(down_8)
        up_2 = self.upsample_2(torch.cat([up_1, down_7], 1))
        up_3 = self.upsample_3(torch.cat([up_2, down_6], 1))
        up_4 = self.upsample_4(torch.cat([up_3, down_5], 1))
        up_5 = self.upsample_5(torch.cat([up_4, down_4], 1))
        up_6 = self.upsample_6(torch.cat([up_5, down_3], 1))
        up_7 = self.upsample_7(torch.cat([up_6, down_2], 1))
        up_8 = self.upsample_8(torch.cat([up_7, down_1], 1))
        
        return up_8


In [None]:
## test generator
gen_net = Generator()
out = gen_net(photo)
gen_net

### Discriminator

CycleGAN employs a unique approach to image discrimination by using the `PatchGAN` discriminator, which unlike other conventional networks, outputs a matrix of values instead of a single probability of the input image being real or fake. Each value of the output matrix corresponds to a specific portion of the input image, allowing for more precise evaluation. Values closer to 1 indicate real classification, while values closer to 0 indicate fake classification. This method provides improved discrimination capabilities, resulting in better quality and more accurate output images.

For instance, the discriminator architecture looks like this:

![Discriminator architecture](https://www.researchgate.net/profile/Marija-Jegorova-2/publication/353016853/figure/fig2/AS:1042566166355968@1625578552453/Example-of-CycleGAN-architecture-Discriminator-PatchGAN-It-is-a-fully-convolutional_W640.jpg)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_channels, out_channels, stride, normalize):
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_channels, affine=True, track_running_stats=True))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, stride=2, normalize=False),
            *discriminator_block(64, 128, stride=2, normalize=True),
            *discriminator_block(128, 256, stride=2, normalize=True),
            *discriminator_block(256, 512, stride=1, normalize=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [None]:
## test discriminator
dis_net = Discriminator()
out = dis_net(photo)
dis_net

### CycleGAN

We build CycleGAN based on the above `Generator` and `Discriminator` network compenents:

1. Photo2Monet Generator
2. Monet2Photo Generator
3. Monet Discriminator
4. Photo Discriminator


![](https://1.bp.blogspot.com/-F7lW08tA2t0/X5424ebzBqI/AAAAAAAAKiA/zq24RKnPc5wwDvjqS7EUuXDzMPb_2rJaACLcBGAsYHQ/s806/Google%2BChromeScreenSnapz098.jpg)

In [None]:
class CycleGAN(pl.LightningModule):
    def __init__(self, lr=2e-4, betas=(0.5, 0.999), lambda_w=10, display_epochs=30, photo_samples=photo):
        super().__init__()
        self.lr = lr
        self.betas = betas
        self.lambda_w = lambda_w
        self.display_epochs = display_epochs
        self.photo_samples = photo_samples
        self.loss_history = []
        self.epoch_count = 0
        
        self.gan_photo_monet = Generator()
        self.gan_monet_photo = Generator()
        self.dis_monet = Discriminator()
        self.dis_photo = Discriminator()

        
    def forward(self, z):
        return self.gan_photo_monet(z)
    
    def adv_criterion(self, y_hat, y):
        return F.binary_cross_entropy_with_logits(y_hat, y)
    
    def recon_criterion(self, y_hat, y):
        return F.l1_loss(y_hat, y)
    
    def adv_loss(self, real_X, disc_Y, gen_XY):
        fake_Y = gen_XY(real_X)
        disc_fake_Y_hat = disc_Y(fake_Y)
        adv_loss_XY = self.adv_criterion(disc_fake_Y_hat, torch.ones_like(disc_fake_Y_hat))
        return adv_loss_XY, fake_Y
    
    def id_loss(self, real_X, gen_YX):
        id_X = gen_YX(real_X)
        id_loss_X = self.recon_criterion(id_X, real_X)
        return id_loss_X
    
    def cycle_loss(self, real_X, fake_Y, gen_YX):
        cycle_X = gen_YX(fake_Y)
        cycle_loss_X = self.recon_criterion(cycle_X, real_X)
        return cycle_loss_X
        
    def gen_loss(self, real_X, real_Y, gen_XY, gen_YX, disc_Y):
        adv_loss_XY, fake_Y = self.adv_loss(real_X, disc_Y, gen_XY)

        id_loss_Y = self.id_loss(real_Y, gen_XY)
        
        cycle_loss_X = self.cycle_loss(real_X, fake_Y, gen_YX)
        cycle_loss_Y = self.cycle_loss(real_Y, gen_YX(real_Y), gen_XY)
        cycle_loss = cycle_loss_X + cycle_loss_Y
        
        gen_loss_XY = adv_loss_XY + 0.5*self.lambda_w*id_loss_Y + self.lambda_w*cycle_loss
        return gen_loss_XY
    
    def disc_loss(self, real_X, fake_X, disc_X):
        disc_fake_hat = disc_X(fake_X.detach())
        disc_fake_loss = self.adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
        
        disc_real_hat = disc_X(real_X)
        disc_real_loss = self.adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
        
        disc_loss = (disc_fake_loss+disc_real_loss) / 2
        return disc_loss
    
    def configure_optimizers(self):
        params = {
            "lr": self.lr,
            "betas": self.betas,
        }
        opt_gan_photo_monet = torch.optim.Adam(self.gan_photo_monet.parameters(), **params)
        opt_gan_monet_photo = torch.optim.Adam(self.gan_monet_photo.parameters(), **params)
        
        opt_dis_monet = torch.optim.Adam(self.dis_monet.parameters(), **params)
        opt_dis_photo = torch.optim.Adam(self.dis_photo.parameters(), **params)
        
        return [opt_gan_photo_monet, opt_gan_monet_photo, opt_dis_monet, opt_dis_photo], []
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        real_M, real_P = batch
        if optimizer_idx == 0:
            gen_loss_PM = self.gen_loss(real_P, real_M, self.gan_photo_monet, self.gan_monet_photo, self.dis_monet)
            return gen_loss_PM
        if optimizer_idx == 1:
            gen_loss_MP = self.gen_loss(real_M, real_P, self.gan_monet_photo, self.gan_photo_monet, self.dis_photo)
            return gen_loss_MP
        
        if optimizer_idx == 2:
            disc_loss_M = self.disc_loss(real_M, self.gan_photo_monet(real_P), self.dis_monet)
            return disc_loss_M
        if optimizer_idx == 3:
            disc_loss_P = self.disc_loss(real_P, self.gan_monet_photo(real_M), self.dis_photo)
            return disc_loss_P
    
    def training_epoch_end(self, outputs):
        self.epoch_count += 1
        
        losses = []
        for j in range(4):
            loss = np.mean([out[j]["loss"].item() for out in outputs])
            losses.append(loss)
        self.loss_history.append(losses)
        
        if self.epoch_count%10 == 0:
            print(
                f"Epoch {self.epoch_count} -",
                f"gen_loss_PM: {losses[0]:.5f} -",
                f"gen_loss_MP: {losses[1]:.5f} -",
                f"disc_loss_M: {losses[2]:.5f} -",
                f"disc_loss_P: {losses[3]:.5f}",
            )
        
        if self.epoch_count%self.display_epochs==0 or self.epoch_count==1:
            gen_monets = self(self.photo_samples.to(self.device)).detach().cpu()
            display_img(
                torch.cat([self.photo_samples, gen_monets]),
                nrow=4,
                title=f"Epoch {self.epoch_count}: Photo-to-Monet Translation",
            )
            
    def predict_step(self, batch, batch_idx):
        return self(batch)
    

In [None]:
model = CycleGAN()
model

### Training Process

In [None]:
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    logger=False,
    enable_checkpointing=False,
    max_epochs=120,
)


dm = GANDataModule(batch_size=8,transform=ImgTransform(img_size=256))
dm.setup("train")


model = CycleGAN()
trainer.fit(model, datamodule=dm)

In [None]:
labels = ["gen_loss_photo2monet", "gen_loss_monet2photo", "disc_loss_monet", "disc_loss_photo"]
titles = ["Generator Loss Curves", "Discriminator Loss Curves"]
num_epochs = len(model.loss_history)
plt.figure(figsize=(18, 4.5))
for j in range(4):
    if j%2 == 0:
        plt.subplot(1, 2, (j//2)+1)
        plt.title(titles[j//2])
        plt.ylabel("Loss")
        plt.xlabel("Epoch")
    plt.plot(
        np.arange(1, num_epochs+1),
        [losses[j] for losses in model.loss_history],
        label=labels[j],
    )
    plt.legend(loc="upper right")

In [None]:
print(model)

## Submission

In [None]:
dm_test = GANDataModule(batch_size=8,transform=ImgTransform(img_size=256))
dm_test.setup("test")

In [None]:
predictions = trainer.predict(model, (dm_test.test_dataloader()))


In [None]:
 os.makedirs("../images", exist_ok=True)

idx = 0
for tensor in predictions:
    for monet in tensor:

        monet = monet.squeeze()
        monet = monet * 0.5 + 0.5
        monet = monet * 255
        monet = monet.detach().cpu().numpy().astype(np.uint8)
        
        monet = np.transpose(monet, [1,2,0])
        
        monet = Image.fromarray(monet)
        monet.save(f"../images/{idx}.jpg")
        #save_image((monet.squeeze()*0.5+0.5), fp=f"../images/{idx}.jpg")
        idx += 1

shutil.make_archive("/kaggle/working/images", "zip", "/kaggle/images")

## Conclusion

In conclusion, our project successfully trained a GAN model using CycleGAN to generate realistic images in Monet's art style. Our custom dataset of photos and Monet's art pieces, along with various image preprocessing techniques, aided in producing high-quality images. Throughout the training process, we monitored the loss functions and evaluated the GAN model's performance on a validation dataset. The generated images demonstrate that the GAN model can emulate Monet's art style with a high level of fidelity, with room for further improvement. Future work can focus on enhancing the GAN model's performance, developing objective evaluation metrics, and investigating ways to apply the generated images to other use cases, such as art curation and creative design.
