In [None]:
import os
import gc
import itertools
from pprint import pprint
import random

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

In [None]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input/gan-getting-started/monet_jpg'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# set numpy seed to always get the same 30 style images
np.random.seed(0)
torch.random.manual_seed(0)
random.seed(0)

def sample_images(paths, n_samples=30):
    idxs = np.sort(np.random.choice(len(paths), n_samples, replace=False))

    return paths[idxs]

# Load the images paths

In [None]:
def load_images(path, image_shape=(256,256)):
    monet_path = os.path.join(path, 'monet_jpg')
    photo_path = os.path.join(path, 'photo_jpg')

    style_images_paths = np.array(list(os.listdir(monet_path)))
    content_images_paths = np.array(list(os.listdir(photo_path)))

    sampled_style_images_paths = sample_images(style_images_paths)
    sampled_content_images_paths = sample_images(content_images_paths, n_samples=7000)

    return sampled_style_images_paths, sampled_content_images_paths

# dataset path
train_path = '/kaggle/input/gan-getting-started'
style_imgs, content_imgs = load_images(train_path)

print(content_imgs.shape)
print(style_imgs.shape)
pprint(style_imgs)

In [None]:
!git clone https://github.com/nspitzern/kaggle-monet-competition.git
    
!mkdir /kaggle/temp
!mv /kaggle/working/kaggle-monet-competition /kaggle/temp/kaggle-monet-competition

In [None]:
kaggle_working_dir = '/kaggle/working'
kaggle_working_output_dir = '/kaggle/working/output'
os.mkdir(kaggle_working_output_dir)

kaggle_my_files = '/kaggle/temp/kaggle-monet-competition'

with open(os.path.join(kaggle_my_files, 'style_files_path.npy'), 'rb') as f:
    style_imgs = np.load(f)

In [None]:
fig = plt.figure(figsize=(20, 20))
columns = 6
rows = 5
for i in range(1, columns*rows +1):
    img = Image.open(os.path.join(train_path, 'monet_jpg', style_imgs[i - 1]))
    fig.add_subplot(rows, columns, i)
    plt.imshow(img)
plt.show()

In [None]:
class MonetSampledDataset(Dataset):
    def __init__(self, root_path, content_root_path, style_root_path, content_paths, style_paths, style_transforms):
        self.root_path = root_path
        self.content_root_path = content_root_path
        self.style_root_path = style_root_path
        self.content_paths = content_paths
        self.style_paths = style_paths
        self.style_transforms = style_transforms
        self.to_tensor = transforms.Compose([
                transforms.ToTensor(), 
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # normalize the image
        ])

    def __getitem__(self, content_idx):
        # get the current content image
        content_image_path = self.content_paths[content_idx]
        content_image = Image.open(os.path.join(self.root_path, self.content_root_path, content_image_path))

        # sample a style image
        style_idx = np.random.choice(len(self.style_paths))
        style_image_path = self.style_paths[style_idx]
        style_image = Image.open(os.path.join(self.root_path,  self.style_root_path, style_image_path))

        # convert to tensors
        content_image = self.to_tensor(content_image)
        original_style_image = self.to_tensor(style_image.copy())
        style_image = self.style_transforms(style_image)

        return {'photo': content_image, 'monet': original_style_image, 'monet_aug': style_image}

    def __len__(self):
        return len(self.content_paths)
    
    def resample_content_images(self):
        photo_path = os.path.join(self.root_path, self.content_root_path)
        content_images_paths = np.array(list(os.listdir(photo_path)))
        self.content_paths = sample_images(content_images_paths, n_samples=1000)

In [None]:
style_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
#     transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
    transforms.RandomResizedCrop((256, 256)),
    transforms.ToTensor(), # numpy array to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # normalize the image between [-1,1]              
])

In [None]:
tensor2img = transforms.Compose([transforms.Normalize((-1, -1, -1), (2, 2 ,2)), # normalize the image between [0, 1]
                                 transforms.ToPILImage()])

def tensor2image(image):
    return tensor2img(image)

## Create Dataset and Dataloader

In [None]:
# define dataset
monet_dataset = MonetSampledDataset(root_path=train_path,
                                    content_root_path='photo_jpg',
                                    style_root_path='monet_jpg',
                                    content_paths=content_imgs,
                                    style_paths=style_imgs,
                                    style_transforms=style_transforms)

In [None]:
print(len(monet_dataset))
print(monet_dataset[0]['photo'].shape)
print(monet_dataset[0]['monet'].shape)

# Define Generator

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.conv_block(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, input_size, output_size, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_size, 64, kernel_size=7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2

        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]

            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [
                ResidualBlock(in_features=in_features)
            ]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]

            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [
                  nn.ReflectionPad2d(3),
                  nn.Conv2d(64, output_size, 7),
                  nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:
def load_model(path, generator_photo2monet, device):
    generator_photo2monet.load_state_dict(torch.load(os.path.join(path, 'generator_photo2monet.pth')))

    generator_photo2monet.to(device)

In [None]:
in_channels = 3
out_channels = 3
device = 'cuda' if torch.cuda.is_available else 'cpu'

generator_photo2monet = Generator(in_channels, out_channels).to(device)
load_model(os.path.join(kaggle_my_files, 'pretrain_train/models'), generator_photo2monet, device)

In [None]:
BATCH_SIZE = 1

# dataset loader
monet_dataloder = DataLoader(monet_dataset, batch_size=1, shuffle=False)

# Create Fake images

In [None]:
def save_image(output_path, image, i):
    output_path = os.path.join(output_path, f'{i}.jpg')
    image.save(output_path)

In [None]:
# for l in os.listdir(kaggle_working_dir):
#     if l.endswith('.jpg'):
#         os.remove(os.path.join(kaggle_working_dir, l))

In [None]:
results = []

generator_photo2monet.eval()


Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.Tensor
input_photo = Tensor(BATCH_SIZE, in_channels, 256, 256).to(device)

for i, batch in enumerate(monet_dataloder):
    real_photo = Variable(input_photo.copy_(batch['photo'])).to(device)
    
    fake_monet = generator_photo2monet(real_photo).cpu().squeeze().detach()
    
    fake_monet = fake_monet.cpu().detach()
    fake_monet = tensor2image(fake_monet)
    
    save_image(kaggle_working_output_dir, fake_monet, i)
    
    print(f'Image #{i} created...')

In [None]:
import shutil

shutil.rmtree(kaggle_my_files)

In [None]:
from zipfile import ZipFile

with ZipFile('images.zip', 'w') as zip:
    for file in os.listdir(kaggle_working_output_dir):
        if file.endswith('.jpg'):
            zip.write(os.path.join(kaggle_working_output_dir, file), file)
    print('zip created')

In [None]:
shutil.rmtree(kaggle_working_output_dir)
# os.listdir(kaggle_working_output_dir)

In [None]:
# os.remove('images.zip')