In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        break

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

Neural Style Transfer with AdaIn for a Single Content/Style Image Pair

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
from tensorflow.keras.applications import VGG19
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array, array_to_img
import matplotlib.pyplot as plt
import cv2
import gzip
import zipfile
import pandas as pd
import skimage.io as sio
from PIL import Image
from io import StringIO, BytesIO

In [None]:
IMG_SIZE = (256, 256)

In [None]:
vgg19 = VGG19(include_top=False, input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3))

In [None]:
class AdaInLayer(layers.Layer):
  def __init__(self):
    super(AdaInLayer, self).__init__()
    #self.batch_size = batch_size
  
  def call(self, inputs):
    cmaps = inputs[0]
    smaps = inputs[1]

    sh = tf.shape(cmaps)
    #cmap_reshape = 
    mean_c = tf.math.reduce_mean(cmaps, axis=[1, 2])
    mean_c = tf.expand_dims(mean_c, axis=1)#tf.reshape(mean_c, (sh[0], 1, 1, sh[3]))
    mean_c = tf.expand_dims(mean_c, axis=2)
    std_c = tf.math.reduce_std(cmaps, axis=[1, 2]) + 0.000001
    std_c = tf.expand_dims(std_c, axis=1)
    std_c = tf.expand_dims(std_c, axis=2)

    mean_s = tf.math.reduce_mean(smaps, axis=[1, 2])
    mean_s = tf.expand_dims(mean_s, axis=1) #tf.reshape(mean_s, (sh[0], 1, 1, sh[3]))
    mean_s = tf.expand_dims(mean_s, axis=2)
    std_s = tf.math.reduce_std(smaps, axis=[1, 2])
    std_s = tf.expand_dims(std_s, axis=1)
    std_s = tf.expand_dims(std_s, axis=2)

    norm_c = tf.divide(tf.math.subtract(cmaps, mean_c), std_c)#tf.linalg.normalize(cmaps, axis=[1, 2])

    out = tf.multiply(norm_c, std_s) + mean_s

    return out
  

In [None]:
def deprocess_image(x):
    # Util function to convert a tensor into a valid image
    x = x.reshape((IMG_SIZE[0], IMG_SIZE[1], 3))
    # Remove zero-center by mean pixel
    x[:, :, 0] += 103.939
    x[:, :, 1] += 116.779
    x[:, :, 2] += 123.68
    # 'BGR'->'RGB'
    x = x[:, :, ::-1]
    x = np.clip(x, 0, 255).astype("uint8")
    return x

In [None]:
# Optimized for Single Image Stylization

