# imports

In [0]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import cv2
from matplotlib import pyplot as plt
from tensorflow.keras.applications import VGG19

from glob import glob
import os


# data

In [0]:
!rm -rf /usr/local/lib/python2.7
!rm -rf /swift
!rm -rf /tensorflow-1.15.2/python2.7

!wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip -O /tmp/train2014.zip

!unzip -q /tmp/train2014.zip -d /tmp/data
!rm /tmp/train2014.zip 

# files

In [0]:
# style to be trained
style_img = 'style1.jpg'

# test image to check progress
content_img = 'dancing.jpg'

style_file_path = '.../style/'
content_file_path = '.../content/'
output_file_path = '.../stylized2/'

# test data
train_file_path = '/tmp/data/train2014/'

saved_weights_file_path = '.../saved_weights/'


content_layers = ['block4_conv2']
style_layers = ['block1_conv1',
                'block2_conv1',
                'block3_conv1', 
                'block4_conv1', 
                'block5_conv1']


# image processing

In [0]:
# rescale style image to content dimensions
# training dims 256 x 256 fixed
# max dim=512
def rescale_image(img_path, dims = None):
  dim = 512
  img = cv2.imread(img_path)
  a, b, _ = img.shape
  a, b = min(a,dim), min(b,dim)
  if(dims != None):
    a, b = dims[1], dims[2]
  x = cv2.resize(img, (a, b))

  return x

In [0]:
# scale tensor values to lie in range [0, 1]
# this though not necessary speeds up learning
def image_to_tensor(x):
  x = x / 255.0
  if(len(x.shape)<4):
    x = np.expand_dims(x, axis=0)
  x = tf.convert_to_tensor(x, dtype=tf.float32)

  return x


In [0]:
def deprocess_image(tensor):
  x = np.array(tensor) * 255.0
  x = x[0]
  x = np.clip(x, 0, 255).astype("uint8")

  plt.imshow(x)
  plt.show()

  return x

# model

In [0]:
def vgg_layers(layer_names):
  vgg = VGG19(include_top=False, weights='imagenet')
  vgg.trainable = False
  outputs = []

  outputs = [vgg.get_layer(name).output for name in layer_names]
  
  model = tf.keras.Model([vgg.input], outputs)
  return model

In [0]:
# instance normalization normalizes across the channels
# it tries to address that network should be unbiased towards contrast of the image
# works better than batch normalization

# conv -> instanceNorm -> activation
class ConvLayer(tf.keras.layers.Layer):
  def __init__(self, filters, 
               kernel=(3,3), padding='same', 
               strides=(1,1), activate=True, name="", 
               weight_initializer="glorot_uniform"
               ):
    super(ConvLayer, self).__init__()
    self.activate = activate
    self.conv = tf.keras.layers.Conv2D(filters, kernel_size=kernel, 
                       padding=padding, strides=strides, 
                       name=name, trainable=True,
                       use_bias=False, 
                       kernel_initializer=weight_initializer)
    self.inst_norm = tfa.layers.InstanceNormalization(axis=3, 
                                          center=True, 
                                          scale=True, 
                                          beta_initializer="zeros", 
                                          gamma_initializer="ones",
                                          trainable=True)
    if self.activate:
      self.relu_layer = tf.keras.layers.Activation('relu')

  def call(self, x):
    x = self.conv(x)
    x = self.inst_norm(x)
    if self.activate:
      x = self.relu_layer(x)
    return x

# convT -> instanceNorm -> activation
class ConvTLayer(tf.keras.layers.Layer):
  def __init__(self, filters, kernel=(3,3), padding='same',
               strides=(1,1), activate=True, name="",
               weight_initializer="glorot_uniform" 
               ):
    super(ConvTLayer, self).__init__()
    self.activate = activate
    self.conv_t = tf.keras.layers.Conv2DTranspose(filters, kernel_size=kernel, padding=padding, 
                                  strides=strides, name=name, 
                                  use_bias=False,
                                  kernel_initializer=weight_initializer)
    self.inst_norm = tfa.layers.InstanceNormalization(axis=3, 
                                          center=True, 
                                          scale=True, 
                                          beta_initializer="zeros", 
                                          gamma_initializer="ones",
                                          trainable=True)
    if self.activate:
      self.relu_layer = tf.keras.layers.Activation('relu')

  def call(self, x):
    x = self.conv_t(x)
    x = self.inst_norm(x)
    if self.activate:
      x = self.relu_layer(x)
    return x

