<a href="https://colab.research.google.com/github/pooriaazami/deep_learning_class_notebooks/blob/main/19_Arbitrary_Style_Transfer_(AdaIN).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/xunhuang1995/AdaIN-style.git

!mkdir dataset

!mv AdaIN-style/input/style dataset/style
!mv AdaIN-style/input/content/ dataset/content

!rm -rf AdaIN-style/

Cloning into 'AdaIN-style'...
remote: Enumerating objects: 221, done.[K
remote: Counting objects: 100% (48/48), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 221 (delta 43), reused 41 (delta 41), pack-reused 173[K
Receiving objects: 100% (221/221), 22.83 MiB | 24.33 MiB/s, done.
Resolving deltas: 100% (78/78), done.


In [2]:
import os
import glob
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

import torchvision as tv
import torchvision.transforms as T

from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

In [3]:
LR = 1e-4
BETA_1 = .5
BETA_2 = .99

IMAGE_SIZE = 256
DATASET_PATH = 'dataset'

VGG_PREPROCESSING_MEAN = (.485, .486, .406)
VGG_PREPROCESSING_STD = (.229, .224, .225)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

BATCH_SIZE = 8
EPOCHS = 10

EPS = 1e-7
LAMBDA = .1

In [4]:
class StyleTransferDataset(Dataset):
  def __init__(self, root, transforms=None):
    super().__init__()

    style_path = os.path.join(root, 'style')
    content_path = os.path.join(root, 'content')

    self.style_images = glob.glob(style_path + '/*.jpg')
    self.content_images = glob.glob(content_path + '/*.jpg')
    self.transforms = transforms

  def __len__(self):
    return len(self.style_images) * len(self.content_images)

  def __getitem__(self, idx):
    style_idx, content_idx = divmod(idx, len(self.content_images))

    style = Image.open(self.style_images[style_idx])
    content = Image.open(self.content_images[content_idx])

    if self.transforms:
      style = self.transforms(style)
      content = self.transforms(content)

    return style, content

In [5]:
class VggUPBlock(nn.Module):
  def __init__(self, in_channels, out_channels, upsample=True):
    super().__init__()

    if upsample:
      self.layers = nn.Sequential(
          nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
          nn.ReLU(),
          nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
          nn.ReLU(),
          nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
          nn.ReLU(),
      )
    else:
      self.layers = nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
          nn.ReLU(),
          nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
          nn.ReLU(),
      )

  def forward(self, x):
    return self.layers(x)

In [6]:
class Decoder(nn.Module):
  def __init__(self):
    super().__init__()

    self.layers = nn.Sequential(
       VggUPBlock(512, 256),
       VggUPBlock(256, 256, False),

       VggUPBlock(256, 128),
       VggUPBlock(128, 128, False),

       VggUPBlock(128, 64),
       VggUPBlock(64, 64, False),

       VggUPBlock(64, 32),

       VggUPBlock(32, 16),
       VggUPBlock(16, 3, False),
    )

  def forward(self, x):
    return self.layers(x)

In [7]:
class AdaIN(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, content, style):
    N, C, W, H = content.size()

    flattened_content = content.reshape(N, C, -1)
    flattened_style = style.reshape(N, C, -1)

    content_mean = flattened_content.mean(axis=-1, keepdims=True)
    content_std = flattened_content.std(axis=-1, keepdims=True)

    style_mean = flattened_style.mean(axis=-1, keepdims=True)
    style_std = flattened_style.std(axis=-1, keepdims=True)

    output = style_std * (flattened_content - content_mean) / (content_std + EPS) + style_mean
    output = output.reshape(N, C, W, H)
    return output

In [8]:
def get_layer_name(layer):
  if isinstance(layer, nn.MaxPool2d):
    return 'pool'
  if isinstance(layer, nn.Conv2d):
    return 'conv'
  if isinstance(layer, nn.ReLU):
    return 'relu'

  return 'UNK'

In [9]:
def build_layer_map(model):
  block_counter = 1
  cache = defaultdict(int)
  layer_map = {}

  for name, layer in model.named_children():
    layer_name = get_layer_name(layer)

    cache[layer_name] += 1
    layer_num = cache[layer_name]
    key = f'{layer_name}{block_counter}_{layer_num}'
    layer_map[name] = key

    if layer_name == 'pool':
      cache = defaultdict(int)
      block_counter += 1

  return layer_map

In [10]:
def cache_model_activations(model, layer_map, target_layers, x):
  cache = {}

  for name, layer in model.named_children():
    x = layer(x)

    if layer_map[name] in target_layers:
      cache[layer_map[name]] = x

  return cache, x

In [11]:
def calculate_style_loss(style_cache, generated_image_cache, target_layers):
  style_loss = .0

  for layer in target_layers:
    style = style_cache[layer]
    generated = generated_image_cache[layer]
    N, C, W, H = style.size()

    flattened_style = style.reshape(N, C, -1)
    flattened_generated = generated.reshape(N, C, -1)

    style_mean = flattened_style.mean(axis=-1, keepdims=True)
    style_std = flattened_style.std(axis=-1, keepdims=True)

    generated_mean = flattened_generated.mean(axis=-1, keepdims=True)
    generated_std = flattened_generated.std(axis=-1, keepdims=True)

    layer_loss = criterion(style_mean, generated_mean) + \
                  criterion(style_std, generated_std)

    style_loss += layer_loss

  return style_loss

In [12]:
weights = tv.models.VGG19_Weights.IMAGENET1K_V1
vgg = tv.models.vgg19(weights=weights).features.to(DEVICE)

decoder = Decoder().to(DEVICE)
adain = AdaIN()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:04<00:00, 118MB/s]


