### Data
Datasets, that I used to train NN.

1. Style images: https://www.kaggle.com/kovalevvyu/painter-by-numbers-resized
2. Content images: https://www.kaggle.com/awsaf49/coco-2017-dataset

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19

In [None]:
import os
import glob
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from skimage import io, transform
from PIL import Image

In [3]:
import warnings
warnings.simplefilter("ignore", UserWarning)
import os
import argparse
import matplotlib as mpl
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.utils import save_image

In [1]:
import os
import argparse
from PIL import Image
import torch
from torchvision import transforms
from torchvision.utils import save_image
import cv2
from PIL import UnidentifiedImageError
import time

In [None]:
def calc_mean_std(features):
    """
    :param features: shape of features -> [batch_size, c, h, w]
    :return: features_mean, feature_s: shape of mean/std ->[batch_size, c, 1, 1]
    """

    batch_size, c = features.size()[:2]
    features_mean = features.reshape(batch_size, c, -1).mean(dim=2).reshape(batch_size, c, 1, 1)
    features_std = features.reshape(batch_size, c, -1).std(dim=2).reshape(batch_size, c, 1, 1) + 1e-6
    return features_mean, features_std


def adain(content_features, style_features):
    """
    Adaptive Instance Normalization
    :param content_features: shape -> [batch_size, c, h, w]
    :param style_features: shape -> [batch_size, c, h, w]
    :return: normalized_features shape -> [batch_size, c, h, w]
    """
    content_mean, content_std = calc_mean_std(content_features)
    style_mean, style_std = calc_mean_std(style_features)
    normalized_features = style_std * (content_features - content_mean) / content_std + style_mean
    return normalized_features


class VGGEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg19(pretrained=True).features
        self.slice1 = vgg[: 2]
        self.slice2 = vgg[2: 7]
        self.slice3 = vgg[7: 12]
        self.slice4 = vgg[12: 21]
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, images, output_last_feature=False):
        h1 = self.slice1(images)
        h2 = self.slice2(h1)
        h3 = self.slice3(h2)
        h4 = self.slice4(h3)
        if output_last_feature:
            return h4
        else:
            return h1, h2, h3, h4


class RC(nn.Module):
    """A wrapper of ReflectionPad2d and Conv2d"""
    def __init__(self, in_channels, out_channels, kernel_size=3, pad_size=1, activated=True):
        super().__init__()
        self.pad = nn.ReflectionPad2d((pad_size, pad_size, pad_size, pad_size))
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
        self.activated = activated

    def forward(self, x):
        h = self.pad(x)
        h = self.conv(h)
        if self.activated:
            return F.relu(h)
        else:
            return h


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.rc1 = RC(512, 256, 3, 1)
        self.rc2 = RC(256, 256, 3, 1)
        self.rc3 = RC(256, 256, 3, 1)
        self.rc4 = RC(256, 256, 3, 1)
        self.rc5 = RC(256, 128, 3, 1)
        self.rc6 = RC(128, 128, 3, 1)
        self.rc7 = RC(128, 64, 3, 1)
        self.rc8 = RC(64, 64, 3, 1)
        self.rc9 = RC(64, 3, 3, 1, False)

    def forward(self, features):
        h = self.rc1(features)
        h = F.interpolate(h, scale_factor=2)
        h = self.rc2(h)
        h = self.rc3(h)
        h = self.rc4(h)
        h = self.rc5(h)
        h = F.interpolate(h, scale_factor=2)
        h = self.rc6(h)
        h = self.rc7(h)
        h = F.interpolate(h, scale_factor=2)
        h = self.rc8(h)
        h = self.rc9(h)
        return h


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg_encoder = VGGEncoder()
        self.decoder = Decoder()

    def generate(self, content_images, style_images, alpha=1.0):
        content_features = self.vgg_encoder(content_images, output_last_feature=True)
        style_features = self.vgg_encoder(style_images, output_last_feature=True)
        t = adain(content_features, style_features)
        t = alpha * t + (1 - alpha) * content_features
        out = self.decoder(t)
        return out

    @staticmethod
    def calc_content_loss(out_features, t):
        return F.mse_loss(out_features, t)

    @staticmethod
    def calc_style_loss(content_middle_features, style_middle_features):
        loss = 0
        for c, s in zip(content_middle_features, style_middle_features):
            c_mean, c_std = calc_mean_std(c)
            s_mean, s_std = calc_mean_std(s)
            loss += F.mse_loss(c_mean, s_mean) + F.mse_loss(c_std, s_std)
        return loss

    def forward(self, content_images, style_images, alpha=1.0, lam=10):
        content_features = self.vgg_encoder(content_images, output_last_feature=True)
        style_features = self.vgg_encoder(style_images, output_last_feature=True)
        t = adain(content_features, style_features)
        t = alpha * t + (1 - alpha) * content_features
        out = self.decoder(t)

        output_features = self.vgg_encoder(out, output_last_feature=True)
        output_middle_features = self.vgg_encoder(out, output_last_feature=False)
        style_middle_features = self.vgg_encoder(style_images, output_last_feature=False)

        loss_c = self.calc_content_loss(output_features, t)
        loss_s = self.calc_style_loss(output_middle_features, style_middle_features)
        loss = loss_c + lam * loss_s
        return loss

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

