<a href="https://colab.research.google.com/github/tamirmal/tau_dl_proj/blob/master/Adaptive_Style_Transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install ipdb
import ipdb
import datetime, os

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision.datasets.mnist import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# For tensorboard use TF 2.x+
%tensorflow_version 2.x

from google.colab import drive
drive.mount('/content/drive')
GOOGLE_DRIVE_PATH='/content/drive/My Drive/Colab Notebooks/TAU_DL_PROJ/STYLE_TRANSFER/'

exists_file_path=GOOGLE_DRIVE_PATH+'exists.file'
if not os.path.isfile(GOOGLE_DRIVE_PATH+'exists.file'):
  print("problem mounting drive FS, failed to access file {}".format(exists_file_path))
  assert 0
else:
  print("successfully accessed drive FS")


Collecting ipdb
  Downloading https://files.pythonhosted.org/packages/df/78/3d0d7253dc85549db182cbe4b43b30c506c84008fcd39898122c9b6306a9/ipdb-0.12.2.tar.gz
Building wheels for collected packages: ipdb
  Building wheel for ipdb (setup.py) ... [?25l[?25hdone
  Created wheel for ipdb: filename=ipdb-0.12.2-cp36-none-any.whl size=9171 sha256=8ff441cfc07d5162481b8aa00b0adb2bae74888a5af4500b10b2593f92e1797d
  Stored in directory: /root/.cache/pip/wheels/7a/00/07/c906eaf1b90367fbb81bd840e56bf8859dbd3efe3838c0b4ba
Successfully built ipdb
Installing collected packages: ipdb
Successfully installed ipdb-0.12.2
TensorFlow 2.x selected.
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapi

Implementing https://arxiv.org/pdf/1703.06868.pdf
There is an official reference in Torch / Lua @ https://github.com/xunhuang1995/AdaIN-style/


In [5]:
class style_transfer_net(nn.Module):
      def __init__(self):
        super(style_transfer_net, self).__init__()

        ## using VGG as encoder/decoder
        ## TODO : consider using other architectures as suggested in the article
        ##        such as resnet34 etc. which are deep BUT have good convergence due to skip-connection (residuals)
        
        encoder_t = 'VGG19' # TODO in the future this will be external argument
        if encoder_t == 'VGG19':
          encoder = torchvision.models.vgg19(pretrained=True, progress=True)
          targetContentLayer = [12]         # relu4_1
          targetStyleLayers = [1, 4, 7, 12] # relu1_1,relu2_1,relu3_1,relu4_1
          
          encoder_layers = list(encoder.children())
          # style encoders - we need to extract intermediate features from SEVERAL layers
          # so i'm splitting the model and concatenating it in the FWD PASS
          self.encoder_style1 = nn.Sequential(*encoder_layers[:targetStyleLayers[0]])                      # input -> relu1_1
          self.encoder_style2 = nn.Sequential(*encoder_layers[targetStyleLayers[0]:targetStyleLayers[1]])  # relu1_1 -> relu2_1
          self.encoder_style3 = nn.Sequential(*encoder_layers[targetStyleLayers[1]:targetStyleLayers[2]])  # relu2_1 -> relu3_1
          self.encoder_style4 = nn.Sequential(*encoder_layers[targetStyleLayers[2]:targetStyleLayers[3]])  # relu3_1 -> relu4_1
          # content encoder
          self.encoder_content = nn.Sequential(*encoder_layers[18:targetStyleLayers[3]])                    # relu3_1 -> relu4_1

          # Encoder IS NOT trainable - freeze it
          for e in [encoder_style1, encoder_style_2, encoder_style_3, encoder_style_4, encoder_content]:
              for p in e.parameters():
                  p.requires_grad = False

        decoder_t = 'VGG19'
        if decoder_t == 'VGG19':
          self.decoder = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
          )

      # End

      def forward(self, content, style, alpha=1.0):
        assert alpha >= 0
        assert alpha <= 1

        # TODO - add asserts that encoders are NOT trainable !!!

        # FWD pass to extract style features
        #
        #  ENC1 --- ENC2 --- ENC3 --- ENC4 ---
        #        |        |        |        |
        #       feat1    feat2    feat3    feat4
        style_features = [ style ]
        style1 = self.encoder_style1(style)
        style2 = self.encoder_style1(style1)
        style3 = self.encoder_style1(style2)
        style4 = self.encoder_style1(style3)     
        style_features.append(style1)
        style_features.append(style2)
        style_features.append(style3)
        style_features.append(style4)
        style_features = style_features[1:]

        # content features extracted from feat4 (but content is input, not style)
        content_features = self.encoder_style1(content)
        content_features = self.encoder_style2(content_features)
        content_features = self.encoder_style3(content_features)
        content_features = self.encoder_style4(content_features)

      # End

IndentationError: ignored