class AdaInStyleTransfer_S():
  def __init__(self, decoder_inp_size,
               decoder_conv_filters,
               decoder_num_conv_layers,
               style_loss_weight,
               style_weight,
               total_variation_weight,
               img_size,
               batch_size,
               content_feat_shape,
               style_feat_shape,
               single_image=True):
    self.img_size = img_size
    self.decoder_inp_size = decoder_inp_size
    self.decoder_conv_filters = decoder_conv_filters
    self.decoder_num_conv_layers = decoder_num_conv_layers
    self.style_loss_weight = style_loss_weight
    self.style_weight = style_weight
    self.total_variation_weight = total_variation_weight
    self.content_feat_shape = content_feat_shape
    self.style_feat_shape = style_feat_shape
    self.optimizer = Adam(learning_rate=0.0001)
    self._create_base_model()
    self._create_model_for_style_loss()
    self._create_encoder()
    self._create_decoder_v2()
    self.create_core()
    self.op_images = []
    self.batch_size = batch_size
    
    pass


  def _create_base_model(self):
    self.base_model = VGG19(include_top=False)
    for layer in self.base_model.layers:
      layer.trainable = False

  def _create_model_for_style_loss(self):
    input_img = layers.Input(shape=self.img_size, name='style_latent_input')
    #x = preprocess_input(input_img)

    style_latent_base = Model(inputs=self.base_model.input, outputs=[self.base_model.get_layer('block1_conv1').output,
                                                                     self.base_model.get_layer('block2_conv1').output,
                                                                     self.base_model.get_layer('block3_conv1').output,
                                                                     self.base_model.get_layer('block4_conv1').output,
                                                                     ])
    
    
    style_latent = style_latent_base(input_img)
    self.style_loss_model = Model(inputs=input_img, outputs=style_latent)
    for layer in self.style_loss_model.layers:
      layer.trainable = False
    
    
  def _create_encoder(self):
    input_img = layers.Input(shape=self.img_size, name='encoder_input')
    
    #x = preprocess_input(input_img)
    
    model = Model(inputs=self.base_model.input, outputs=self.base_model.get_layer('block4_conv1').output)
    fmaps = model(input_img)
    
    self.encoder = Model(inputs=input_img, outputs=fmaps)
    for layer in self.encoder.layers:
      layer.trainable = False


  def _create_decoder(self):
    input_tensor = layers.Input(shape=self.decoder_inp_size, name='decoder_input')
    x = input_tensor
    paddings = tf.constant([[0,0], [1, 1], [1, 1], [0, 0]])
    for i in range(len(self.decoder_conv_filters)-1):
      if (i!=len(self.decoder_conv_filters)-1):
        x = layers.UpSampling2D(interpolation='nearest')(x)
      for j in range(self.decoder_num_conv_layers[i]):
        filters = self.decoder_conv_filters[i]
        if ((j==self.decoder_num_conv_layers[i]-1) and (i!=len(self.decoder_conv_filters)-1)):
          filters = self.decoder_conv_filters[i+1]
        x = layers.Conv2D(filters = filters,
                          kernel_size = (3, 3),
                          padding='valid',
                          )(x)
        x = tf.pad(x, paddings, mode='REFLECT')
        
        if i!=len(self.decoder_conv_filters)-2:
          x = layers.ReLU()(x)
          #x = layers.Activation('tanh')(x)
    #x = layers.Activation('tanh')(x)
    deprocess_output = self._deprocess_decoder_output(x)
    self.decoder = Model(inputs = input_tensor, outputs=deprocess_output)

  def _create_decoder_v2(self):
    input_tensor = layers.Input(shape=self.decoder_inp_size, name='decoder_input')
    x = input_tensor
    paddings = tf.constant([[0,0], [1, 1], [1, 1], [0, 0]])
    for i in range(len(self.decoder_conv_filters)-1):
      if (i!=len(self.decoder_conv_filters)-1) and (self.decoder_num_conv_layers[i]!=1):
        x = layers.UpSampling2D(interpolation='nearest')(x)
      for j in range(self.decoder_num_conv_layers[i]):
        filters = self.decoder_conv_filters[i]
        if ((j==self.decoder_num_conv_layers[i]-1) and (i!=len(self.decoder_conv_filters)-1)):
          filters = self.decoder_conv_filters[i+1]
        x = layers.Conv2D(filters = filters,
                          kernel_size = (3, 3),
                          padding='valid',
                          )(x)
        x = tf.pad(x, paddings, mode='REFLECT')
        
        if i!=len(self.decoder_conv_filters)-2:
          x = layers.ReLU()(x)
          #x = layers.Activation('tanh')(x)
    #x = layers.Activation('tanh')(x)
    #x = layers.UpSampling2D(interpolation='nearest')(x)
    #deprocess_output = self._deprocess_decoder_output(x)
    self.decoder = Model(inputs = input_tensor, outputs=x)
  
  def create_core(self):
    c_input = layers.Input(shape=(self.content_feat_shape), dtype=tf.float32, name='content_image')
    s_input = layers.Input(shape=(self.style_feat_shape), dtype=tf.float32, name='style_image')

    # t = AdaInLayer()([c_input, s_input])
    # weighted_t = (1-self.style_weight)*c_input + self.style_weight*t
    # out_img = self.decoder(weighted_t)
    # out_t = self.encoder(out_img)

    # content_output = tf.concat([tf.expand_dims(weighted_t, axis=0), tf.expand_dims(out_t, axis=0)], axis=0)
    # self.nst_model = Model(inputs=[c_input, s_input], outputs=[content_output, out_img])

    t = AdaInLayer()([c_input, s_input])
    weighted_t = (1-self.style_weight)*c_input + self.style_weight*t
    out_img = self.decoder(weighted_t)
    out_t = self.encoder(out_img)

    content_output = tf.concat([tf.expand_dims(weighted_t, axis=0), tf.expand_dims(out_t, axis=0)], axis=0)

    self.nst_model = Model(inputs=[c_input, s_input], outputs=[content_output, out_img])


  def calculate_loss(self, model, X, y, training=True):
    outputs = model(X)
    y1 = outputs[0]
    y2 = outputs[1]

    content_loss = self._custom_content_loss(y1[0], y1[1])
    style_loss = self.style_loss(y, y2)

    total_loss = content_loss + self.style_loss_weight*style_loss #+ self.total_variation_weight*tf.image.total_variation(y2)

    return total_loss
  
  @tf.function
  def _get_grads(self, model, X, y):
    with tf.GradientTape() as tape:
      loss = self.calculate_loss(model, X, y, True)
    return loss, tape.gradient(loss, model.trainable_variables)

  def custom_training_for_single_image(self, content_image, style_image, epochs=100):
    #content_image = content_image[:, :, :, ::-1]#tf.reverse(content_image, axis=[-1])
    #style_image = style_image[:, :, :, ::-1]#tf.reverse(style_image, axis=[-1])

    content_feats = self.encoder(preprocess_input(content_image))
    style_feats = self.encoder(preprocess_input(style_image))
    style_feats_for_loss = self.style_loss_model(preprocess_input(style_image))
    for i in range(1, epochs+1):
      loss_value, grads = self._get_grads(self.nst_model, [content_feats, style_feats], style_feats_for_loss)
      self.optimizer.apply_gradients(zip(grads, self.nst_model.trainable_variables))
      if(((i+1)%20)==0):
        print("Epoch: {}, Total Loss: {:.3f}".format(i, loss_value.numpy()[0]))
        ops = self.nst_model([content_feats, style_feats])
        self.op_images.append(ops)
        

  def _custom_content_loss(self, y, y_pred):
    sh = tf.shape(y)
    content_fmaps = tf.reshape(y, (sh[0], sh[1]*sh[2]*sh[3]))
    encoded_output = tf.reshape(y_pred, (sh[0], sh[1]*sh[2]*sh[3]))
    return tf.keras.losses.MSE(content_fmaps, encoded_output)


  
   
  def style_loss(self, y_true, y_pred):
    encoded_styles = y_true # self.style_loss_model(y_true)
    output_style_encoded = self.style_loss_model(y_pred)
    loss = 0
    for i in range(len(encoded_styles)):
      #sh = tf.shape(encoded_styles[i])
      #tf.print(sh)
      #sh2 = tf.shape(output_style_encoded[i])
      mu_s = tf.math.reduce_mean(encoded_styles[i], axis=[1, 2])#tf.math.reduce_mean(tf.reshape(encoded_styles[i], (sh[0], sh[3], sh[1]*sh[2])), axis=2)
      mu_o = tf.math.reduce_mean(output_style_encoded[i], axis=[1, 2])#tf.math.reduce_mean(tf.reshape(output_style_encoded[i], (sh[0], sh[3], sh[1]*sh[2])), axis=2)
      mu_diff = tf.keras.losses.MSE(mu_s, mu_o)#tf.reduce_sum(tf.sqrt(tf.square(tf.subtract(mu_o, mu_s))), axis=1)#

      std_s = tf.math.reduce_std(encoded_styles[i], axis=[1, 2])
      std_o = tf.math.reduce_mean(output_style_encoded[i], axis=[1, 2])#tf.math.reduce_std(output_style_encoded[i], axis=[1, 2])
      std_diff = tf.keras.losses.MSE(std_s, std_o)#tf.reduce_sum(tf.sqrt(tf.square(tf.subtract(mu_o, mu_s))), axis=1)#

      loss = loss + mu_diff + std_diff
    return loss


  def content_loss(self, y_true, y_pred):
    adain_c = y_pred[0]
    out_img = y_pred[1]
    content_fmaps = adain_c#self.encoder(y_true)
    encoded_output = out_img#elf.encoder(out_img)
    sh = tf.shape(content_fmaps)
    content_fmaps = tf.reshape(content_fmaps, (sh[0], sh[1]*sh[2]*sh[3]))
    encoded_output = tf.reshape(encoded_output, (sh[0], sh[1]*sh[2]*sh[3]))
    return tf.keras.losses.MSE(content_fmaps, encoded_output)#tf.reduce_sum(tf.sqrt(tf.square(tf.subtract(encoded_output, content_fmaps))))#

  def variation_loss(self, y_true, y_pred):
    return tf.reduce_mean(tf.image.total_variation(y_pred))

    


