## Introduction and Problem Statement
In this challenge, we use Generative adversarial networks or GANS. While GANs have broad generative applications, our GAN learns to mimic complex artistic styles, to produce images in the style of Monet. The GAN includes two parts: a generator that creates the images and a discriminator that judges them, and the goal of the generator is to fool the discriminator into accepting its images as real artworks. The generator will be trained to create 7,000 to 10,000 images so convincing that the discriminator thinks they are real Monet paintings.

## About the Data

The dataset includes 4 files:
* monet_jpg which includes 300 Monet paintings sized 256x256 in JPEG format
* monet_tfrec which includes 300 Monet paintings sized 256x256 in TFRecord format
* photo_jpg which includes 7028 photos sized 256x256 in JPEG format
* photo_tfrec which includes 7028 photos sized 256x256 in TFRecord format

The monet directories contain Monet paintings and will be used to train the model.
The photo directories contain images that can be stylized by the GAN into the Monet style, and later submitted.

Source of information about the data: https://www.kaggle.com/competitions/gan-getting-started/data

In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import os
import math
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Dataset, random_split, DataLoader

import torchvision.models as models
import torchvision.transforms as transforms

from tqdm.notebook import tqdm
import itertools
import time
import shutil

In [None]:
path_monet = "/kaggle/input/gan-input/monet_jpg/"
path_photo = "/kaggle/input/gan-input/photo_jpg/"

## Exploratory Data Analysis (EDA)
This exploratory data analysis includes:
1. Displaying a sample of the Monet paintings and images in our directories.
2. Visualization of histogram distributions of the RGB color channels present in a few Monet paintings vs the photos. This may help identify if there may be any instrinsic differences in the color channel distribution of Monet paintings.
3. Average RGB color channel values of Monet paintings vs the photographs. This is a more aggregated view and could tell us if there is a significant difference in the color channel distribution of Monet painting wholistically.
4. Key takeaways from EDA

In [None]:
def display_images_in_grid(directory, rows=4, cols=4):
    num_samples = rows * cols
    
    image_files = os.listdir(directory)[:num_samples]
    
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    axes = axes.flatten()
    
    for idx, ax in enumerate(axes):
        if idx < len(image_files):
            image_path = os.path.join(directory, image_files[idx])
            image = cv2.imread(image_path)
            if image is not None:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                ax.imshow(image)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()


In [None]:
display_images_in_grid(path_monet)

In [None]:
display_images_in_grid(path_photo)