In [13]:
layer_map = build_layer_map(vgg)

In [14]:
target_layers = {'relu1_1', 'relu2_1', 'relu3_1', 'relu4_1'}

In [15]:
optimizer = optim.Adam(decoder.parameters(), lr=LR, betas=(BETA_1, BETA_2))
criterion = nn.MSELoss()

In [16]:
transforms = T.Compose([
    T.ToTensor(),
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.Normalize(VGG_PREPROCESSING_MEAN, VGG_PREPROCESSING_STD)
])


dataset = StyleTransferDataset(DATASET_PATH, transforms=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [18]:
for epoch in range(1, EPOCHS + 1):
  print(f'Epoch {epoch} / {EPOCHS}')

  total_loss = .0
  for style, content in tqdm(dataloader):
    style = style.to(DEVICE)
    content = content.to(DEVICE)

    optimizer.zero_grad()

    style_cache, style_latent = cache_model_activations(vgg, layer_map, target_layers, style)
    content_latent = vgg(content)

    mixed_latent = adain(style_latent, content_latent)
    generated_image = decoder(mixed_latent)

    generated_image_cache, generated_image_latent = cache_model_activations(vgg, layer_map, target_layers, generated_image)

    content_loss = criterion(mixed_latent, generated_image_latent)
    style_loss = calculate_style_loss(style_cache, generated_image_cache, target_layers)
    loss = content_loss + LAMBDA * style_loss

    loss.backward()
    optimizer.step()

    total_loss += loss.detach().cpu().item()

  print(f'Total loss: {total_loss:.2f}')

Epoch 1 / 10


  0%|          | 0/27 [00:00<?, ?it/s]

Total loss: 105.76
Epoch 2 / 10


  0%|          | 0/27 [00:00<?, ?it/s]

Total loss: 93.87
Epoch 3 / 10


  0%|          | 0/27 [00:00<?, ?it/s]

Total loss: 89.80
Epoch 4 / 10


  0%|          | 0/27 [00:00<?, ?it/s]

Total loss: 83.61
Epoch 5 / 10


  0%|          | 0/27 [00:00<?, ?it/s]

Total loss: 77.35
Epoch 6 / 10


  0%|          | 0/27 [00:00<?, ?it/s]

Total loss: 76.58
Epoch 7 / 10


  0%|          | 0/27 [00:00<?, ?it/s]

Total loss: 74.08
Epoch 8 / 10


  0%|          | 0/27 [00:00<?, ?it/s]

Total loss: 72.43
Epoch 9 / 10


  0%|          | 0/27 [00:00<?, ?it/s]

Total loss: 74.77
Epoch 10 / 10


  0%|          | 0/27 [00:00<?, ?it/s]

Total loss: 73.59