trans = transforms.Compose([transforms.RandomCrop(256),
                            transforms.ToTensor(),
                            normalize])


def denorm(tensor, device):
    std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1).to(device)
    mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1).to(device)
    res = torch.clamp(tensor * std + mean, 0, 1)
    return res

In [None]:
make_dir = True

In [None]:
if make_dir == True:
    os.mkdir('/kaggle/working/los_dir/')
    os.mkdir('/kaggle/working/model_dir/')
    os.mkdir('/kaggle/working/image_dir_epoch/')

In [None]:
class PreprocessDataset(Dataset):
    def __init__(self, content_dir, style_dir, transforms=trans):
        l1 = content_dir.split('/')[-1]
        l2 = style_dir.split('/')[-1]
        content_dir_resized = "/kaggle/working/" + l1 + '_resized'
        style_dir_resized = "/kaggle/working/" + l2 + '_resized'
        
        if not (os.path.exists(content_dir_resized) and
                os.path.exists(style_dir_resized)):
            os.mkdir(content_dir_resized)
            os.mkdir(style_dir_resized)
            self._resize(content_dir, content_dir_resized)
            self._resize(style_dir, style_dir_resized)
        
        content_images = glob.glob((content_dir_resized + '/*'))
        np.random.shuffle(content_images)
        print('content img: ', len(content_images))
        style_images = glob.glob(style_dir_resized + '/*')
        np.random.shuffle(style_images)
        print('style img: ', len(style_images))
        self.images_pairs = list(zip(content_images, style_images))
        print('pairs: ', len(list(zip(content_images, style_images))))
        self.transforms = transforms

    @staticmethod
    def _resize(source_dir, target_dir):
        print(f'Start resizing {source_dir} ')
        for i,length in zip(os.listdir(source_dir), range(0,7500)):
            filename = os.path.basename(i)
            try:
                image = io.imread(os.path.join(source_dir, i))
                if len(image.shape) == 3 and image.shape[-1] == 3:
                    H, W, _ = image.shape
                    if H < W:
                        ratio = W / H
                        H = 512
                        W = int(ratio * H)
                    else:
                        ratio = H / W
                        W = 512
                        H = int(ratio * W)
                    image = transform.resize(image, (H, W), mode='reflect', anti_aliasing=True)
                    io.imsave(os.path.join(target_dir, filename), image)
            except:
                continue

    def __len__(self):
        return len(self.images_pairs)

    def __getitem__(self, index):
        try:
            content_image, style_image = self.images_pairs[index]
            content_image = Image.open(content_image)
            style_image = Image.open(style_image)
            if self.transforms:
                content_image = self.transforms(content_image)
                style_image = self.transforms(style_image)
            return content_image, style_image
        except UnidentifiedImageError:
            pass