In [None]:
def plot_rgb_histograms(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.axis("off")
    plt.title('Original Image')
    
    plt.subplot(1, 2, 2)
    colors = ['red', 'green', 'blue']
    for i, color in enumerate(colors):
        plt.hist(
            img[:, :, i].ravel(), 
            bins=256, 
            color=color,
            alpha=0.5,
            label=f'{color.capitalize()} Channel'
        )
    
    plt.title('Color Channel Histogram')
    plt.xlabel('Pixel Intensity')
    plt.ylabel('Density')
    plt.xlim(0, 255)
    plt.legend()
    
    plt.tight_layout()
    plt.show()

In [None]:
#monet pictures
plot_rgb_histograms(path_monet + os.listdir(path_monet)[0])
plot_rgb_histograms(path_monet + os.listdir(path_monet)[1])
plot_rgb_histograms(path_monet + os.listdir(path_monet)[2])

In [None]:
#photographs
plot_rgb_histograms(path_photo + os.listdir(path_photo)[0])
plot_rgb_histograms(path_photo + os.listdir(path_photo)[1])
plot_rgb_histograms(path_photo + os.listdir(path_photo)[2])

In [None]:
def compute_channel_averages(image_directory):
    total_sum = np.zeros(3)
    count = 0

    for filename in os.listdir(image_directory):
        image_path = os.path.join(image_directory, filename)
        image = Image.open(image_path)
        image_array = np.array(image)

        if image_array.shape == (256, 256, 3):
            total_sum += np.sum(image_array, axis=(0, 1))
            count += 1

    return total_sum / count if count else np.zeros(3)

monet_dir = path_monet
monet_total = np.zeros(3)
monet_count = 0

for file in os.listdir(monet_dir):
    img = Image.open(os.path.join(monet_dir, file))
    img_array = np.array(img)
    if img_array.shape == (256, 256, 3):
        monet_total += np.sum(img_array, axis=(0, 1))
        monet_count += 1

average_monet_values = monet_total / monet_count if monet_count else np.zeros(3)

photo_dir = path_photo
photo_total = np.zeros(3)
photo_count = 0

for file in os.listdir(photo_dir):
    img = Image.open(os.path.join(photo_dir, file))
    img_array = np.array(img)
    if img_array.shape == (256, 256, 3):
        photo_total += np.sum(img_array, axis=(0, 1))
        photo_count += 1

average_photos_values = photo_total / photo_count if photo_count else np.zeros(3)

channels = ['Red', 'Green', 'Blue']
x = np.arange(len(channels))

plt.figure(figsize=(8, 5))
plt.plot(x, average_monet_values, marker='o', linestyle='-', color='b', label='Monet Paintings', markersize=8)
plt.plot(x, average_photos_values, marker='s', linestyle='-', color='g', label='Photographs', markersize=8)

plt.xlabel('Color Channel')
plt.ylabel('Average Value')
plt.title('Average Channel Values: Monet Paintings vs. Photographs')
plt.xticks(x, channels)
plt.legend()
plt.grid(True)
plt.show()

### Key Takeaways from the EDA
The density of color pixels seems to be much higher for Monet paintings compared to the photographs. This can be seen in the histogram distributions of the RGB color channels for the subset of Monet paintings compared to the subset of photos. This can also be seen in the visualization of average RGB channel values of Monet paintings vs photographs which seems to suggest there is a significant difference in the color distributions of Monet paintings from photographs.
Monet's artistic style was known to be vibrant. He used pure, unmixed colors, and his paintings showed intensity of color. This aligns with the color channel distribution we observed in his work.

In [None]:
device = 'cpu'

## Data Preprocessing
1. Initialize the training dataset for the model
2. Load images
3. **Data preprocessing**, which includes resizing and normalizing the images, and returns them as tensors.

In [None]:
class ImageDataset(Dataset):
    def __init__(self, path_monet, path_photo, size=(256, 256), normalize=True):
        super().__init__()
        self.monet_files = self._load_files(path_monet)
        self.photo_files = self._load_files(path_photo)
        self.transform = self._load_transform(size, normalize)

    def _load_files(self, directory):
        """Load and return a list of file paths from a directory."""
        return [os.path.join(directory, f) for f in os.listdir(directory) if f.lower().endswith(('png', 'jpg', 'jpeg'))]

    def _load_transform(self, size, normalize):
        """Create the transformation pipeline for the images."""
        transformations = [transforms.Resize(size), transforms.ToTensor()]
        if normalize:
            transformations.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        return transforms.Compose(transformations)

    def __getitem__(self, idx):
        """Retrieve a random pair of images, one Monet painting and one photo."""
        monet_path = self.monet_files[idx % len(self.monet_files)]  # Safe index wrapping
        photo_path = np.random.choice(self.photo_files)
        
        monet_img = self.transform(Image.open(monet_path))
        photo_img = self.transform(Image.open(photo_path))
        
        return photo_img, monet_img

    def __len__(self):
        """Return the number of Monet paintings, assuming we always have more photos."""
        return len(self.monet_files)

In [None]:
img_ds = ImageDataset(path_monet, path_photo)
img_dl = DataLoader(img_ds, batch_size=1, pin_memory=True)

In [None]:
def reverse_normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    mean = torch.tensor(mean, dtype=img.dtype, device=img.device)
    std = torch.tensor(std, dtype=img.dtype, device=img.device)
    
    img = img * std[:, None, None] + mean[:, None, None]
        
    return img

## Building the Model
Constructing the GAN model involves a few steps:
1. Upsampling functionality
2. BUild the convolutional layer
3. Build residual block for the generator
4. Build the core generator and discriminator functionality
5. Initialize network weights
6. Get a sample fake/generated images
7. Build a Cycle GAN

In [None]:
def Upsample(in_channels, out_channels, use_dropout=True, dropout_ratio=0.5):
    """
    Creates an upsampling block for neural networks.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        use_dropout (bool): Whether to use dropout in the upsampling block.
        dropout_ratio (float): Dropout ratio if dropout is used.

    Returns:
        nn.Sequential: The upsampling block.
    """
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1),
        nn.InstanceNorm2d(out_channels),
        nn.GELU()
    ]
    
    if use_dropout:
        # Insert dropout before the activation function (GELU)
        layers.insert(2, nn.Dropout(dropout_ratio))
    
    return nn.Sequential(*layers)