In [None]:
content_fnames = ["photo_jpg/00dcf0f1e3.jpg", "photo_jpg/047da870f6.jpg"]
style_fnames = ["monet_jpg/6043aadea0.jpg", "monet_jpg/3d13fe022e.jpg"]
path_prefix = "/kaggle/input/gan-getting-started/"
x_cont = np.zeros(shape=(len(content_fnames), IMG_SIZE[0], IMG_SIZE[1], 3), dtype=np.uint8)
x_styl = np.zeros(shape=(len(style_fnames), IMG_SIZE[0], IMG_SIZE[1], 3), dtype=np.uint8)
k=0
for i,j in zip(content_fnames, style_fnames):
  cimg = load_img(path_prefix+i, target_size=IMG_SIZE)
  simg = load_img(path_prefix+j, target_size=IMG_SIZE)
  x_cont[k, :, :, :] = img_to_array(cimg)
  x_styl[k, :, :, :] = img_to_array(simg)

  k += 1

In [None]:
plt.imshow(x_cont[0].astype(np.uint8))
plt.title("Content Image");
plt.figure()
plt.imshow(x_styl[0].astype(np.uint8))
plt.title("Style Image");

In [None]:
newObj_s = AdaInStyleTransfer_S(decoder_inp_size = (32, 32, 512),
                            decoder_conv_filters = [256, 256, 128, 64, 3],
                            decoder_num_conv_layers = [1, 4, 2, 2, 1],
                            style_loss_weight = 2,
                            style_weight = 0.8,
                            total_variation_weight=0.00001,
                            img_size = (IMG_SIZE[0], IMG_SIZE[1], 3),
                            batch_size=1,
                            content_feat_shape=(32, 32, 512),
                            style_feat_shape=(32, 32, 512))

In [None]:
newObj_s.custom_training_for_single_image(x_cont[0:1], x_styl[0:1], 1000) 

In [None]:
print(len(newObj_s.op_images))
plt.imshow(deprocess_image(newObj_s.op_images[54][1].numpy()[0]))