<a href="https://colab.research.google.com/github/sushant-97/Style-Transform-Dashtoon-Assignment-GenAI/blob/main/Final_dash_toon_style_transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### In this notebook the AdaIn Style Transfer is emplemented from the original paper [Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization](https://openaccess.thecvf.com/content_ICCV_2017/papers/Huang_Arbitrary_Style_Transfer_ICCV_2017_paper.pdf). It is based on this [github repository](https://github.com/irasin/Pytorch_AdaIN).

NST originated with [Gatys et al.](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf), with the central concept being that the statistical characteristics (mean, variance, etc.) of feature maps derived from internal Conv layers, originally designed for pre-trained image recognition, can capture the stylistic essence of an image.

While early approaches predominantly relied on optimization methods, the currenttly trend shifted towards utilizing feed-forward networks as generators in subsequent research. The DeNA article provides comprehensive details and comparisons of notable methods, sparing the need for an in-depth discussion here.

In my implementation this time, I opted for the relatively straightforward Adain, and I'll provide a brief explanation here.

## Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization


This method is characterized by the use of Adaptive Instance Normalization (Adain). To understand Adain, it's essential to first comprehend Instance Normalization.

Instance normalization and Batch normalization are often compared, so what are the key differences?

Let's consider the output of an intermediate layer in a typical CNN, with a size of (b, c, h, w). Now, let's summarize where each normalization method normalizes and what mean and variance are obtained:




1.   Batch Normalization (BN):
      1.   Calculate mean and variance for b, h, w in (b, c, h, w).
      2.   Obtain vectors of mean and variance with shape=c.

2.   Instance Normalization (IN):
    1. Calculate mean and variance for h, w in (b, c, h, w).
    2. Obtain matrices of mean and variance with shape=(b, c).

3.  Layer Normalization (LN, and a brief explanation of Layer Normalization):
    1. Calculate mean and variance for c, h, w in (b, c, h, w).
    2. Obtain vectors of mean and variance with shape=b.


In the proposed method in this paper, a VGG-based AutoEncoder with an Adain layer in the middle is trained. Content and Style images are inputted into the Encoder, and the output feature maps are referred to as C_feature and S_feature, respectively. For each, mean and variance are calculated similar to the IN method. Finally, C_feature is normalized using the mean and variance of C_feature and then inverse-normalized using the mean and variance of S_feature. In other words, the mean and variance of C_feature are transformed to match those of S_feature. This is done based on the idea from Gatys et al. that the statistical information of feature maps can capture the artistic style.

The original paper implemented this in Torch, but this time I implemented Adain from scratch in Pytorch. The code is available on GitHub. Feel free to check it out and give it a star if you find it helpful.

Results:
I'll share some result images. Additionally, I've implemented a feature to adjust the artistic style, so please feel free to try that as well.


In [None]:
# imports

import warnings
warnings.simplefilter("ignore", UserWarning)

import os
import glob
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

from torchvision.utils import save_image, make_grid
from torchvision import transforms
from torchvision import models

In [None]:
# utils functions ans model implementation
def calc_mean_std(features):
    """
    :param features: shape of features -> [batch_size, c, h, w]
    :return: features_mean, feature_s: shape of mean/std ->[batch_size, c, 1, 1]
    """

    batch_size, c = features.size()[:2]
    features_mean = features.reshape(batch_size, c, -1).mean(dim=2).reshape(batch_size, c, 1, 1)
    features_std = features.reshape(batch_size, c, -1).std(dim=2).reshape(batch_size, c, 1, 1) + 1e-6
    return features_mean, features_std


def adain(content_features, style_features):
    """
    Adaptive Instance Normalization
    :param content_features: shape -> [batch_size, c, h, w]
    :param style_features: shape -> [batch_size, c, h, w]
    :return: normalized_features shape -> [batch_size, c, h, w]
    """
    content_mean, content_std = calc_mean_std(content_features)
    style_mean, style_std = calc_mean_std(style_features)
    normalized_features = style_std * (content_features - content_mean) / content_std + style_mean
    return normalized_features


class VGGEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
        self.slice1 = vgg[: 2]
        self.slice2 = vgg[2: 7]
        self.slice3 = vgg[7: 12]
        self.slice4 = vgg[12: 21]
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, images, output_last_feature=False):
        h1 = self.slice1(images)
        h2 = self.slice2(h1)
        h3 = self.slice3(h2)
        h4 = self.slice4(h3)
        if output_last_feature:
            return h4
        else:
            return h1, h2, h3, h4


class RC(nn.Module):
    """A wrapper of ReflectionPad2d and Conv2d"""
    def __init__(self, in_channels, out_channels, kernel_size=3, pad_size=1, activated=True):
        super().__init__()
        self.pad = nn.ReflectionPad2d((pad_size, pad_size, pad_size, pad_size))
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
        self.activated = activated

    def forward(self, x):
        h = self.pad(x)
        h = self.conv(h)
        if self.activated:
            return F.relu(h)
        else:
            return h


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.rc1 = RC(512, 256, 3, 1)
        self.rc2 = RC(256, 256, 3, 1)
        self.rc3 = RC(256, 256, 3, 1)
        self.rc4 = RC(256, 256, 3, 1)
        self.rc5 = RC(256, 128, 3, 1)
        self.rc6 = RC(128, 128, 3, 1)
        self.rc7 = RC(128, 64, 3, 1)
        self.rc8 = RC(64, 64, 3, 1)
        self.rc9 = RC(64, 3, 3, 1, False)

    def forward(self, features):
        h = self.rc1(features)
        h = F.interpolate(h, scale_factor=2)
        h = self.rc2(h)
        h = self.rc3(h)
        h = self.rc4(h)
        h = self.rc5(h)
        h = F.interpolate(h, scale_factor=2)
        h = self.rc6(h)
        h = self.rc7(h)
        h = F.interpolate(h, scale_factor=2)
        h = self.rc8(h)
        h = self.rc9(h)
        return h


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg_encoder = VGGEncoder()
        self.decoder = Decoder()

    def generate(self, content_images, style_images, alpha=1.0):
        self.vgg_encode = self.vgg_encoder
        content_features = self.vgg_encoder(content_images, output_last_feature=True)
        style_features = self.vgg_encoder(style_images, output_last_feature=True)
        t = adain(content_features, style_features)
        t = alpha * t + (1 - alpha) * content_features
        out = self.decoder(t)
        return out

    @staticmethod
    def calc_content_loss(out_features, t):
        return F.mse_loss(out_features, t)

    @staticmethod
    def calc_style_loss(content_middle_features, style_middle_features):
        loss = 0
        for c, s in zip(content_middle_features, style_middle_features):
            c_mean, c_std = calc_mean_std(c)
            s_mean, s_std = calc_mean_std(s)
            loss += F.mse_loss(c_mean, s_mean) + F.mse_loss(c_std, s_std)
        return loss

    def forward(self, content_images, style_images, alpha=1.0, lam=10):
        content_features = self.vgg_encoder(content_images, output_last_feature=True)
        style_features = self.vgg_encoder(style_images, output_last_feature=True)
        t = adain(content_features, style_features)
        t = alpha * t + (1 - alpha) * content_features
        out = self.decoder(t)

        output_features = self.vgg_encoder(out, output_last_feature=True)
        output_middle_features = self.vgg_encoder(out, output_last_feature=False)
        style_middle_features = self.vgg_encoder(style_images, output_last_feature=False)

        loss_c = self.calc_content_loss(output_features, t)
        loss_s = self.calc_style_loss(output_middle_features, style_middle_features)
        loss = loss_c + lam * loss_s
        return loss

In [None]:
# dataset

transform = transforms.Compose([
    transforms.Resize(size=512),
    transforms.RandomCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


def denorm(tensor, device):
    std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1).to(device)
    mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1).to(device)
    res = torch.clamp(tensor * std + mean, 0, 1)
    return res