In [None]:
def ConvLayer(in_channels, out_channels, kernel_size=3, stride=2, use_leaky=True, use_inst_norm=True, use_pad=True):
    """
    Constructs a convolutional layer with configurable options for activation and normalization.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int): Size of the convolution kernel.
        stride (int): Stride of the convolution.
        use_leaky (bool): If True, use LeakyReLU activation, otherwise use GELU.
        use_inst_norm (bool): If True, use instance normalization, otherwise use batch normalization.
        use_pad (bool): If True, apply padding to keep dimensions consistent, otherwise no padding.

    Returns:
        nn.Sequential: The constructed convolutional layer.
    """
    padding = kernel_size // 2 if use_pad else 0

    conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=True)
    norm = nn.InstanceNorm2d(out_channels) if use_inst_norm else nn.BatchNorm2d(out_channels)

    activation = nn.LeakyReLU(negative_slope=0.2, inplace=True) if use_leaky else nn.GELU()

    return nn.Sequential(conv, norm, activation)


In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_features, use_dropout=True, dropout_ratio=0.5):
        """
        Initializes the Residual Block used in the generator.

        Args:
            in_features (int): Number of input and output channels.
            use_dropout (bool): Whether to include dropout layers.
            dropout_ratio (float): Dropout ratio if dropout is used.
        """
        super(ResBlock, self).__init__()
        
        layers = [
            nn.ReflectionPad2d(1),
            ConvLayer(in_features, in_features, kernel_size=3, stride=1, use_leaky=False, use_pad=False, use_inst_norm=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3, 1, padding=0, bias=True),
            nn.InstanceNorm2d(in_features)
        ]

        if use_dropout:
            layers.insert(2, nn.Dropout(dropout_ratio))

        self.res = nn.Sequential(*layers)

    def forward(self, x):
        """
        Defines the forward pass for the Residual Block.

        Args:
            x (Tensor): Input tensor to the residual block.

        Returns:
            Tensor: Output tensor after adding the input to the block's output.
        """
        return x + self.res(x)


