# imports

In [None]:
import tensorflow as tf
import numpy as np
import cv2
from matplotlib import pyplot as plt
from tensorflow.keras.applications import VGG19


# files

In [None]:
style_path = 'style4.jpg'
content_path = 'eye.jpg'

morphed_img = content_path[:-4] + '_' + style_path

style_path = '.../style/' + style_path
content_path = '.../content/' + content_path
output_path = '.../stylized/'

content_layers = ['block5_conv2']
style_layers = ['block1_conv1',
                'block2_conv1',
                'block3_conv1', 
                'block4_conv1', 
                'block5_conv1']


# image processing

In [None]:
# max dims 512 x 512
# rescale style image to content dimensions
def rescale_image(img_path, dims = None):
  dim = 512
  img = cv2.imread(img_path)
  a, b, _ = img.shape
  a, b = min(a,dim), min(b,dim)
  if(dims != None):
    a, b = dims[0], dims[1]
  x = cv2.resize(img, (a, b))

  return x

In [None]:
# scale tensor values to lie in range [0, 1]
# this though not necessary speeds up learning
def image_to_tensor(x):
  x = x / 255.0
  x = np.expand_dims(x, axis=0)
  x = tf.convert_to_tensor(x, dtype=tf.float32)

  return x


In [None]:
def clip(image):
  return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)

In [None]:
def deprocess_image(tensor):
  x = np.array(tensor) * 255.0
  x = x[0]
  x = np.clip(x, 0, 255).astype("uint8")
  plt.imshow(x)

  return x

# model

In [None]:
def vgg_layers(layer_names):
  vgg = VGG19(include_top=False, weights='imagenet')
  vgg.trainable = False
  outputs = []

  outputs = [vgg.get_layer(name).output for name in layer_names]
  
  model = tf.keras.Model([vgg.input], outputs)
  return model

In [None]:
class ModelClass(tf.keras.models.Model):
  def __init__(self, all_layers):
    super(ModelClass, self).__init__()
    self.vgg =  vgg_layers(all_layers)
    self.all_layers = all_layers
    self.vgg.trainable = False

  def call(self, inputs):
    # inputs in form of (1, None, None, 3) tensor
    # preprocessing is a part of the model
    inputs = inputs * 255.0
    preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)
    outputs = self.vgg(preprocessed_input)

    return outputs

# losses

In [None]:
def compute_content_cost(a_C, a_G):
  m, n_H, n_W, n_C = a_G.get_shape().as_list()
  J_content = tf.reduce_sum(tf.square(tf.subtract(a_C,a_G)))/(4*n_C*n_H*n_W)

  return J_content

def gram_matrix(A):
  GA = tf.matmul(A,tf.transpose(A,[1,0]))

  return GA

def compute_layer_style_cost(a_S, a_G):
  m, n_H, n_W, n_C = a_G.get_shape().as_list()
  a_S = tf.transpose(tf.reshape(a_S, [n_H*n_W, n_C]), [1,0])
  a_G = tf.transpose(tf.reshape(a_G, [n_H*n_W, n_C]), [1,0])
  GS = gram_matrix(a_S)
  GG = gram_matrix(a_G)
  J_style_layer = tf.reduce_sum(tf.square(tf.subtract(GS,GG)))/(4*(n_C*n_H*n_W)**2)

  return J_style_layer

def total_cost(content_outputs, style_outputs, outputs, alpha = 10000, beta = 3):
  content = compute_content_cost(content_outputs, outputs[0])
  style = 0

  for i in range(1,len(outputs)):
    style = style + 0.2*compute_layer_style_cost(style_outputs[i-1], outputs[i])
  
  return content, style, beta*style + content*alpha


# training

In [None]:
def StyleTransfer(style_path, content_path):
  content_image = rescale_image(content_path)
  style_image = rescale_image(style_path, dims=content_image.shape)

  content_image = image_to_tensor(content_image)
  style_image = image_to_tensor(style_image)

  style_extractor = ModelClass(style_layers)
  style_outputs = style_extractor(style_image)

  content_extractor = ModelClass(content_layers)
  content_outputs = content_extractor(content_image)

  # actual model
  model = ModelClass(content_layers + style_layers)
  model.trainable = False

  img = tf.Variable(content_image)
  optimizer = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)

  def train_step(image):
    with tf.GradientTape() as tape:
      outputs = model(image)
      content_loss, style_loss, loss = total_cost(content_outputs, style_outputs, outputs)
    grads = tape.gradient(loss, image)
    optimizer.apply_gradients([(grads, image)])
    image.assign(clip(image))
    return content_loss, style_loss, loss
  
  epochs = 10
  steps_per_epoch = 100

  content_loss = np.zeros([epochs*steps_per_epoch])
  style_loss = np.zeros([epochs*steps_per_epoch])
  loss = np.zeros([epochs*steps_per_epoch])
  step = 0

  for n in range(epochs):
    for m in range(steps_per_epoch):
      content_loss[step], style_loss[step], loss[step] = train_step(img)
      step += 1
      print(".", end='')
    print("Train step: " + str(step))

  plt.plot(content_loss, label = 'content_loss')
  plt.legend()
  plt.show()
  plt.plot(style_loss, label = 'style_loss')
  plt.legend()
  plt.show()
  plt.plot(loss, label = 'loss')
  plt.legend()
  plt.show()

  final_img = deprocess_image(img.value())
  return final_img


# execute

In [None]:
# run on colab gpu
with tf.device('/gpu:0'):
  image = StyleTransfer(style_path, content_path)
  output_path = output_path + morphed_img
  cv2.imwrite(output_path, image)