In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
import IPython.display as display
import PIL.Image

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

RGB_MEAN = [103.939/255, 116.779/255, 123.68/255]

def show_img(*imgs):
    for img in imgs:
        if img.shape[0] == 1:
            img = tf.squeeze(img, axis=0)
        cv2.imshow('Image Window', img)
        cv2.waitKey(0)
    cv2.destroyAllWindows()

def rand_img(shape):
    img = np.zeros(shape, dtype=np.float32)
    cv2.randn(img, 128, 20)
    return img

def imshape(img):
    return np.reshape(img, (1, img.shape[0], img.shape[1], img.shape[2]))

def unshape(img):
    if img.shape[0] == 1:
        img = np.squeeze(img, axis=0)
    return img


def load_img(path_to_img):
  max_dim = 512
  img = tf.io.read_file(path_to_img)
  img = tf.image.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)

  shape = tf.cast(tf.shape(img)[:-1], tf.float32)
  long_dim = max(shape)
  scale = max_dim / long_dim

  new_shape = tf.cast(shape * scale, tf.int32)

  img = tf.image.resize(img, new_shape)
  img = img[tf.newaxis, :]
  return img

In [None]:
# Importing Images

# painting = load_img('Art/paintings/the_weeping_woman.png')
# painting = load_img('Art/paintings/the_starry_night.jpg')
painting = load_img('Art/paintings/the_great_wave_off_kanagawa.jpg')
# painting = load_img('Art/paintings/the_storm_on_the_sea_of_galilee.png')
# painting = load_img('Art/paintings/wanderer_above_the_sea_of_fog.jpg')
# painting = load_img(tf.keras.utils.get_file('kandinsky5.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg'))

# image = load_img('Art/images/greece.jpg')
# image = load_img('Art/images/earth.jpg')
# image = load_img('Art/images/uni.jpeg')
image = load_img('Art/images/taj.jpg')

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)
    

imshow(painting)
# show_img(painting, image)

In [None]:
# Importing the Oxford VGG model
vgg = tf.keras.applications.VGG16(include_top=False, weights='imagenet')
vgg.summary()

In [None]:
# ppd_image = image.astype('float32')
# ppd_image = tf.keras.applications.vgg16.preprocess_input(image)
ppd_image = image - RGB_MEAN

# ppd_painting = painting.astype('float32')
# ppd_painting = tf.keras.applications.vgg16.preprocess_input(painting)
ppd_painting = painting - RGB_MEAN

# show_img(image, painting, ppd_image, ppd_painting)
np.max(painting)

In [None]:
vgg.trainable = False

# Selecting the layers used in the paper
content_layers = ['block5_conv2']
style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']

# Making the model
model_layers = [vgg.get_layer(layer_name).output for layer_name in style_layers]
model_layers.append(vgg.get_layer(content_layers[0]).output)

model = tf.keras.Model([vgg.input], model_layers)
model.trainable = False

In [None]:
def gram_matrix(feature_map):
    return tf.linalg.einsum('bijc,bijd->bcd', feature_map, feature_map)

def get_style_content(outputs):
    return ([gram_matrix(output) for output in outputs[:-1]],
            outputs[-1:])

In [None]:
image_outputs = model(ppd_image)
_ , target_content = get_style_content(image_outputs)
# add ppd
painting_outputs = model(ppd_painting)
target_style, _ = get_style_content(painting_outputs)

In [None]:
style_weight=1e-2
content_weight=1e8 + 2.5e7
opt = tf.optimizers.Adam(learning_rate=0.02)

def total_loss(out):
    final_style, final_content = get_style_content(out)
    # print("Target Style: ", target_style)
    # print("\n\n___________________\n\n")
    # print("Target Content: ", target_content)
    style_loss = tf.add_n([tf.reduce_mean((fs  - ts)**2)
                    for fs, ts in zip(final_style, target_style)])

    content_loss = tf.add_n([tf.reduce_mean((fc - tc)**2)
                    for fc, tc in zip(final_content, target_content)])

    loss = (content_weight*content_loss) + (style_weight*style_loss)
    return loss

In [None]:
@tf.function()
def train_step(image):
    with tf.GradientTape() as tape:
        final_out = model(image)
        loss = total_loss(final_out)

    grad = tape.gradient(loss, image)
    opt.apply_gradients([(grad, image)])
    image.assign(tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0))
    # image.assign(image)

In [None]:
def tensor_to_image(tensor):
  tensor = tensor*255
  tensor = np.array(tensor, dtype=np.uint8)
  if np.ndim(tensor)>3:
    assert tensor.shape[0] == 1
    tensor = tensor[0]
  return PIL.Image.fromarray(tensor)

In [None]:
im = tf.Variable(image)
train_step(im)

tensor_to_image(im)

In [None]:
im = tf.Variable(image)

import time
start = time.time()

epochs = 1
steps_per_epoch = 100

step = 0
for n in range(epochs):
  for m in range(steps_per_epoch):
    step += 1
    train_step(im)
    print(".", end='')
  display.clear_output(wait=True)
  display.display(tensor_to_image(im))
  print("Train step: {}".format(step))

end = time.time()
print("Total time: {:.1f}".format(end-start))