In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, num_res_blocks=6):
        """
        Initializes the Generator with convolutional and residual blocks.

        Args:
            in_channels (int): Number of channels in the input image.
            out_channels (int): Number of channels in the output image.
            num_res_blocks (int): Number of residual blocks to use.
        """
        super(Generator, self).__init__()

        initial_layers = [
            nn.ReflectionPad2d(3),
            ConvLayer(in_channels, 64, kernel_size=7, stride=1, use_leaky=False, use_pad=False, use_inst_norm=True)
        ]

        # Downsampling
        downsampling_layers = [
            ConvLayer(64, 128, kernel_size=3, stride=2, use_leaky=False),
            ConvLayer(128, 256, kernel_size=3, stride=2, use_leaky=False)
        ]

        # Residual blocks
        res_blocks = [ResBlock(256) for _ in range(num_res_blocks)]

        # Upsampling
        upsampling_layers = [
            Upsample(256, 128),
            Upsample(128, 64)
        ]

        # Output layer
        output_layers = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_channels, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        # Putting it all together
        self.gen = nn.Sequential(
            *(initial_layers + downsampling_layers + res_blocks + upsampling_layers + output_layers)
        )

    def forward(self, x):
        """
        Defines the forward pass of the generator.

        Args:
            x (Tensor): Input tensor to the generator.

        Returns:
            Tensor: Output tensor from the generator.
        """
        return self.gen(x)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, num_layers=4):
        """
        Initializes the Discriminator with convolutional layers and LeakyReLU activation.

        Args:
            in_channels (int): Number of channels in the input image.
            num_layers (int): Number of convolutional layers in the discriminator.
        """
        super(Discriminator, self).__init__()

        layers = [
            # Initial convolution layer
            nn.Conv2d(in_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        # Buidl convolutional layers based on num_layers
        for i in range(1, num_layers):
            in_chs = 64 * (2 ** (i - 1))
            out_chs = in_chs * 2
            stride = 1 if i == num_layers - 1 else 2
            layers.append(ConvLayer(in_chs, out_chs, kernel_size=4, stride=stride, use_leaky=True))

        # Final convolution layer
        layers.append(nn.Conv2d(out_chs, 1, kernel_size=4, stride=1, padding=1))

        # Sequential model
        self.disc = nn.Sequential(*layers)

    def forward(self, x):
        """
        Defines the forward pass of the discriminator.

        Args:
            x (Tensor): Input tensor to the discriminator.

        Returns:
            Tensor: Output tensor from the discriminator.
        """
        return self.disc(x)

In [None]:
def init_weights(net, init_type='normal', init_gain=0.02):
    """
    Initialize network weights.
    
    Args:
        net (nn.Module): Network to initialize.
        init_type (str): The name of an initialization method: 'normal' or 'xavier'.
        init_gain (float): Standard deviation of the normal distribution for weight initialization.

    """
    def init_func(m):
        classname = m.__class__.__name__
        # Apply initialization to layers
        if hasattr(m, 'weight') and ('Conv' in classname or 'Linear' in classname):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
                
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)

        elif 'BatchNorm2d' in classname:
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    net.apply(init_func)


### Helper functions for the Cycle GAN model

In [None]:
def update_req_grad(models, trainable=True):
    """
    Sets the 'requires_grad' property for all parameters in given models.
    
    Args:
        models (list of nn.Module): A list of PyTorch models whose parameters will be updated.
        trainable (bool): If True, the model's parameters will require gradients, otherwise they won't.
    """
    for model in models:
        for param in model.parameters():
            param.requires_grad = trainable


In [None]:
class sample_fake:
    def __init__(self, max_images=50):
        """
        Initializes the SampleFake object.

        Args:
            max_images (int): Maximum number of images to store.
        """
        self.max_images = max_images
        self.current_image_index = 0
        self.images = []

    def __call__(self, images):
        """
        Samples fake images and updates the stored images.

        Args:
            images (list of numpy.ndarray): List of generated fake images.

        Returns:
            list of numpy.ndarray: Sampled fake images.
        """
        sampled_images = []
        for image in images:
            if self.current_image_index < self.max_images:
                self.images.append(image)
                sampled_images.append(image)
                self.current_image_index += 1
            else:
                if np.random.rand() > 0.5:
                    index = np.random.randint(0, self.max_images)
                    sampled_images.append(self.images[index])
                    self.images[index] = image
                else:
                    sampled_images.append(image)
        return sampled_images


In [None]:
class lr_sched:
    def __init__(self, decay_start_epoch=100, total_epochs=200):
        """
        Initializes a learning rate scheduler with a linear decay after a specified start.

        Args:
            decay_start_epoch (int): Epoch from which the learning rate starts to decay.
            total_epochs (int): Total number of epochs after which the learning rate will reach zero.
        """
        self.decay_start_epoch = decay_start_epoch
        self.total_epochs = total_epochs

    def step(self, current_epoch):
        """
        Calculates the learning rate based on the current epoch.

        Args:
            current_epoch (int): The current epoch number.

        Returns:
            float: The scaled learning rate.
        """
        if current_epoch <= self.decay_start_epoch:
            return 1.0
        else:
            decay_progress = (current_epoch - self.decay_start_epoch) / (self.total_epochs - self.decay_start_epoch)
            return 1.0 - decay_progress


In [None]:
class AvgStats(object):
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.losses =[]
        self.its = []
        
    def append(self, loss, it):
        self.losses.append(loss)
        self.its.append(it)

### Cycle GAN Model Build

In [None]:
class CycleGAN(object):
    def __init__(self, in_ch, out_ch, epochs, device, start_lr=2e-4, lmbda=10, idt_coef=0.5, decay_epoch=0):
        self.epochs = epochs
        self.decay_epoch = decay_epoch if decay_epoch > 0 else int(self.epochs/2)
        self.lmbda = lmbda
        self.idt_coef = idt_coef
        self.device = device
        self.gen_mtp = Generator(in_ch, out_ch)
        self.gen_ptm = Generator(in_ch, out_ch)
        self.desc_m = Discriminator(in_ch)
        self.desc_p = Discriminator(in_ch)
        self.init_models()
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()
        self.adam_gen = torch.optim.Adam(itertools.chain(self.gen_mtp.parameters(), self.gen_ptm.parameters()),
                                         lr = start_lr, betas=(0.5, 0.999))
        self.adam_desc = torch.optim.Adam(itertools.chain(self.desc_m.parameters(), self.desc_p.parameters()),
                                          lr=start_lr, betas=(0.5, 0.999))
        self.sample_monet = sample_fake()
        self.sample_photo = sample_fake()
        gen_lr = lr_sched(self.decay_epoch, self.epochs)
        desc_lr = lr_sched(self.decay_epoch, self.epochs)
        self.gen_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.adam_gen, gen_lr.step)
        self.desc_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.adam_desc, desc_lr.step)
        self.gen_stats = AvgStats()
        self.desc_stats = AvgStats()
        
    def init_models(self):
        init_weights(self.gen_mtp)
        init_weights(self.gen_ptm)
        init_weights(self.desc_m)
        init_weights(self.desc_p)
        self.gen_mtp = self.gen_mtp.to(self.device)
        self.gen_ptm = self.gen_ptm.to(self.device)
        self.desc_m = self.desc_m.to(self.device)
        self.desc_p = self.desc_p.to(self.device)
        
    def train(self, photo_dl):
        for epoch in range(self.epochs):
            start_time = time.time()
            avg_gen_loss, avg_desc_loss = self.train_one_epoch(photo_dl, epoch)
            self.log_training(epoch, avg_gen_loss, avg_desc_loss, start_time)

    def train_one_epoch(self, photo_dl, epoch):
        avg_gen_loss = 0.0
        avg_desc_loss = 0.0
        t = tqdm(photo_dl, leave=False, total=len(photo_dl))

        for photo_real, monet_real in t:
            photo_real, monet_real = photo_real.to(self.device), monet_real.to(self.device)
            gen_loss, desc_loss = self.process_batch(photo_real, monet_real)
            avg_gen_loss += gen_loss
            avg_desc_loss += desc_loss
            t.set_postfix(gen_loss=gen_loss, desc_loss=desc_loss)

        return avg_gen_loss / len(photo_dl), avg_desc_loss / len(photo_dl)

    def process_batch(self, photo_real, monet_real):
        gen_loss = self.update_generators(photo_real, monet_real)
        desc_loss = self.update_discriminators(photo_real, monet_real)
        return gen_loss.item(), desc_loss.item()

    def update_generators(self, photo_real, monet_real):
        self.adam_gen.zero_grad()
    
        #Generate fakes imgs
        fake_photo = self.gen_mtp(monet_real)
        fake_monet = self.gen_ptm(photo_real)
    
        cycl_monet = self.gen_ptm(fake_photo)
        cycl_photo = self.gen_mtp(fake_monet)
    
        #identity mapping
        id_monet = self.gen_ptm(monet_real)
        id_photo = self.gen_mtp(photo_real)
    
        #generator losses
        idt_loss_monet = self.l1_loss(id_monet, monet_real) * self.lmbda * self.idt_coef
        idt_loss_photo = self.l1_loss(id_photo, photo_real) * self.lmbda * self.idt_coef
    
        cycle_loss_monet = self.l1_loss(cycl_monet, monet_real) * self.lmbda
        cycle_loss_photo = self.l1_loss(cycl_photo, photo_real) * self.lmbda
    
        # Adversarial losses to trick the discrimanots 
        real_label = torch.ones_like(self.desc_m(fake_monet), device=self.device)
        adv_loss_monet = self.mse_loss(self.desc_m(fake_monet), real_label)
        adv_loss_photo = self.mse_loss(self.desc_p(fake_photo), real_label)
    
        # Total generator loss
        total_gen_loss = cycle_loss_monet + adv_loss_monet + cycle_loss_photo + adv_loss_photo + idt_loss_monet + idt_loss_photo
        total_gen_loss.backward()
        self.adam_gen.step()
    
        return total_gen_loss

    def update_discriminators(self, photo_real, monet_real):
        update_req_grad([self.desc_m, self.desc_p], True)
        self.adam_desc.zero_grad()
    
        #generate fake imgs
        fake_photo = self.gen_mtp(monet_real).detach()
        fake_monet = self.gen_ptm(photo_real).detach()
    
        real_label = torch.ones_like(self.desc_m(monet_real), device=self.device)
        fake_label = torch.zeros_like(self.desc_m(fake_monet), device=self.device)
    
        #ensure real is classified as real, and fake as fake 
        monet_desc_real_loss = self.mse_loss(self.desc_m(monet_real), real_label)
        photo_desc_real_loss = self.mse_loss(self.desc_p(photo_real), real_label)
    
        monet_desc_fake_loss = self.mse_loss(self.desc_m(fake_monet), fake_label)
        photo_desc_fake_loss = self.mse_loss(self.desc_p(fake_photo), fake_label)
    
        #total loss
        total_desc_loss = (monet_desc_real_loss + monet_desc_fake_loss + photo_desc_real_loss + photo_desc_fake_loss) / 4
        total_desc_loss.backward()
        self.adam_desc.step()
    
        return total_desc_loss

    def log_training(self, epoch, avg_gen_loss, avg_desc_loss, start_time):
        self.gen_lr_sched.step()
        self.desc_lr_sched.step()
        time_elapsed = time.time() - start_time
        self.gen_stats.append(avg_gen_loss, time_elapsed)
        self.desc_stats.append(avg_desc_loss, time_elapsed)

