# Neural Style Transfer
As implemented in the original *Gatys et al.* paper for celestial images and some of my favorite artists.

In [1]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

## Helper Functions
As in *Gatys et al.* we use VGG19 convnet and preprocess accordingly to extract style and content of images.

In [4]:
def preprocess_image(image_path):
    """Format and resize images to appropriate arrays"""
    img = keras.utils.load_img(
        image_path, target_size=(img_height, img_width))
    img = keras.utils.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    
    # preprocess for vgg19 convnet
    img = keras.applications.vgg19.preprocess_input(img)
    
    return img

def deprocess_image(img):
    """Reverse-engineer numpy array back to valid image format"""
    
    # 3 for RGB channels
    img = img.reshape((img_height, img_width, 3))
    
    # zero-centering to reverse transformation from VGG19 which
    # uses imagenet data
    img[:, :, 0] += 103.939
    img[:, :, 1] += 116.779
    img[:, :, 2] += 123.68
    
    img = img[:, :, ::-1]
    img = np.clip(img, 0, 255).astype("uint8")
    return img

def content_loss(base_img, combination_img):
    """computes the content loss between two images"""
    return tf.reduce_sum(tf.square(combination_img - base_img))

def gram_matrix(x):
    """computes the gram_matrix for a given numpy array"""
    x = tf.transpose(x, (2, 0, 1))
    features = tf.reshape(x, (tf.shape(x)[0], -1))
    gram = tf.matmul(features, tf.transpose(features))
    return gram

def style_loss(style_img, combination_img):
    """computes the style loss between two images using gram_matrix"""
    S = gram_matrix(style_img)
    C = gram_matrix(combination_img)
    channels = 3 # RGB
    size = img_height * img_width
    return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels ** 2) * (size ** 2))

def total_variation_loss(x):
    """regularizer for style and content loss"""
    a = tf.square(
        x[:, : img_height - 1, : img_width - 1, :] - x[:, 1:, : img_width - 1, :]
    )
    b = tf.square(
        x[:, : img_height - 1, : img_width - 1, :] - x[:, : img_height - 1, 1:, :]
    )
    return tf.reduce_sum(tf.pow(a+b, 1.25))

## Defining the Final Loss

In [5]:
style_layer_names = [
    "block1_conv1",
    "block2_conv1",
    "block3_conv1",
    "block4_conv1",
    "block5_conv1",
]
content_layer_name = "block5_conv2"
total_variation_weight = 1e-6

style_weight = 1e-6
content_weight = 2.5e-9

def compute_loss(combination_image, base_image, style_reference_image):
    input_tensor = tf.concat(
        [base_image, style_reference_image, combination_image], axis=0)
    features = feature_extractor(input_tensor)
    loss = tf.zeros(shape=())
    layer_features = features[content_layer_name]
    base_image_features = layer_features[0, :, :, :]
    combination_features = layer_features[2, :, :, :]
    loss = loss + content_weight * content_loss(
        base_image_features, combination_features
    )
    for layer_name in style_layer_names:
        layer_features = features[layer_name]
        style_reference_features = layer_features[1, :, :, :]
        combination_features = layer_features[2, :, :, :]
        style_loss_value = style_loss(
        style_reference_features, combination_features)
        loss += (style_weight / len(style_layer_names)) * style_loss_value
        
    loss += total_variation_weight * total_variation_loss(combination_image)
    
    return loss

## Gradient-Descent

In [9]:
@tf.function
def compute_loss_and_grads(
    combination_image, base_image, style_reference_image):
    with tf.GradientTape() as tape:
        loss = compute_loss(
            combination_image, base_image, style_reference_image)
    grads = tape.gradient(loss, combination_image)
    return loss, grads

optimizer = keras.optimizers.SGD(
    keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=100.0, decay_steps=1000, decay_rate=0.8
    )
)

## Training/Generating

In [14]:
base_image_name = "pluto"
style_reference_image_name = "signac1"
base_image_path = f"images/original_images/{base_image_name}.jpg"
style_reference_image_path = f"images/reference_images/{style_reference_image_name}.jpg"

original_width, original_height = keras.utils.load_img(base_image_path).size
img_height = 800
img_width = round(original_width * img_height / original_height)

model = keras.applications.vgg19.VGG19(weights="imagenet", include_top=False)
outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])
feature_extractor = keras.Model(inputs=model.inputs, outputs=outputs_dict)

In [11]:
base_image = preprocess_image(base_image_path)
style_reference_image = preprocess_image(style_reference_image_path)
combination_image = tf.Variable(preprocess_image(base_image_path))