class ContentStyleDataset(Dataset):
    def __init__(self, content_dir, style_dir, num_range, transform=transform):
        content_images = glob.glob(os.path.join(content_dir, "*.jpg"))
        style_images = glob.glob(os.path.join(style_dir, "*.jpg"))

        self.images_pairs = list(zip(content_images, style_images))[num_range[0]:num_range[1]]
        self.transform = transform

    def __len__(self):
        return len(self.images_pairs)

    def __getitem__(self, index):
        content_image, style_image = self.images_pairs[index]
        content_image = Image.open(content_image).convert("RGB")
        style_image = Image.open(style_image).convert("RGB")

        if self.transform:
            content_image = self.transform(content_image)
            style_image = self.transform(style_image)
        return content_image, style_image

In [None]:
# training

batch_size = 32
epochs = 20
learning_rate = 5e-5
train_content_dir ='/kaggle/input/coco-wikiart-nst-dataset-512-100000/content'
train_style_dir = '/kaggle/input/coco-wikiart-nst-dataset-512-100000/style'
test_content_dir = '/kaggle/input/coco-wikiart-nst-dataset-512-100000/content'
test_style_dir = '/kaggle/input/coco-wikiart-nst-dataset-512-100000/style'


loss_dir = "loss"
model_state_dir = "model_state"

if not os.path.exists(loss_dir):
    os.mkdir(loss_dir)
if not os.path.exists(model_state_dir):
    os.mkdir(model_state_dir)

# set device on GPU if available, else CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = "cpu"

# to speed-up the training we will be using only 32000 context and style pairs of images
num_train_range = (0, 32000)

# prepare dataset and dataLoader
train_dataset = ContentStyleDataset(train_content_dir, train_style_dir, num_range=num_train_range)
iters = len(train_dataset)
print(f'Length of train image pairs: {iters}')

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# set model and optimizer
model = Model()
model = model.to(device)

optimizer = Adam(model.parameters(), lr=learning_rate)

# start training
loss_list = []
for epoch in range(1, epochs + 1):
    print(f'Start {epoch} epoch')
    for i, (content, style) in tqdm(enumerate(train_loader, 1)):
        content = content.to(device)
        style = style.to(device)
        loss = model(content, style)
        loss_list.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    print(f'[Epoch {epoch} Loss: {loss.item()}')
    torch.save(model.state_dict(), f'{model_state_dir}/{epoch}_epoch.pth')

# plot training loss
plt.plot(range(len(loss_list)), loss_list)
plt.xlabel('iteration')
plt.ylabel('loss')
plt.title('train loss')
plt.savefig(f'{loss_dir}/train_loss.png')
with open(f'{loss_dir}/loss_log.txt', 'w') as f:
    for l in loss_list:
        f.write(f'{l}\n')
print(f'Loss saved in {loss_dir}')

Length of train image pairs: 32000


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
 42%|████▏     | 233M/548M [00:00<00:00, 342MB/s]  

In [None]:
# test

model = model.eval()

# take 5 image pairs for the test set
num_test_range = (32001, 32006)

test_transform = transforms.Compose([
                 transforms.ToTensor(),
                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])])

test_dataset = ContentStyleDataset(test_content_dir, test_style_dir, transform=test_transform, num_range=num_test_range)
iters = len(test_dataset)
print(f'Length of test image pairs: {iters}')

test_loader = DataLoader(test_dataset, batch_size=iters, shuffle=False)

for i, (content, style) in tqdm(enumerate(test_loader, 1)):
    content = content.to(device)
    style = style.to(device)
    with torch.no_grad():
        out = model.generate(content, style)
        content = denorm(content, device).detach().cpu()
        style = denorm(style, device).detach().cpu()
        out = denorm(out, device).detach().cpu()
        res = torch.cat([content, style, out], dim=0)
        grid_img = make_grid(res, nrow=iters)
        plt.figure(figsize=(20, 12))
        plt.imshow(grid_img.permute(1, 2, 0))
        plt.axis("off")