In [None]:
gan = CycleGAN(3, 3, 1, device)
gan.train(img_dl)

## Model performance and results

In [None]:
def plot_gan_loss():
    plt.xlabel("Epochs")
    plt.ylabel("Losses")
    plt.plot(gan.gen_stats.losses, 'r', label='Generator Loss')
    plt.plot(gan.desc_stats.losses, 'b', label='Descriminator Loss')
    plt.legend()
    plt.show()

In [None]:
plot_gan_loss()

In [None]:
_, ax = plt.subplots(3, 2, figsize=(12, 12))
for i in range(3):
    photo_img, _ = next(iter(img_dl))
    pred_monet = gan.gen_ptm(photo_img.to(device)).cpu().detach()
    photo_img = reverse_normalize(photo_img)
    pred_monet = reverse_normalize(pred_monet)
    
    ax[i, 0].imshow(photo_img[0].permute(1, 2, 0))
    ax[i, 1].imshow(pred_monet[0].permute(1, 2, 0))
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet style Photo")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()

### Key Takeaways:
* The model is doing well: loss decreases over time/as the number of epochs increases
* The monet style images are actually quite decent!

### Image export for submission

In [None]:
class PhotoDataset(Dataset):
    def __init__(self, photo_dir, size=(256, 256), normalize=True):
        super().__init__()
        self.photo_dir = photo_dir
        self.photo_idx = dict()
        
        if normalize:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor()                               
            ])
        for i, fl in enumerate(os.listdir(self.photo_dir)):
            self.photo_idx[i] = fl

    def __getitem__(self, idx):
        photo_path = os.path.join(self.photo_dir, self.photo_idx[idx])
        photo_img = Image.open(photo_path)
        photo_img = self.transform(photo_img)
        return photo_img

    def __len__(self):
        return len(self.photo_idx.keys())

In [None]:
ph_ds = PhotoDataset(path_photo)
ph_dl = DataLoader(ph_ds, batch_size=1, pin_memory=True)
os.makedirs('/kaggle/working/images')
trans = transforms.ToPILImage()

In [None]:
t = tqdm(ph_dl, leave=False, total=ph_dl.__len__())
for i, photo in enumerate(t):
    with torch.no_grad():
        pred_monet = gan.gen_ptm(photo.to(device)).cpu().detach()
    
    pred_monet = reverse_normalize(pred_monet)
    img = trans(pred_monet[0]).convert("RGB")
    
    img.save("/kaggle/working/images/" + str(i+1) + ".jpg")

In [None]:
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/working")

In [None]:
learning_rate = [2e-4]
beta_1 = [0.5]
beta_2 = [0.999]  
df_results = pd.DataFrame({'Learning Rate': learning_rate, 'Beta 1': beta_1, 'Beta 2': beta_2})
print(df_results)