<a href="https://colab.research.google.com/github/sushant-97/Style-Transform/blob/main/Solution_2_N_Style_Dashtoon_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## N Style Transform
based on [A Learned Representation For Artistic Style](https://arxiv.org/abs/1610.07629)

In [None]:
# Extended the code from
# Fast Neural Style Transform library by adding Conditional
# https://github.com/pytorch/examples/tree/main/fast_neural_style


## Utility Functions

In [None]:
"""Utility Code."""

import torch
import torchvision
import torchvision.transforms as T
from PIL import Image

MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)

normalize = T.Normalize(mean=MEAN, std=STD)
denormalize = T.Normalize(mean=[-m/s for m, s in zip(MEAN, STD)],
                          std=[1/std for std in STD])


def get_transforms(imsize=None, cropsize=None, cencrop=False):
    """Get the transforms."""
    transformer = []
    if imsize:
        transformer.append(T.Resize(imsize))
    if cropsize:
        if cencrop:
            transformer.append(T.CenterCrop(cropsize))
        else:
            transformer.append(T.RandomCrop(cropsize))

    transformer.append(T.ToTensor())
    transformer.append(normalize)
    return T.Compose(transformer)


def imload(path, imsize=None, cropsize=None, cencrop=False):
    """Load a image."""
    transformer = get_transforms(imsize=imsize,
                                 cropsize=cropsize,
                                 cencrop=cencrop)
    image = Image.open(path).convert("RGB")
    return transformer(image).unsqueeze(0)


def imsave(image, save_path):
    """Save a image."""
    image = denormalize(torchvision.utils.make_grid(image)).clamp_(0.0, 1.0)
    torchvision.utils.save_image(image, save_path)
    return None


class ImageDataset:
    """Image Dataset."""

    def __init__(self, dir_path):
        """Init."""
        self.images = sorted(list(dir_path.glob('*.jpg')))

    def __len__(self):
        """Return the Number of data sampels."""
        return len(self.images)

    def __getitem__(self, index):
        """Get Image and Index."""
        img = Image.open(self.images[index]).convert('RGB')
        return img, index


class DataProcessor:
    """Data Processor."""

    def __init__(self, imsize=256, cropsize=240, cencrop=False):
        """Init."""
        self.transforms = get_transforms(imsize=imsize,
                                         cropsize=cropsize,
                                         cencrop=cencrop)

    def __call__(self, batch):
        """Process the batch."""
        images, indices = list(zip(*batch))

        inputs = torch.stack([self.transforms(image) for image in images])
        return inputs, indices


## Loss Function

In [None]:
"""Loss Function Code."""
# based on Pytorch official tutorial
# https://pytorch.org/tutorials/advanced/neural_style_tutorial.html#:~:text=The%20style%20loss%20module%20is,_%7BXL%7D%20GXL.

import torch
from torch.nn.functional import mse_loss


def calc_content_loss(features, targets, nodes):
    """Calculate Content Loss."""
    content_loss = 0
    for node in nodes:
        content_loss += mse_loss(features[node], targets[node])
    return content_loss


def gram(x):
    """Transfer a feature to gram matrix."""
    b, c, h, w = x.size()
    f = x.flatten(2)
    g = torch.bmm(f, f.transpose(1, 2))
    return g.div(h*w)


def calc_style_loss(features, targets, nodes):
    """Calcuate Gram Loss."""
    gram_loss = 0
    for node in nodes:
        gram_loss += mse_loss(gram(features[node]), gram(targets[node]))
    return gram_loss


def calc_tv_loss(x):
    """Calc Total Variation Loss."""
    tv_loss = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
    tv_loss += torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
    return tv_loss


## Network

In [None]:
"""Network Code."""

import torch
import torch.nn as nn


class CIN(nn.Module):
    """Conditional Instance Norm."""

    def __init__(self, num_style, ch):
        """Init with number of style and channel."""
        super(CIN, self).__init__()
        self.normalize = nn.InstanceNorm2d(ch, affine=False)
        self.offset = nn.Parameter(0.01 * torch.randn(1, num_style, ch))
        self.scale = nn.Parameter(1 + 0.01 * torch.randn(1, num_style, ch))

    def forward(self, x, style_codes):
        """Forward func."""
        b, c, h, w = x.size()

        x = self.normalize(x)

        gamma = torch.sum(self.scale * style_codes, dim=1).view(b, c, 1, 1)
        beta = torch.sum(self.offset * style_codes, dim=1).view(b, c, 1, 1)

        x = x * gamma + beta

        return x.view(b, c, h, w)


