In [None]:
%config Completer.use_jedi = False

In [None]:
import os
import datetime
import glob
import time
import cv2
import itertools
from tqdm.notebook import tqdm
import shutil

from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision.transforms as transforms
from torchvision.utils import make_grid

In [None]:
img_path = '../input/gan-getting-started/'
monet_path = glob.glob(img_path + 'monet_jpg/*')
photo_path = glob.glob(img_path + 'photo_jpg/*')

print('Dataset')
print(f'- monet data : {len(monet_path)}\n- photo data : {len(photo_path)}')

In [None]:
class Custom_dataset(Dataset):
    def __init__(self, img_path : list , transforms = None, mode = 'train'):
        super().__init__()

        self.path_monet = img_path[0]
        self.path_photo = img_path[1]
        self.transforms = transforms
        self.mode = mode
        
    def __getitem__(self, idx):
        if self.mode == 'train':
            monet_img = self.path_monet[idx]
            monet_img = Image.open(monet_img).convert('RGB')
            monet_img = self.transforms(monet_img)
            
            photo_idx = np.random.randint(0, len(self.path_photo))
            photo_img = self.path_photo[photo_idx]
            photo_img = Image.open(photo_img).convert('RGB')
            photo_img = self.transforms(photo_img)
            
            return monet_img, photo_img

        elif self.mode == 'test':
            photo_img = self.path_photo[idx]
            photo_img = Image.open(photo_img).convert('RGB')
            photo_img = self.transforms(photo_img)
            return photo_img    
        
    def __len__(self):
        if self.mode == 'train':
            return len(self.path_monet)
        elif self.mode == 'test':
            return len(self.path_photo)

In [None]:
# Data loader
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], 
                                [0.5, 0.5, 0.5])
])

test_dataset = Custom_dataset([monet_path, photo_path], transforms = transform, mode = 'test')
test_loader = DataLoader(test_dataset, batch_size = 1, shuffle = False)

## Inference
- photo -> monet

In [None]:
# Define Conv Block
'''
1. Conv_up
2. Conv_down
3. Residual_block
'''

# 1.Conv_up
class Conv_up(nn.Module):
    '''
    convTranspose - instanceNorm - ReLU - (dropout)
    '''
    def __init__(self, in_ch, out_ch, kernel_size = 4, stride = 2, 
                 padding = 1, output_padding = 1, drop_out = True):
        super().__init__()
        
        self.convT = nn.ConvTranspose2d(in_ch, out_ch,
                                       kernel_size = kernel_size,
                                       stride = stride,
                                       padding = padding,
                                       output_padding = output_padding,
                                       bias = False)
        self.instance_norm = nn.InstanceNorm2d(out_ch)
        self.relu = nn.ReLU()
        self.drop_out = drop_out
        
    def forward(self, x):
        x = self.convT(x)
        x = self.instance_norm(x)
        x = self.relu(x)
        if self.drop_out:
            x = nn.Dropout2d(0.5)(x)

        return x

# 2. Conv_down
class Conv_down(nn.Module):
    '''
    Conv2d - instanceNorm - LeakyReLU
    '''
    def __init__(self, in_ch, out_ch,
                 kernel_size = 4,
                 stride = 2,
                 padding = 1,
                 batch_Norm = True):
        super().__init__()
        
        self.conv = nn.Conv2d(in_ch, out_ch,
                             kernel_size = kernel_size,
                             stride = stride,
                             padding = padding,
                             bias = True)
        self.instance_norm = nn.InstanceNorm2d(out_ch)
        self.relu = nn.ReLU()
        self.batch = batch_Norm
        
    def forward(self, x):
        x = self.conv(x)
        if self.batch:
            x = self.instance_norm(x)
        x = self.relu(x)
        
        return x
    
# 3. Residual_block
class Residual_block(nn.Module):
    '''
    Conv2d - InstanceNorm - Relu - Conv2d - InstanceNorm
    '''
    def __init__(self, in_ch, out_ch, kernel_size = 3, stride = 1, padding = 1):
        super().__init__()
        
        self.res = nn.Sequential(
            nn.Conv2d(in_ch, out_ch,
                     kernel_size = kernel_size,
                     stride = stride,
                     padding = padding,
                     padding_mode = 'reflect',
                     bias = False),
            nn.InstanceNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch,
                     kernel_size = kernel_size,
                     stride = stride,
                     padding = padding,
                     padding_mode = 'reflect',
                     bias = False),
            nn.InstanceNorm2d(out_ch)
        )
    
    def forward(self, x):
        return x + self.res(x)


In [None]:
class Generator(nn.Module):
    '''
    kernel   D64 - D128 - D256 - R256 * n - U128 - U64 - U3
    filter   7x7 - 3x3  - 3x3  -          -  3x3 - 3x3 - 7x7
    stride    1     2      2       1          2     2     1
    '''
    def __init__(self, n_features = 64, n_res = 9):
        super().__init__()
        
        self.main = nn.Sequential(
            nn.ReflectionPad2d(3),
            Conv_down(3, n_features, 7, 1, 0, False),
            Conv_down(n_features * 1, n_features * 2, 3, 2),
            Conv_down(n_features * 2, n_features * 4, 3, 2),
            *[
                Residual_block(n_features * 4, n_features * 4) for _ in range(n_res)
            ],
            Conv_up(n_features * 4, n_features * 2, 3, 2, 1),
            Conv_up(n_features * 2, n_features * 1, 3, 2, 1),
            nn.ReflectionPad2d(3),
            nn.Conv2d(n_features, 3, 7, 1, 0, bias = False),
            nn.Tanh()   
        )
        
    def forward(self, x):
        x = self.main(x)
        
        return x

In [None]:
# load netG_B
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = torch.load('../input/cyclegan-baseline/model_G/netG(uNet)_B100.pt').to(device)

In [None]:
def unNormalize(img, mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]):
    img = img[0].cpu().detach()
    np_img = np.transpose(img.numpy(), (1, 2, 0))
    np_img = np_img * std + mean
    np_img = (np_img * 255).astype('uint8')
    return np_img

In [None]:
!mkdir images

In [None]:
for i, photo in tqdm(enumerate(test_loader), total = len(test_loader)):
    with torch.no_grad():
        photo = photo.to(device)
        out = model(photo)
        img = unNormalize(out)
        img = Image.fromarray(img)
        img.save('../working/images/' + str(i + 1) + '.jpg')

In [None]:
shutil.make_archive('../working/images', 'zip', '../working/images')