In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras.applications import vgg19
from tensorflow.keras.models import load_model,Model
from PIL import Image
import time
import matplotlib.pyplot as plt
import matplotlib
import os
from pathlib import Path

matplotlib.rcParams['figure.figsize'] = (12,12)
matplotlib.rcParams['axes.grid'] = False

In [None]:
def load_image(image_path, dim=None, resize=False):
    img= Image.open(image_path)
    if dim:
        if resize:
            img=img.resize(dim)
        else:
            img.thumbnail(dim)
    img= img.convert("RGB")
    return np.array(img)

def array_to_img(array):
    array=np.array(array,dtype=np.uint8)
    if np.ndim(array)>3:
        assert array.shape[0]==1
        array=array[0]
    return Image.fromarray(array)

def show_image(image,title=None):
    if len(image.shape)>3:
        image=tf.squeeze(image,axis=0)
    plt.imshow(image)

In [None]:
vgg=vgg19.VGG19(weights='imagenet',include_top=False)
vgg.summary()

content_layers=['block4_conv2']

style_layers=['block1_conv1',
            'block2_conv1',
            'block3_conv1',
            'block4_conv1',
            'block5_conv1']

In [None]:
class LossModel:
    def __init__(self,pretrained_model,content_layers,style_layers):
        self.model=pretrained_model
        self.content_layers=content_layers
        self.style_layers=style_layers
        self.loss_model=self.get_model()

    def get_model(self):
        self.model.trainable=False
        layer_names=self.style_layers + self.content_layers
        outputs=[self.model.get_layer(name).output for name in layer_names]
        new_model=Model(inputs=self.model.input,outputs=outputs)
        return new_model
    
    def get_activations(self,inputs):
        inputs=inputs*255.0
        style_length=len(self.style_layers)
        outputs=self.loss_model(vgg19.preprocess_input(inputs))
        style_output,content_output=outputs[:style_length],outputs[style_length:]
        content_dict={name:value for name,value in zip(self.content_layers,content_output)}
        style_dict={name:value for name,value in zip(self.style_layers,style_output)}
        return {'content':content_dict,'style':style_dict}

In [None]:
loss_model = LossModel(vgg, content_layers, style_layers)

In [None]:
def content_loss(placeholder,content,weight):
    assert placeholder.shape == content.shape
    return weight*tf.reduce_mean(tf.square(placeholder-content))

def gram_matrix(x):
    gram=tf.linalg.einsum('bijc,bijd->bcd', x, x)
    return gram/tf.cast(x.shape[1]*x.shape[2]*x.shape[3],tf.float32)

def style_loss(placeholder,style, weight):
    assert placeholder.shape == style.shape
    s=gram_matrix(style)
    p=gram_matrix(placeholder)
    return weight*tf.reduce_mean(tf.square(s-p))

def perceptual_loss(predicted_activations,content_activations,
                    style_activations,content_weight,style_weight,
                    content_layers_weights,style_layer_weights):
    pred_content = predicted_activations["content"]
    pred_style = predicted_activations["style"]
    c_loss = tf.add_n([content_loss(pred_content[name],content_activations[name],
                                  content_layers_weights[i]) for i,name in enumerate(pred_content.keys())])
    c_loss = c_loss*content_weight
    s_loss = tf.add_n([style_loss(pred_style[name],style_activations[name],
                                style_layer_weights[i]) for i,name in enumerate(pred_style.keys())])
    s_loss = s_loss*style_weight
    return c_loss+s_loss