class ConvWithCIN(nn.Module):
    """Convolution layer with CIN."""

    def __init__(self, num_style, in_ch, out_ch, stride, activation, ksize):
        """Init."""
        super(ConvWithCIN, self).__init__()
        self.padding = nn.ReflectionPad2d(ksize // 2)
        self.conv = nn.Conv2d(in_ch, out_ch, ksize, stride)

        self.cin = CIN(num_style, out_ch)

        # activatoin
        if activation == "relu":
            self.activation = nn.ReLU()

        elif activation == "linear":
            self.activation = lambda x: x

    def forward(self, x, style_codes):
        """Forward func."""
        x = self.padding(x)
        x = self.conv(x)
        x = self.cin(x, style_codes)
        x = self.activation(x)

        return x


class ResidualBlock(nn.Module):
    """ResidualBlock."""

    def __init__(self, num_style, in_ch, out_ch):
        """Init."""
        super(ResidualBlock, self).__init__()

        self.conv1 = ConvWithCIN(num_style, in_ch, out_ch, 1, "relu", 3)
        self.conv2 = ConvWithCIN(num_style, out_ch, out_ch, 1, "linear", 3)

    def forward(self, x, style_codes):
        """Forward func."""
        out = self.conv1(x, style_codes)
        out = self.conv2(out, style_codes)

        return x + out


class UpsamleBlock(nn.Module):
    """Upsampling Bloack."""

    def __init__(self, num_style, in_ch, out_ch):
        """Init."""
        super(UpsamleBlock, self).__init__()
        self.conv = ConvWithCIN(num_style, in_ch, out_ch, 1, "relu", 3)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x, style_codes):
        """Forward func."""
        x = self.upsample(x)
        x = self.conv(x, style_codes)

        return x


class StyleTransferNetwork(nn.Module):
    """Style Transfer Network."""

    def __init__(self, num_style=16):
        """Init."""
        super(StyleTransferNetwork, self).__init__()
        self.conv1 = ConvWithCIN(num_style,  3, 32, 1, 'relu', 9)
        self.conv2 = ConvWithCIN(num_style, 32, 64, 2, 'relu', 3)
        self.conv3 = ConvWithCIN(num_style, 64, 128, 2, 'relu', 3)

        self.residual1 = ResidualBlock(num_style, 128, 128)
        self.residual2 = ResidualBlock(num_style, 128, 128)
        self.residual3 = ResidualBlock(num_style, 128, 128)
        self.residual4 = ResidualBlock(num_style, 128, 128)
        self.residual5 = ResidualBlock(num_style, 128, 128)

        self.upsampling1 = UpsamleBlock(num_style, 128, 64)
        self.upsampling2 = UpsamleBlock(num_style, 64, 32)

        self.conv4 = ConvWithCIN(num_style, 32, 3, 1, 'linear', 9)

    def forward(self, x, style_codes):
        """Forward func."""
        x = self.conv1(x, style_codes)
        x = self.conv2(x, style_codes)
        x = self.conv3(x, style_codes)

        x = self.residual1(x, style_codes)
        x = self.residual2(x, style_codes)
        x = self.residual3(x, style_codes)
        x = self.residual4(x, style_codes)
        x = self.residual5(x, style_codes)

        x = self.upsampling1(x, style_codes)
        x = self.upsampling2(x, style_codes)

        x = self.conv4(x, style_codes)

        return x


## Main

In [None]:
"""Pytorch Implementation Code.

Reference: 'A Learned Representation for Artistic Style'
"""

import torch
import argparse
from pathlib import Path
from torch.optim import Adam
from network import StyleTransferNetwork
from torch.utils.data import DataLoader
from torchvision.models import vgg16, VGG16_Weights
from torchvision.models.feature_extraction import create_feature_extractor
from utils import ImageDataset, DataProcessor, imsave, imload
from loss import calc_content_loss, calc_style_loss, calc_tv_loss