# (conv1 -> conv2) + x
class ResBlock(tf.keras.layers.Layer):
  def __init__(self, filters, kernel=(3,3), padding='same', weight_initializer="glorot_uniform", prefix=""):
    super(ResBlock, self).__init__()
    self.prefix_name = prefix + "_"
    self.conv1 = ConvLayer(filters=filters, 
                           kernel=kernel, 
                           padding=padding, 
                           weight_initializer=weight_initializer,
                           name=self.prefix_name + "conv_1")
    self.conv2 = ConvLayer(filters=filters, 
                           kernel=kernel, 
                           padding=padding, 
                           activate=False, 
                           weight_initializer=weight_initializer,
                           name=self.prefix_name + "conv_2")
    self.add = tf.keras.layers.Add(name=self.prefix_name + "add")

  def call(self, x):
    tmp = self.conv1(x)
    c = self.conv2(tmp)
    return self.add([x, c])


In [0]:
class TransformNet(tf.keras.models.Model):
  def __init__(self):
    super(TransformNet, self).__init__()
    self.conv1 = ConvLayer(32, (9,9), strides=(1,1), padding='same', name="conv_1")
    self.conv2 = ConvLayer(64, (3,3), strides=(2,2), padding='same', name="conv_2")
    self.conv3 = ConvLayer(128, (3,3), strides=(2,2), padding='same', name="conv_3")
    self.res1 = ResBlock(128, prefix="res_1")
    self.res2 = ResBlock(128, prefix="res_2")
    self.res3 = ResBlock(128, prefix="res_3")
    self.res4 = ResBlock(128, prefix="res_4")
    self.res5 = ResBlock(128, prefix="res_5")
    self.convt1 = ConvTLayer(64, (3,3), strides=(2,2), padding='same', name="conv_t_1")
    self.convt2 = ConvTLayer(32, (3,3), strides=(2,2), padding='same', name="conv_t_2")
    self.conv4 = ConvLayer(3, (9,9), strides=(1,1), padding='same', activate=False, name="conv_4")
    self.tanh = tf.keras.layers.Activation('tanh')

  def call(self, inputs):
    # inputs in form of (None, None, None, 3) tensor
    x = self.conv1(inputs)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.res1(x)
    x = self.res2(x)
    x = self.res3(x)
    x = self.res4(x)
    x = self.res5(x)
    x = self.convt1(x)
    x = self.convt2(x)
    x = self.conv4(x)
    x = self.tanh(x)
    x = (x + 1) / 2
    
    return x

In [0]:
class ModelClass(tf.keras.models.Model):
  def __init__(self, all_layers):
    super(ModelClass, self).__init__()
    self.vgg =  vgg_layers(all_layers)
    self.all_layers = all_layers
    self.vgg.trainable = False

  def call(self, inputs):
    # inputs in form of (None, None, None, 3) tensor
    # preprocessing is a part of the model
    inputs = inputs * 255.0
    preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)
    outputs = self.vgg(preprocessed_input)

    return outputs

# losses

In [0]:
def compute_content_cost(a_C, a_G):
  m, n_H, n_W, n_C = a_G.get_shape().as_list()
  J_content = 0
  for i in range(m):
    J_content = J_content + tf.reduce_sum(tf.square(tf.subtract(a_C[i,:,:,:],a_G[i,:,:,:])))/(4*n_C*n_H*n_W)

  J_content = J_content / m
  return J_content

def variational_loss(A):
  m, H, W, C = A.get_shape().as_list()
  return tf.reduce_sum(tf.image.total_variation(A)) / (m*W*H)

def gram_matrix(A):
  GA = tf.matmul(A,tf.transpose(A,[1,0]))

  return GA

