<a href="https://colab.research.google.com/github/tamirmal/tau_dl_proj/blob/master/Adaptive_Style_Transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install ipdb
#import ipdb

import datetime, os

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# For tensorboard use TF 2.x+
%tensorflow_version 2.x

# Mount GCP bucket
from google.colab import auth
auth.authenticate_user()
project_id = 'tau-dl'
!gcloud config set project {project_id}
!gsutil ls
!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
!apt -qq update
!apt -qq install gcsfuse
!mkdir -p /content/adaptive_style_transfer/
!gcsfuse --implicit-dirs adaptive_style_transfer /content/adaptive_style_transfer/

TensorFlow 2.x selected.
Updated property [core/project].
gs://adaptive_style_transfer/
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   659  100   659    0     0  24407      0 --:--:-- --:--:-- --:--:-- 24407
OK
68 packages can be upgraded. Run 'apt list --upgradable' to see them.
The following package was automatically installed and is no longer required:
  libnvidia-common-430
Use 'apt autoremove' to remove it.
The following NEW packages will be installed:
  gcsfuse
0 upgraded, 1 newly installed, 0 to remove and 68 not upgraded.
Need to get 4,274 kB of archives.
After this operation, 12.8 MB of additional disk space will be used.
Selecting previously unselected package gcsfuse.
(Reading database ... 145605 files and directories currently installed.)
Preparing to unpack .../gcsfuse_0.28.1_amd64.deb ...
Unpacking gcsfuse (0.28.1) ...
Setting up gcsfuse (0.28.1) ...
Using 

