<a href="https://colab.research.google.com/github/tamirmal/tau_dl_proj/blob/master/Adaptive_Style_Transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install ipdb
import ipdb
import datetime, os

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision.datasets.mnist import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# For tensorboard use TF 2.x+
%tensorflow_version 2.x

from google.colab import drive
drive.mount('/content/drive')
GOOGLE_DRIVE_PATH='/content/drive/My Drive/Colab Notebooks/TAU_DL_PROJ/STYLE_TRANSFER/'

exists_file_path=GOOGLE_DRIVE_PATH+'exists.file'
if not os.path.isfile(GOOGLE_DRIVE_PATH+'exists.file'):
  print("problem mounting drive FS, failed to access file {}".format(exists_file_path))
  assert 0
else:
  print("successfully accessed drive FS")


Collecting ipdb
  Downloading https://files.pythonhosted.org/packages/df/78/3d0d7253dc85549db182cbe4b43b30c506c84008fcd39898122c9b6306a9/ipdb-0.12.2.tar.gz
Building wheels for collected packages: ipdb
  Building wheel for ipdb (setup.py) ... [?25l[?25hdone
  Created wheel for ipdb: filename=ipdb-0.12.2-cp36-none-any.whl size=9171 sha256=afe1acdf7ca1831d28d350db43f605e65267c8da3afd9257f68cd5d36146710d
  Stored in directory: /root/.cache/pip/wheels/7a/00/07/c906eaf1b90367fbb81bd840e56bf8859dbd3efe3838c0b4ba
Successfully built ipdb
Installing collected packages: ipdb
Successfully installed ipdb-0.12.2
TensorFlow 2.x selected.
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapi

![adain_net](https://i.imgur.com/jAyz9hY.jpg)

Implementing https://arxiv.org/pdf/1703.06868.pdf
There is an official reference in Torch / Lua @ https://github.com/xunhuang1995/AdaIN-style/



---


AdaIN Layer implements the following :

![adain_layer](https://i.imgur.com/OiqyfkN.png)



In [0]:
def get_mu_and_sigma(features):
    # input is a tensor of shape : [minibatch_size, channels, h ,w]
    # output is a tensor of shape : [minibatch_size, channels, 1 ,1]

    epsilon = 1e-6
    minibatch_size, channels = features.size()[:2]

    features_channels_stacked = features.reshape(minibatch_size, channels, -1)

    features_mean_per_channel = features_channels_stacked.mean(dim=2)
    features_mean_per_channel = features_mean_per_channel.reshape(minibatch_size, channels, 1, 1) # set dim as tensor

    features_sigma_per_channel = features_channels_stacked.std(dim=2)
    features_sigma_per_channel = features_sigma_per_channel.reshape(minibatch_size, channels, 1, 1) # set dim as tensor

    return features_mean_per_channel, features_sigma_per_channel

class vgg19_encoder(nn.Module):
      def __init__(self):
          super(vgg19_encoder, self).__init__()
          
          encoder = torchvision.models.vgg19(pretrained=True, progress=True)
          encoder_layers = list(encoder.children())
          relu1_1 = 2
          relu2_1 = 7
          relu3_1 = 12
          relu4_1 = 21
          #print(encoder_layers) (I print layers to be sure i extract correct layers...)
          
          # style encoders - we need to extract intermediate features from SEVERAL layers
          # by splitting the model to parts we can take each part output AND feed it into next model part
          self.encoder_1 = nn.Sequential(*encoder_layers[:relu1_1])         # input -> relu1_1
          self.encoder_2 = nn.Sequential(*encoder_layers[relu1_1:relu2_1])  # relu1_1 -> relu2_1
          self.encoder_3 = nn.Sequential(*encoder_layers[relu2_1:relu3_1])  # relu2_1 -> relu3_1
          self.encoder_4 = nn.Sequential(*encoder_layers[relu3_1:relu4_1])  # relu3_1 -> relu4_1

          # Encoder IS NOT trainable - freeze it
          for e in [self.encoder_1, self.encoder_2, self.encoder_3, self.encoder_4]:
              for p in e.parameters():
                  p.requires_grad = False
        # END of __init__()

      def forward(self, x, last_only = True):
        # last_only : pass only the output of relu4_1 layer
        features = [ x ]

        features_1 = self.encoder_1(x)
        features_2 = self.encoder_2(features_1)
        features_3 = self.encoder_3(features_2)
        features_4 = self.encoder_4(features_3)
        
        if last_only is True:
          return features_4
        else:
          features.append(features_1)
          features.append(features_2)
          features.append(features_3)
          features.append(features_4)
          features = features[1:]
          return features
        # END of forward()

class style_transfer_net(nn.Module):
      def __init__(self):
        super(style_transfer_net, self).__init__()

        ## using VGG as encoder/decoder
        ## TODO : consider using other architectures as suggested in the article
        ##        such as resnet34 etc. which are deep BUT have good convergence due to skip-connection (residuals)
        
        encoder_t = 'VGG19' # TODO in the future this will be external argument
        if encoder_t == 'VGG19':
          encoder = torchvision.models.vgg19(pretrained=True, progress=True)
          encoder_layers = list(encoder.children())
          relu1_1 = 2
          relu2_1 = 7
          relu3_1 = 12
          relu4_1 = 21
          #print(encoder_layers) (I print layers to be sure i extract correct layers...)
          
          # style encoders - we need to extract intermediate features from SEVERAL layers
          # by splitting the model to parts we can take each part output AND feed it into next model part
          self.encoder_style1 = nn.Sequential(*encoder_layers[:relu1_1])         # input -> relu1_1
          self.encoder_style2 = nn.Sequential(*encoder_layers[relu1_1:relu2_1])  # relu1_1 -> relu2_1
          self.encoder_style3 = nn.Sequential(*encoder_layers[relu2_1:relu3_1])  # relu2_1 -> relu3_1
          self.encoder_style4 = nn.Sequential(*encoder_layers[relu3_1:relu4_1])  # relu3_1 -> relu4_1
          # content encoder (needs to be concatenated to encoder_style3)
          self.encoder_content = nn.Sequential(*encoder_layers[relu3_1:relu4_1]) # relu3_1 -> relu4_1

          # Encoder IS NOT trainable - freeze it
          for e in [self.encoder_style1, self.encoder_style2, self.encoder_style3, self.encoder_style4, self.encoder_content]:
              for p in e.parameters():
                  p.requires_grad = False

        decoder_t = 'VGG19'
        if decoder_t == 'VGG19':
          self.decoder = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
          )

      # End
      def adain_layer(content_features, style_features):
          # Adaptive instance normalization
          # Inputs are :
          #  content features - the content image output from VGG_ENCODER.relu4_1     [batch_size, 512, h/8, w/8]
          #  style features   - the style image output from VGG_ENCODER.relu4_1       [batch_size, 512, h/8, w/8]
          # Therefore they have the same dimensions of 512x(H/8)x(W/8)
          # This layer calculates a per-channel mean and std of the style features
          # and scales the content features so they have the same mean and std (per channel) of the style
          
          content_mu, content_sigma = get_mu_and_sigma(content_features)
          style_mu, style_sigma = get_mu_and_sigma(style_features)

          normalized_content_features = (content_features - content_mu) / content_sigma
          style_normalized_content_features = style_sigma*normalized_content_features + style_mu
          return style_normalized_content_features


      @staticmethod
      def calc_content_loss(out_features, t):
          return F.mse_loss(out_features, t)

      @staticmethod
      def calc_style_loss(content_middle_features, style_middle_features):
          loss = 0
          for c, s in zip(content_middle_features, style_middle_features):
              c_mean, c_std = calc_mean_std(c)
              s_mean, s_std = calc_mean_std(s)
              loss += F.mse_loss(c_mean, s_mean) + F.mse_loss(c_std, s_std)
          return loss


      def forward(self, content, style, alpha=1.0):
        assert alpha >= 0
        assert alpha <= 1
        # TODO - add asserts that encoders are NOT trainable !!!

        ###########################################
        # Encoder pass of content and style images
        ###########################################

        #
        #  ENC1 --- ENC2 --- ENC3 --- ENC4 ---
        #        |        |        |        |
        #       feat1    feat2    feat3    feat4
        style_features = [ style ]
        style1 = self.encoder_style1(style)
        style2 = self.encoder_style2(style1)
        style3 = self.encoder_style3(style2)
        style4 = self.encoder_style4(style3)     
        style_features.append(style1)
        style_features.append(style2)
        style_features.append(style3)
        style_features.append(style4)
        style_features = style_features[1:]
        # content features extracted from feat4 (but content is input, not style)
        content_features = self.encoder_style1(content)
        content_features = self.encoder_style2(content_features)
        content_features = self.encoder_style3(content_features)
        content_features = self.encoder_content(content_features)

        ###########################################
        # AdaIn step
        ###########################################
        # feed into AdaIn layer the style & content features, get style-normalized content features
        style_norm_content = adain_layer(content_features, style4)
        style_norm_content = alpha*style_norm_content + (1-alpha)*content_features # hyper-parameter, a tradeoff between content and style
        
        ###########################################
        # Apply the style transfer
        ###########################################
        # pass through decoder, obtain transformed image
        out = self.decoder(style_norm_content)

        ###########################################
        # Loss calculation
        ###########################################
        # get content & style features of output image (after style transfer), same process as above
        # TODO - enclose this in a function, too much repeating code
        out_content_features = self.encoder_style1(out)
        out_content_features = self.encoder_style2(out_content_features)
        out_content_features = self.encoder_style3(out_content_features)
        out_content_features = self.encoder_style4(out_content_features)
        # get style features
        out_style = [ style ]
        out_style1 = self.encoder_style1(out)
        out_style2 = self.encoder_style2(out_style1)
        out_style3 = self.encoder_style3(out_style2)
        out_style4 = self.encoder_style4(out_style3)
        out_style.append(out_style1)
        out_style.append(out_style2)
        out_style.append(out_style3)
        out_style.append(out_style4)
        out_style = style_features[1:]


      # End

model = style_transfer_net()