def compute_layer_style_cost(a_S, a_G):
  J_style_layer = 0
  m, n_H, n_W, n_C = a_G.get_shape().as_list()
  for i in range(m):
    a_Sw = tf.transpose(tf.reshape(a_S[i,:,:,:], [n_H*n_W, n_C]), [1,0])
    a_Gw = tf.transpose(tf.reshape(a_G[i,:,:,:], [n_H*n_W, n_C]), [1,0])
    GS = gram_matrix(a_Sw)
    GG = gram_matrix(a_Gw)
    J_style_layer = J_style_layer + tf.reduce_sum(tf.square(tf.subtract(GS,GG)))/(4*(n_C*n_H*n_W)**2)

  J_style_layer = J_style_layer / m
  return J_style_layer

def total_cost(content_outputs, style_outputs, outputs, transformed, alpha = 100000, beta = 40, gamma = 20):
  content = compute_content_cost(content_outputs, outputs[0])
  variational = variational_loss(transformed)
  style = 0

  for i in range(1,len(outputs)):
    style = style + 0.2*compute_layer_style_cost(style_outputs[i-1], outputs[i])
  
  return beta*style + content*alpha + gamma*variational


# training

In [0]:
style_extractor = ModelClass(style_layers)
content_extractor = ModelClass(content_layers)
extractor = ModelClass(content_layers + style_layers)

# actual model
model = TransformNet()

In [0]:
def StyleTransfer(style_img, style_path, train_file_path):

  optimizer = tf.optimizers.Adam(learning_rate=0.001)

  @tf.function
  def train_step(content_tensor, style_tensor):
    with tf.device('/GPU:0'):
      content_outputs = content_extractor(content_tensor)
      style_outputs = style_extractor(style_tensor)
      with tf.GradientTape() as tape:
        x = model(content_tensor)
        outputs = extractor(x)
        loss = total_cost(content_outputs, style_outputs, outputs, x)
      grads = tape.gradient(loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss
 
  epochs = 2
  minibatch_size = 10
  step = 0

  check = 100
  save_check = 1000

  loss = []

  train_images = glob(os.path.join(train_file_path, "*.jpg"))
  num_images = len(train_images) - (len(train_images) % minibatch_size)

  print(num_images, 'training images')

  for n in range(epochs):
    for i in range(0, num_images, minibatch_size):
      content_imgs = [rescale_image(img_path, dims=[1, 256, 256]) for img_path in train_images[i:i+minibatch_size]]
      style_imgs = [rescale_image(style_path, dims=[1, 256, 256]) for j in range(len(content_imgs))]
      content_imgs = np.array(content_imgs)
      style_imgs = np.array(style_imgs)

      content_tensor = image_to_tensor(content_imgs)
      style_tensor = image_to_tensor(style_imgs)

      loss_ = train_step(content_tensor, style_tensor)

      if(step%check == 0):
        print(i, 'images processed')
        loss.append(loss_)
        
      if(step%save_check == 0):
        content_path = content_file_path + content_img
        test_image = rescale_image(content_path)
        test_image = image_to_tensor(test_image)

        x = model(test_image)
        x = deprocess_image(x)
        model.save_weights(saved_weights_file_path + str(style_img[:-4]))
        print('checkpoint... model saved')

      step = step + 1
    step = 0
    print("epochs : " + str(n))


  loss = np.array(loss)
  plt.plot(loss, label = 'loss')
  plt.legend()
  plt.show()

  model.save_weights(saved_weights_file_path + str(style_img[:-4]))

# execute

In [47]:
style_path = style_file_path + style_img
# use of colab gpu
# training
StyleTransfer(style_img, style_path, train_file_path)


Output hidden; open in https://colab.research.google.com to view.

# testing

In [0]:
# trained style
style_img = 'style1.jpg'
# image to be transformed
content_img = 'dancing.jpg'

new_model = TransformNet()
new_model.load_weights(saved_weights_file_path + str(style_img[:-4]))

content_path = content_file_path + content_img
test_image = rescale_image(content_path)
test_image = image_to_tensor(test_image)

x = new_model(test_image)
x = deprocess_image(x)

morphed_img = content_img[:-4] + '_' + style_img
output_path = output_file_path + morphed_img
cv2.imwrite(output_path, x)