NUM_STYLE = 16


def train(style_path, content_path,
          style_weight=5.0, tv_weight=1e-5,
          lr=1e-4, batch_size=8, iterations=40_000):
    """Train Network."""
    content_nodes = ['relu_3_3']
    style_nodes = ['relu_1_2', 'relu_2_2', 'relu_3_3', 'relu_4_2']
    return_nodes = {3: 'relu_1_2',
                    8: 'relu_2_2',
                    15: 'relu_3_3',
                    22: 'relu_4_2'}
    device = torch.device('cuda')

    # data
    content_dataset = ImageDataset(dir_path=Path(content_path))
    style_dataset = ImageDataset(dir_path=Path(style_path))

    data_processor = DataProcessor(imsize=256,
                                   cropsize=240,
                                   cencrop=False)
    content_dataloader = DataLoader(dataset=content_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    collate_fn=data_processor)
    style_dataloader = DataLoader(dataset=style_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  collate_fn=data_processor)

    # loss network
    vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
    for param in vgg.parameters():
        param.requires_grad = False
    loss_network = create_feature_extractor(vgg, return_nodes).to(device)

    # network
    model = StyleTransferNetwork()
    model.train()
    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=lr)

    losses = {'content': [], 'style': [], 'tv': [], 'total': []}
    print("Start training...")
    for i in range(1, 1+iterations):
        content_images, _ = next(iter(content_dataloader))
        style_images, style_indices = next(iter(style_dataloader))

        style_codes = torch.zeros(batch_size, NUM_STYLE, 1)
        for b, s in enumerate(style_indices):
            style_codes[b, s] = 1

        content_images = content_images.to(device)
        style_images = style_images.to(device)
        style_codes = style_codes.to(device)

        output_images = model(content_images, style_codes)

        content_features = loss_network(content_images)
        style_features = loss_network(style_images)
        output_features = loss_network(output_images)

        style_loss = calc_style_loss(output_features,
                                     style_features,
                                     style_nodes)
        content_loss = calc_content_loss(output_features,
                                         content_features,
                                         content_nodes)
        tv_loss = calc_tv_loss(output_images)

        total_loss = content_loss \
            + style_loss * style_weight \
            + tv_loss * tv_weight

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        losses['content'].append(content_loss.item())
        losses['style'].append(style_loss.item())
        losses['tv'].append(tv_loss.item())
        losses['total'].append(total_loss.item())

        if i % 100 == 0:
            log = f"iter.: {i}"
            for k, v in losses.items():
                # calcuate a recent average value
                avg = sum(v[-50:]) / 50
                log += f", {k}: {avg:1.4f}"
            print(log)

    torch.save({"state_dict": model.state_dict()}, "model.ckpt")


def evaluate(content_path, style_index):
    """Evaluate the network."""
    device = torch.device('cpu')
    ckpt = torch.load('model.ckpt', map_location=device)

    model = StyleTransferNetwork()
    model.load_state_dict(ckpt['state_dict'])
    model.eval()

    content_image = imload(args.content_path, imsize=256)
    # for all styles
    if style_index == -1:
        style_code = torch.eye(NUM_STYLE).unsqueeze(-1)
        content_image = content_image.repeat(NUM_STYLE, 1, 1, 1)

    # for specific style
    elif style_index in range(NUM_STYLE):
        style_code = torch.zeros(1, NUM_STYLE, 1)
        style_code[:, style_index, :] = 1

    else:
        raise RuntimeError("Not expected style index")

    stylized_image = model(content_image, style_code)
    imsave(stylized_image, 'stylized_images.jpg')
    return None


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--mode', type=str, default='train',
                        help="'train' | 'eval'")
    parser.add_argument('--style_path', type=str, default=None,
                        help="Path of style image.")
    parser.add_argument('--content_path', type=str, default=None,
                        help="Path of content image.")
    parser.add_argument('--style_index', type=int, default=0,
                        help="Index for stylization, -1: all styles.")

    args = parser.parse_args()

    if args.mode == 'train':
        train(args.style_path, args.content_path)

    elif args.mode == 'eval':
        evaluate(args.content_path, args.style_index)

    else:
        raise RuntimeError("Not exepcted mode")