![adain_net](https://i.imgur.com/jAyz9hY.jpg)

Implementing https://arxiv.org/pdf/1703.06868.pdf
There is an official reference in Torch / Lua @ https://github.com/xunhuang1995/AdaIN-style/



---


AdaIN Layer implements the following :

![adain_layer](https://i.imgur.com/OiqyfkN.png)



In [17]:
def get_mu_and_sigma(features):
    # input is a tensor of shape : [minibatch_size, channels, h ,w]
    # output is a tensor of shape : [minibatch_size, channels, 1 ,1]

    epsilon = 1e-6
    minibatch_size, channels = features.size()[:2]

    features_channels_stacked = features.reshape(minibatch_size, channels, -1)

    features_mean_per_channel = features_channels_stacked.mean(dim=2)
    features_mean_per_channel = features_mean_per_channel.reshape(minibatch_size, channels, 1, 1) # set dim as tensor

    features_sigma_per_channel = features_channels_stacked.std(dim=2)
    features_sigma_per_channel = features_sigma_per_channel.reshape(minibatch_size, channels, 1, 1) # set dim as tensor

    return features_mean_per_channel, features_sigma_per_channel

class vgg19_encoder(nn.Module):
      def __init__(self):
          super(vgg19_encoder, self).__init__()
          
          encoder = torchvision.models.vgg19(pretrained=True, progress=True)
          print(encoder) # print encoder, to make sure i'm extracting the correct layers
          encoder_layers = list(encoder.features.children())
          relu1_1 = 2
          relu2_1 = 7
          relu3_1 = 12
          relu4_1 = 21
          
          # style encoders - we need to extract intermediate features from SEVERAL layers
          # by splitting the model to parts we can take each part output AND feed it into next model part
          self.encoder_1 = nn.Sequential(*encoder_layers[:relu1_1])         # input -> relu1_1
          self.encoder_2 = nn.Sequential(*encoder_layers[relu1_1:relu2_1])  # relu1_1 -> relu2_1
          self.encoder_3 = nn.Sequential(*encoder_layers[relu2_1:relu3_1])  # relu2_1 -> relu3_1
          self.encoder_4 = nn.Sequential(*encoder_layers[relu3_1:relu4_1])  # relu3_1 -> relu4_1

          # Encoder IS NOT trainable - freeze it
          for e in [self.encoder_1, self.encoder_2, self.encoder_3, self.encoder_4]:
              for p in e.parameters():
                  p.requires_grad = False
        # END of __init__()

      def forward(self, x, last_only = True):
        #
        #  ENC1 --- ENC2 --- ENC3 --- ENC4 ---
        #        |        |        |        |
        #     relu1_1    relu2_1  relu3_1  relu4_1

        # last_only : pass only the output of relu4_1 layer
        features = [ x ]

        features_1 = self.encoder_1(x)
        features_2 = self.encoder_2(features_1)
        features_3 = self.encoder_3(features_2)
        features_4 = self.encoder_4(features_3)
        
        if last_only is True:
          return features_4
        else:
          features.append(features_1)
          features.append(features_2)
          features.append(features_3)
          features.append(features_4)
          features = features[1:]
          return features
      # END of forward()


vgg19_decoder = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
          )

class style_transfer_net(nn.Module):
      def __init__(self):
        super(style_transfer_net, self).__init__()

        ## using VGG as encoder/decoder
        ## TODO : consider using other architectures as suggested in the article
        ##        such as resnet34 etc. which are deep BUT have good convergence due to skip-connection (residuals)
        
        encoder_t = 'VGG19' # TODO in the future this will be external argument
        if encoder_t == 'VGG19':
          self.encoder = vgg19_encoder()

        decoder_t = 'VGG19'
        if decoder_t == 'VGG19':
          self.decoder = vgg19_decoder
      # End

      @staticmethod
      def adain_layer(content_features, style_features):
          # Adaptive instance normalization
          # Inputs are :
          #  content features - the content image output from VGG_ENCODER.relu4_1     [batch_size, 512, h/8, w/8]
          #  style features   - the style image output from VGG_ENCODER.relu4_1       [batch_size, 512, h/8, w/8]
          # Therefore they have the same dimensions of 512x(H/8)x(W/8)
          # This layer calculates a per-channel mean and std of the style features
          # and scales the content features so they have the same mean and std (per channel) of the style
          
          content_mu, content_sigma = get_mu_and_sigma(content_features)
          style_mu, style_sigma = get_mu_and_sigma(style_features)

          normalized_content_features = (content_features - content_mu) / content_sigma
          style_normalized_content_features = style_sigma*normalized_content_features + style_mu
          return style_normalized_content_features


      @staticmethod
      def calc_content_loss(out_content, adain_content):
          return F.mse_loss(out_content, adain_content)

      @staticmethod
      def calc_style_loss(out_style, in_style):
          loss = 0
          for a,b in zip(out_style, in_style):
              a_mu, a_sigma = get_mu_and_sigma(a)
              b_mu, b_sigma = get_mu_and_sigma(b)
              loss += F.mse_loss(a_mu, b_mu) + F.mse_loss(a_sigma, b_sigma)
          return loss

      def forward(self, content, style, alpha=1.0):
        assert alpha >= 0
        assert alpha <= 1
        # TODO - add asserts that encoders are NOT trainable !!!

        ###########################################
        # Encoder pass of content and style images
        ###########################################
        style_features = self.encoder(style, last_only=False)   # for VGG19 [relu1_1, relu2_1, relu3_1, relu4_1]
        content_features = self.encoder(style, last_only=True)  # for VGG19 relu4_1

        ###########################################
        # AdaIn step
        ###########################################
        # feed into AdaIn layer the style & content features, get style-normalized content features
        style4=style_features[-1]
        style_norm_content = self.adain_layer(content_features, style4)
        style_norm_content = alpha*style_norm_content + (1-alpha)*content_features # hyper-parameter, a tradeoff between content and style
        
        ###########################################
        # Apply the style transfer
        ###########################################
        # pass through decoder, obtain transformed image
        out = self.decoder(style_norm_content)

        ###########################################
        # Loss calculation
        ###########################################
        # get content & style features of output image (after style transfer), same process as above
        # TODO - enclose this in a function, too much repeating code
        out_content_features = self.encoder(out, last_only=True)  # for VGG19 relu4_1
        content_loss = self.calc_content_loss(out_content_features, style_norm_content)
        # get style features
        out_style_features = self.encoder(out, last_only=False)   # for VGG19 [relu1_1, relu2_1, relu3_1, relu4_1]
        style_loss = self.calc_style_loss(out_style_features, style_features)
        # combine the losses
        return content_loss, style_loss
      # End

model = style_transfer_net()
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [3]:
from torch.utils import data
from PIL import Image, ImageFile

class DatasetFolders(data.Dataset):
    def __init__(self, files_paths, transform):
        super(DatasetFolders, self).__init__()
        with open(files_paths) as F:
          paths = F.readlines()
        self.paths = [x.strip() for x in paths] 
        self.transform = transform

    def __getitem__(self, index):
        path = self.paths[index]
        path = path.replace('gs://', '/content/')
        img = Image.open(str(path)).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.paths)

    def name(self):
        return 'DatasetFolders'