iterations = 4000
for i in range(1, iterations + 1):
    loss, grads = compute_loss_and_grads(
        combination_image, base_image, style_reference_image
    )
    optimizer.apply_gradients([(grads, combination_image)])
    if i % 100 == 0:
        print(f"Iteration {i}: loss={loss:.2f}")
    if i % 500 == 0:
        img = deprocess_image(combination_image.numpy())
        fname = f"images/training_images/combination_image_at_iteration_{i}.png"
        keras.utils.save_img(fname, img)
    if i % iterations == 0:
        fname = f"images/combination_images/{base_image_name}_{style_reference_image_name}.png"
        keras.utils.save_img(fname, img)

Iteration 100: loss=16601.26
Iteration 200: loss=11607.11
Iteration 300: loss=9932.91
Iteration 400: loss=9100.65
Iteration 500: loss=8597.45
Iteration 600: loss=8257.75
Iteration 700: loss=8010.11
Iteration 800: loss=7819.84
Iteration 900: loss=7668.95
Iteration 1000: loss=7545.96
Iteration 1100: loss=7443.09
Iteration 1200: loss=7355.37
Iteration 1300: loss=7279.49
Iteration 1400: loss=7213.08
Iteration 1500: loss=7154.48
Iteration 1600: loss=7102.37
Iteration 1700: loss=7055.54
Iteration 1800: loss=7013.18
Iteration 1900: loss=6974.71
Iteration 2000: loss=6939.46
Iteration 2100: loss=6907.08
Iteration 2200: loss=6877.24
Iteration 2300: loss=6849.68
Iteration 2400: loss=6824.11
Iteration 2500: loss=6800.40
Iteration 2600: loss=6778.26
Iteration 2700: loss=6757.51
Iteration 2800: loss=6738.09
Iteration 2900: loss=6719.81
Iteration 3000: loss=6702.54
Iteration 3100: loss=6686.28
Iteration 3200: loss=6670.93
Iteration 3300: loss=6656.38
Iteration 3400: loss=6642.55
Iteration 3500: loss=

OSError: [Errno 36] File name too long: 'images/combination_images/[[[[-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   ...\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]]\n\n  [[-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   ...\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]]\n\n  [[-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   ...\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]]\n\n  ...\n\n  [[-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   ...\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]]\n\n  [[-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   ...\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]]\n\n  [[-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   ...\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]\n   [-102.939 -115.779 -122.68 ]]]]_[[[[  25.060997    19.221        5.3199997]\n   [   7.060997    -5.7789993  -26.68     ]\n   [  53.060997    29.221       14.32     ]\n   ...\n   [ -13.939003    31.221       59.32     ]\n   [  23.060997    44.221       60.32     ]\n   [ -82.939      -46.779      -27.68     ]]\n\n  [[  -9.939003    36.221       25.32     ]\n   [  10.060997    49.221       54.32     ]\n   [  22.060997    36.221       26.32     ]\n   ...\n   [  34.060997    53.221       57.32     ]\n   [  54.060997    77.221       76.32     ]\n   [  25.060997    59.221       73.32     ]]\n\n  [[  28.060997    74.221       93.32     ]\n   [ -17.939003    32.221       75.32     ]\n   [  10.060997    33.221       32.32     ]\n   ...\n   [ -18.939003    -0.7789993   -8.68     ]\n   [  27.060997    46.221       44.32     ]\n   [  35.060997    53.221       77.32     ]]\n\n  ...\n\n  [[  -0.939003     4.2210007  -15.68     ]\n   [  75.061       48.221       22.32     ]\n   [  74.061       55.221       53.32     ]\n   ...\n   [  31.060997    -9.778999   -10.68     ]\n   [ -81.939       -1.7789993  118.32     ]\n   [ -68.939      -14.778999   105.32     ]]\n\n  [[  33.060997     2.2210007   -5.6800003]\n   [  33.060997    12.221001     6.3199997]\n   [  59.060997    39.221       39.32     ]\n   ...\n   [  12.060997   -15.778999   -31.68     ]\n   [ -77.939       -7.7789993  101.32     ]\n   [ -55.939003     1.2210007  111.32     ]]\n\n  [[  46.060997    29.221       33.32     ]\n   [  62.060997    26.221        4.3199997]\n   [  72.061       29.221       17.32     ]\n   ...\n   [  37.060997    34.221       34.32     ]\n   [-103.939      -64.779        4.3199997]\n   [ -65.939      -51.779       31.32     ]]]].png'