In [None]:
def main():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    image_dir = '/kaggle/working/image_dir_epoch'
    model_dir = '/kaggle/working/model_dir'
    
    # prepare dataset and dataLoader
#     train_dataset = PreprocessDataset('../input/coco-2017-dataset/coco2017/val2017', '../input/claude-monet-pictorial-works-dataset-wikiart')
    train_dataset = PreprocessDataset('../input/coco-2017-dataset/coco2017/train2017', '../input/painter-by-numbers-resized')
    test_dataset = PreprocessDataset('../input/testing/dataset_test/content', '../input/testing/dataset_test/style')
    iters = len(train_dataset)
    print(f'Length of train image pairs: {iters}')

    batch_size = 32

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    test_iter = iter(test_loader)

    model = Model().to(device)

    learning_rate = 5e-5
    optimizer = Adam(model.parameters(), lr=learning_rate)

    epoch = 30
    snapshot_interval = 1000
    loss_list = []
    for e in range(1, epoch + 1):
        print(f'Start {e} epoch')
        for i, (content, style) in tqdm(enumerate(train_loader, 1)):
            content = content.to(device)
            style = style.to(device)
            loss = model(content, style)
            loss_list.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f'[{e}/total {epoch} epoch],[{i} /'
                  f'total {round(iters/batch_size)} iteration]: {loss.item()}')

            if i % snapshot_interval == 0:
                content, style = next(test_iter)
                content = content.to(device)
                style = style.to(device)
                with torch.no_grad():
                    out = model.generate(content, style)
                content = denorm(content, device)
                style = denorm(style, device)
                out = denorm(out, device)
                res = torch.cat([content, style, out], dim=0)
                res = res.to('cpu')
                save_image(res, f'{image_dir}/{e}_epoch_{i}_iteration.png', nrow=batch_size)
        torch.save(model.state_dict(), f'{model_dir}/{e}_epoch.pth')
        print(f'{model_dir}/{e}_epoch.pth')
    loss_dir = '/kaggle/working/los_dir'
    plt.plot(range(len(loss_list)), loss_list)
    plt.xlabel('iteration')
    plt.ylabel('loss')
    plt.title('train loss')
    plt.savefig(f'{loss_dir}/train_loss.png')
    with open(f'{loss_dir}/loss_log.txt', 'w') as f:
        for l in loss_list:
            f.write(f'{l}\n')
    print(f'Loss saved in {loss_dir}')

In [None]:
if __name__ == '__main__':
    main()

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

trans = transforms.Compose([transforms.ToTensor(),
                            normalize])


def denorm(tensor, device):
    std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1).to(device)
    mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1).to(device)
    res = torch.clamp(tensor * std + mean, 0, 1)
    return res


def main_test():
    
    start_time = time.time()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_state_path = './model_dir/20_epoch.pth'
    
    model = Model()
    if model_state_path is not None:
        model.load_state_dict(torch.load(model_state_path, map_location=lambda storage, loc: storage))
    model = model.to(device)

    c = Image.open('../input/testing/dataset_test/content/doggie.jpg')
    s = Image.open('../input/painter-by-numbers-resized/0.jpg')
    c_tensor = trans(c).unsqueeze(0).to(device)
    s_tensor = trans(s).unsqueeze(0).to(device)
    alpha = 1
    with torch.no_grad():
        out = model.generate(c_tensor, s_tensor, alpha)

    out = denorm(out, device)

    output_name = './test/test_3'
    save_image(out, f'{output_name}.jpg', nrow=1)

    print(f'result saved into files starting with {output_name}')
    
    end_time = time.time()
    
    print(f'inference neural network is {end_time-start_time}')


In [None]:
if __name__ == '__main__':
    main_test()

![Train loss](train_loss.png)