trans = [
    transforms.Resize(size=(512, 512)),
    transforms.RandomCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
trans = transforms.Compose(trans)

##### Ive manually calculated the dataset length ... lazy to write something better
## Number of content examples:
## 82783 /root/content_examples_paths.lst
## Number of style examples:
## 79433 /root/style_examples_paths.lst

#!gsutil ls gs://adaptive_style_transfer/COCO_SPLITS/*/*.jpg > /content/adaptive_style_transfer/content_examples_paths.lst
!gsutil cp gs://adaptive_style_transfer/content_examples_paths.lst /root/
!echo "Number of content examples:"
!wc -l /root/content_examples_paths.lst
content_dataset = DatasetFolders('/root/content_examples_paths.lst', trans)

#!gsutil ls gs://adaptive_style_transfer/WIKIART_SPLITS/*/*.jpg > /content/adaptive_style_transfer/style_examples_paths.lst
!gsutil cp gs://adaptive_style_transfer/style_examples_paths.lst /root/
!echo "Number of style examples:"
!wc -l /root/style_examples_paths.lst
style_dataset = DatasetFolders('/root/style_examples_paths.lst', trans)


Copying gs://adaptive_style_transfer/content_examples_paths.lst...
/ [1 files][  7.0 MiB/  7.0 MiB]                                                
Operation completed over 1 objects/7.0 MiB.                                      
Number of content examples:
82783 /root/content_examples_paths.lst
Copying gs://adaptive_style_transfer/style_examples_paths.lst...
/ [1 files][  4.7 MiB/  4.7 MiB]                                                
Operation completed over 1 objects/4.7 MiB.                                      
Number of style examples:
79433 /root/style_examples_paths.lst


"We train our network using **MS-COCO [36] as content
images** and a dataset of paintings mostly collected from
**WikiArt [39] as style images**, following the setting of [6].
Each dataset contains roughly 80; 000 training examples.
We use the adam optimizer [26] and a **batch size of 8**
content-style image pairs. During training, we **first resize
the smallest dimension of both images to 512 while preserving the aspect ratio, then randomly crop regions of size
256 × 256**. Since our network is fully convolutional, it can
be applied to images of any size during testing."

In [0]:
sd = 999
def torch_seed():
    global sd
    torch.manual_seed(sd)
    if torch.cuda.is_available():
      torch.cuda.manual_seed_all(sd)

## Training will not be in epochs but in iterations
## the dataloaders will need to be infinite
## (in epoch we finish an epoch at end of iteration)
## (doing as in the article/lua-torch code. training for 160,000 iteration with batch size of 8)

# the built in samplers in pytorch are not infinite, defining my own
def InfiniteSampler(dataset_len):
    torch_seed()
    perm_iter = iter(torch.randperm(dataset_len).tolist())
    while True:
      try:
        yield perm_iter.__next__()
      except StopIteration:
        perm_iter = iter(torch.randperm(dataset_len).tolist())
        yield perm_iter.__next__()

class InfiniteSamplerWrapper(data.sampler.Sampler):
    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2 ** 31

content_iter = iter(data.DataLoader(
    content_dataset, batch_size=GLOBAL_ARGS['batch_size'],
    sampler=InfiniteSamplerWrapper(content_dataset),
    num_workers=4))
style_iter = iter(data.DataLoader(
    style_dataset, batch_size=GLOBAL_ARGS['batch_size'],
    sampler=InfiniteSamplerWrapper(style_dataset),
    num_workers=4))

In [0]:
"""
content_iter = iter(data.DataLoader(
    content_dataset, batch_size=1,
    sampler=InfiniteSamplerWrapper(content_dataset),
    num_workers=1))
style_iter = iter(data.DataLoader(
    style_dataset, batch_size=1,
    sampler=InfiniteSamplerWrapper(style_dataset),
    num_workers=1))

images = content_iter.next()
print('images shape on batch size = {}'.format(images.size()))
grid = torchvision.utils.make_grid(images)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
print(images[0])
"""

In [0]:
from torch.optim import Adam
from torch.optim import SGD

# as in original LUA implemantation
def adjust_lr(optimizer, iter_i):
    lr = GLOBAL_ARGS['lr'] / (1.0 + GLOBAL_ARGS['lr_decay'] * iter_i)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

if GLOBAL_ARGS['optimizer'] == 'adam':
  optimizer = Adam(model.parameters(), lr=GLOBAL_ARGS['lr'])
else:
  optimizer = SGD(model.parameters(), lr=GLOBAL_ARGS['lr'], momentum=GLOBAL_ARGS['momentum'])


In [0]:
from tqdm import tqdm

if torch.cuda.is_available():
  device = torch.device("cuda:0")
else:
  # no GPU available
  assert 0

## Train the network
for i in tqdm(range(GLOBAL_ARGS['max_iters'])):
    model.to(device)
    model.train()
    adjust_lr(optimizer, iter_i=i)
    content_images = next(content_iter).to(device)
    style_images = next(style_iter).to(device)
    loss_c, loss_s = model(content_images, style_images)
    loss_c = GLOBAL_ARGS['loss_content_w'] * loss_c
    loss_s = GLOBAL_ARGS['loss_style_w'] * loss_s
    loss = loss_c + loss_s

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
# end of for

  0%|          | 108/160000 [27:11<593:09:19, 13.36s/it]

In [5]:
GLOBAL_ARGS = {}
GLOBAL_ARGS['lr'] = 1e-4
GLOBAL_ARGS['lr_decay'] = 5e-5
GLOBAL_ARGS['momentum'] = 0.9
GLOBAL_ARGS['optimizer'] = 'adam'
GLOBAL_ARGS['batch_size'] = 8
GLOBAL_ARGS['max_iters'] = 160000
GLOBAL_ARGS['loss_style_w'] = 1e-2
GLOBAL_ARGS['loss_content_w'] = 1

"""
-- Training options
cmd:option('-resume', false, 'If true, resume training from the last checkpoint')
cmd:option('-optimizer', 'adam', 'Optimizer used, adam|sgd')
cmd:option('-learningRate', 1e-4, 'Learning rate')
cmd:option('-learningRateDecay', 5e-5, 'Learning rate decay')
cmd:option('-momentum', 0.9, 'Momentum')
cmd:option('-weightDecay', 0, 'Weight decay')
cmd:option('-batchSize', 8, 'Batch size')
cmd:option('-maxIter', 160000, 'Maximum number of iterations')
cmd:option('-targetContentLayer', 'relu4_1', 'Target content layer used to compute the loss')
cmd:option('-targetStyleLayers', 'relu1_1,relu2_1,relu3_1,relu4_1', 'Target style layers used to compute the loss')
cmd:option('-tvWeight', 0, 'Weight of TV loss')
cmd:option('-styleWeight', 1e-2, 'Weight of style loss')
cmd:option('-contentWeight', 1, 'Weight of content loss')
cmd:option('-reconStyle', false, 'If true, the decoder is also trained to reconstruct style images')
cmd:option('-normalize', false, 'If true, gradients at the loss function are normalized')
"""

"\n-- Training options\ncmd:option('-resume', false, 'If true, resume training from the last checkpoint')\ncmd:option('-optimizer', 'adam', 'Optimizer used, adam|sgd')\ncmd:option('-learningRate', 1e-4, 'Learning rate')\ncmd:option('-learningRateDecay', 5e-5, 'Learning rate decay')\ncmd:option('-momentum', 0.9, 'Momentum')\ncmd:option('-weightDecay', 0, 'Weight decay')\ncmd:option('-batchSize', 8, 'Batch size')\ncmd:option('-maxIter', 160000, 'Maximum number of iterations')\ncmd:option('-targetContentLayer', 'relu4_1', 'Target content layer used to compute the loss')\ncmd:option('-targetStyleLayers', 'relu1_1,relu2_1,relu3_1,relu4_1', 'Target style layers used to compute the loss')\ncmd:option('-tvWeight', 0, 'Weight of TV loss')\ncmd:option('-styleWeight', 1e-2, 'Weight of style loss')\ncmd:option('-contentWeight', 1, 'Weight of content loss')\ncmd:option('-reconStyle', false, 'If true, the decoder is also trained to reconstruct style images')\ncmd:option('-normalize', false, 'If true