In [None]:
def resnet_block(num_filters, input_layer):
    init = tf.random_normal_initializer(0., 0.02)
    # first convolutional layer
    g = tf.keras.layers.Conv2D(num_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
    g = tfa.layers.InstanceNormalization(axis=-1)(g)
    g = tf.keras.layers.Activation('relu')(g)
    # second convolutional layer
    g = tf.keras.layers.Conv2D(num_filters, (3,3), padding='same', kernel_initializer=init)(g)
    g = tfa.layers.InstanceNormalization(axis=-1)(g)
    # concatenate with input layer
    g = tf.keras.layers.Concatenate()([g, input_layer])
    return g
def generator():
 
    initializer = tf.random_normal_initializer(0., 0.02)
    inputs = tf.keras.layers.Input(shape=[256, 256, 3])
    
    #first encoding layer
    x = tf.keras.layers.Conv2D(32, (9, 9), padding="same", kernel_initializer = initializer)(inputs)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    #second encoding layer
    x = tf.keras.layers.Conv2D(64, (3, 3), strides = 2, padding="same", kernel_initializer = initializer)(x)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    #third encoding layer
    x = tf.keras.layers.Conv2D(128, (3, 3), strides = 2, padding="same", kernel_initializer = initializer)(x)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    #resnet blocks
    for _ in range(5):
        x = resnet_block(128, x)
    #first decoding layer
    x = tf.keras.layers.Conv2DTranspose(64, (3, 3), strides = 2, name = "feature_map", padding="same", kernel_initializer = initializer)(x)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    #second decoding layer
    x = tf.keras.layers.Conv2DTranspose(32, (3, 3), strides = 2, padding="same", kernel_initializer = initializer)(x)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    #third decoding layewr
    x = tf.keras.layers.Conv2D(3, (9, 9), padding="same", kernel_initializer = initializer)(x)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    output_image = (tf.keras.layers.Activation('tanh')(x) + 1) * (255.0 / 2)
    
    model = tf.keras.Model(inputs, output_image)
    return model
test_model = generator()
test_model.summary()

In [None]:
input_shape=(256,256,3)
batch_size=4

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

In [None]:
def train_step(dataset,style_activations,steps_per_epoch,style_model,loss_model,optimizer,
               checkpoint_path="./",content_weight=1e4, style_weight=1e-2,
               total_variation_weight=0.004):
    
    content_layers_weights=[1]
    style_layers_weights=[1,1,1,1,1]
    batch_losses=[]
    steps=1
    save_path=os.path.join(checkpoint_path, "model_checkpoint.ckpt")
    print("Model Checkpoint Path: ",save_path)
    for input_image_batch in dataset:
        if steps-1 >= steps_per_epoch:
            break
        with tf.GradientTape() as tape:
            outputs=style_model(input_image_batch)
            outputs=tf.clip_by_value(outputs, 0, 255)
            pred_activations=loss_model.get_activations(outputs/255.0)
            content_activations=loss_model.get_activations(input_image_batch)["content"] 
            curr_loss=perceptual_loss(pred_activations,content_activations,style_activations,content_weight,
                                      style_weight,content_layers_weights,style_layers_weights)
            curr_loss += total_variation_weight*tf.image.total_variation(outputs)
        batch_losses.append(curr_loss)
        grad = tape.gradient(curr_loss,style_model.trainable_variables)
        optimizer.apply_gradients(zip(grad,style_model.trainable_variables))
        if steps % 1000==0:
            print("checkpoint saved ",end=" ")
            test_model.save_weights(save_path)
            print(f"Loss: {tf.reduce_mean(batch_losses).numpy()}")
        steps+=1
    return tf.reduce_mean(batch_losses)

In [None]:
class TensorflowDatasetLoader:
    def __init__(self,dataset_path,batch_size=4, image_size=(256, 256),num_images=None):
        images_paths = [str(path) for path in Path(dataset_path).glob("*.jpg")]
        self.length=len(images_paths)
        if num_images is not None:
            images_paths = images_paths[0:num_images]
        dataset = tf.data.Dataset.from_tensor_slices(images_paths).map(
            lambda path: self.load_tf_image(path, dim=image_size),
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
        )
        dataset = dataset.batch(batch_size,drop_remainder=True)
        dataset = dataset.repeat()
        dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        self.dataset=dataset
    def __len__(self):
        return self.length
    def load_tf_image(self,image_path,dim):
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image= tf.image.resize(image,dim)
        image= image/255.0
        image = tf.image.convert_image_dtype(image, tf.float32)
        return image

In [None]:
loader=TensorflowDatasetLoader("C:/Users/robin/Downloads/train2014/train2014",batch_size=4)
loader.dataset.element_spec
plot_images_grid(next(iter(loader.dataset.take(1))))

In [None]:
url='path/to/style/image'
#url="https://www.edvardmunch.org/images/paintings/the-scream.jpg"
style_image=load_image(url,dim=(input_shape[0],input_shape[1]),resize=True)
style_image=style_image/255.0
show_image(style_image)

style_image=style_image.astype(np.float32)
style_image_batch=np.repeat([style_image],batch_size,axis=0)
style_activations=loss_model.get_activations(style_image_batch)["style"]

content_weight=1e1
style_weight=1e2
total_variation_weight=0.004
epochs=2
num_images=len(loader)
steps_per_epochs=num_images//batch_size
print(steps_per_epochs)
save_path = "./impressionism_painter3"
os.makedirs(save_path, exist_ok=True)

if os.path.isfile(os.path.join(save_path,"model_checkpoint.ckpt.index")):
    test_model.load_weights(os.path.join(save_path,"model_checkpoint.ckpt"))
    print("resuming training ...")
else:
    print("training scratch ...")

In [None]:
epoch_losses=[]
for epoch in range(1,epochs+1):
    print(f"epoch: {epoch}")
    batch_loss=train_step(loader.dataset,style_activations,steps_per_epochs,test_model,loss_model,optimizer,
                          save_path,
                          content_weight,style_weight,total_variation_weight)
    test_model.save_weights(os.path.join(save_path,"model_checkpoint.ckpt"))
    print("Model Checkpointed at: ",os.path.join(save_path,"model_checkpoint.ckpt"))
    print(f"loss: {batch_loss.numpy()}")
    epoch_losses.append(batch_loss)

In [None]:
test_model.save('./impressionism_style_transfer')

model = keras.models.load_model('impressionism_style_transfer',  compile=False)

val_extractor = keras.Model(model.inputs,
                        outputs=model.get_layer('activation_10').output)
val_extractor.save('./impressionism_style_transfer01')