In [1]:
import tensorflow as tf
import numpy as np
import os

In [9]:
class instance_norm(tf.keras.layers.Layer):
    def __init__(self, epsilon=1e-3):
        super(instance_norm, self).__init__()
        self.epsilon = epsilon

    def build(self, input_shape):
        self.beta = tf.Variable(tf.zeros([input_shape[3]]))
        self.gamma = tf.Variable(tf.ones([input_shape[3]]))

    def call(self, inputs):
        mean, var = tf.nn.moments(inputs, axes=[1, 2], keepdims=True)
        x = tf.divide(tf.subtract(inputs, mean), tf.sqrt(tf.add(var, self.epsilon)))
        
        return self.gamma * x + self.beta

class conv_2d(tf.keras.layers.Layer):
    def __init__(self, filters, kernel, stride):
        super(conv_2d, self).__init__()
        pad = kernel // 2
        self.paddings = tf.constant([[0, 0], [pad, pad],[pad, pad], [0, 0]])
        self.conv2d = tf.keras.layers.Conv2D(filters, kernel, stride, use_bias=False, padding='valid')
        self.instance_norm = instance_norm()

    def call(self, inputs, relu=True):
        x = tf.pad(inputs, self.paddings, mode='REFLECT')
        x = self.conv2d(x)
        x = self.instance_norm(x)

        if relu:
            x = tf.nn.relu(x)
        return x

class resize_conv_2d(tf.keras.layers.Layer):
    def __init__(self, filters, kernel, stride):
        super(resize_conv_2d, self).__init__()
        self.conv = conv_2d(filters, kernel, stride)
        self.instance_norm = instance_norm()
        self.stride = stride

    def call(self, inputs):
        new_h = inputs.shape[1] * self.stride * 2
        new_w = inputs.shape[2] * self.stride * 2
        x = tf.image.resize(inputs, [new_h, new_w], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        x = self.conv(x)
        # return x

        """ Redundant """
        x = self.instance_norm(x)

        return tf.nn.relu(x)

class tran_conv_2d(tf.keras.layers.Layer):
    def __init__(self, filters, kernel, stride):
        super(tran_conv_2d, self).__init__()
        self.tran_conv = tf.keras.layers.Conv2DTranspose(filters, kernel, stride, padding='same')
        self.instance_norm = instance_norm()

    def call(self, inputs):
        x = self.tran_conv(inputs)
        x = self.instance_norm(x)

        return tf.nn.relu(x)

class residual(tf.keras.layers.Layer):
    def __init__(self, filters, kernel, stride):
        super(residual, self).__init__()
        self.conv1 = conv_2d(filters, kernel, stride)
        self.conv2 = conv_2d(filters, kernel, stride)

    def call(self, inputs):
        x = self.conv1(inputs)
        return inputs + self.conv2(x, relu=False)
        

class feed_forward(tf.keras.models.Model):
    def __init__(self):
        super(feed_forward, self).__init__()
        # [filters, kernel, stride]
        self.conv1 = conv_2d(32, 9, 1)     
        self.conv2 = conv_2d(64, 3, 2)           
        self.conv3 = conv_2d(128, 3, 2)     
        self.resid1 = residual(128, 3, 1)         
        self.resid2 = residual(128, 3, 1)          
        self.resid3 = residual(128, 3, 1)     
        self.resid4 = residual(128, 3, 1)     
        self.resid5 = residual(128, 3, 1)    
        #self.tran_conv1 = tran_conv_2d(64, 3, 2)  
        #self.tran_conv2 = tran_conv_2d(32, 3, 2)    
        self.resize_conv1 = resize_conv_2d(64, 3, 2)
        self.resize_conv2 = resize_conv_2d(32, 3, 2)
        self.conv4 = conv_2d(3, 9, 1)              

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.resid1(x)
        x = self.resid2(x)
        x = self.resid3(x)
        x = self.resid4(x)
        x = self.resid5(x)
        x = self.resize_conv1(x)
        x = self.resize_conv2(x)
        x = self.conv4(x, relu=False)
        return (tf.nn.tanh(x) * 150 + 255. / 2)     # for better convergence

In [None]:
import PIL.Image
import cv2
import os


def tensor_to_image(tensor):
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor)>3:
        assert tensor.shape[0] == 1
        tensor = tensor[0]
    return PIL.Image.fromarray(tensor)
    

def load_img(path_to_img, max_dim=None, resize=True):
    img = tf.io.read_file(path_to_img)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    
    if resize:
        new_shape = tf.cast([256, 256], tf.int32)
        img = tf.image.resize(img, new_shape)

    if max_dim:
        shape = tf.cast(tf.shape(img)[:-1], tf.float32)
        long_dim = max(shape)
        scale = max_dim / long_dim
        new_shape = tf.cast(shape * scale, tf.int32)
        img = tf.image.resize(img, new_shape)
        
    img = img[tf.newaxis, :]

    return img


def resolve_video(network, path_to_video, result):
    cap = cv2.VideoCapture(path_to_video)
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(result, fourcc, 30.0, (640,640))

    while cap.isOpened():
        ret, frame = cap.read()
        #frame = cv2.resize(frame, (256, 256), interpolation = cv2.INTER_LINEAR) 

        print('Transfering video...')
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = tf.cast(frame[tf.newaxis, ...], tf.float32) / 255.0

        prediction = network(frame)

        prediction = clip_0_1(prediction)
        prediction = np.array(prediction).astype(np.uint8).squeeze()
        prediction = cv2.cvtColor(prediction, cv2.COLOR_RGB2BGR)

        out.write(prediction)
        cv2.imshow('prediction', prediction)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    out.release()
    cv2.destroyALLWindow()


def create_folder(diirname):
    if not os.path.exists(diirname):
        os.mkdir(diirname)
        print('Directory ', diirname, ' createrd')
    else:
        print('Directory ', diirname, ' already exists')       


def clip_0_1(image):
    return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)

In [2]:
def vgg_layers(layer_names):
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False

    outputs = [vgg.get_layer(name).output for name in layer_names]

    model = tf.keras.Model([vgg.input], outputs)
    return model

In [3]:
def gram_matrix(features, normalize=True):
    batch_size, height, width, filters = features.shape
    features = tf.reshape(features, (batch_size, height * width, filters))

    tran_f = tf.transpose(features, perm=[0, 2, 1])
    gram = tf.matmul(features, features)
    if normalize:
        gram = gram / tf.cast(height * width, tf.float32)
    
    return gram

In [4]:
def style_loss(style_outputs, style_target):
    style_loss = tf.add_n([tf.reduce_mean((style_outputs[name] - style_target[name]) ** 2) for name in style_outputs.keys()])
    return style_loss / len(style_outputs)

In [5]:
def content_loss(content_outputs, content_target):
    content_loss = tf.add_n([tf.reduce_mean((content_outputs[name] - content_target[name]) ** 2) for name in content_outputs.keys()])

In [10]:
def total_variation_loss(image):
    x_var = image[:,:,1:,:] - image[:,:,:-1,:]
    y_var = image[:,1:,:,:] - image[:,:-1,:,:]

    return tf.reduce_mean(tf.square(x_var)) + tf.reduce_mean(tf.square(y_var))

In [7]:
class StyleContentModel(tf.keras.models.Model):
    def __init__(self, style_layers, content_layers):
        super(StyleContentModel, self).__init__()
        self.vgg = vgg_layers(style_layers + content_layers)
        self.style_layers = style_layers
        self.content_layers = content_layers
        self.num_style_layers = len(style_layers)
        self.vgg.trainable = False
    
    def call(self, inputs):
        preprocessed_input = preprocessed_input(inputs)
        outputs = self.vgg(preprocessed_input)
        style_outputs, content_outputs = (outputs[:self.num_style_layers], outputs[self.num_style_layers:])

        style_outputs = [gram_matrix(style_output) for style_output in style_outputs]

        style_dict = {style_layer: value for style_layer, value in zip(self.style_layers, style_outputs)}
        content_dict = {content_layer: value for content_layer, value in zip(self.content_layers, content_outputs)}

        return {'style': style_dict, 'content': content_dict}

In [11]:
"""VGG19 model for Keras.
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np
import warnings
import os

backend = tf.keras.backend
layers = tf.keras.layers
models = tf.keras.models
keras_utils = tf.keras.utils

WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/'
                'releases/download/v0.1/'
                'vgg19_weights_tf_dim_ordering_tf_kernels.h5')
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/'
                       'releases/download/v0.1/'
                       'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')

_IMAGENET_MEAN = None


def VGG19(include_top=True,
          weights='imagenet',
          input_tensor=None,
          input_shape=None,
          pooling=None,
          classes=1000,
          **kwargs):

    if not (weights in {'imagenet', None} or os.path.exists(weights)):
        raise ValueError('The `weights` argument should be either '
                         '`None` (random initialization), `imagenet` '
                         '(pre-training on ImageNet), '
                         'or the path to the weights file to be loaded.')

    if weights == 'imagenet' and include_top and classes != 1000:
        raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
                         ' as true, `classes` should be 1000')

    # Determine proper input shape
    input_shape = _obtain_input_shape(input_shape,
                                      default_size=224,
                                      min_size=32,
                                      require_flatten=include_top,
                                      weights=weights)

    if input_tensor is None:
        img_input = layers.Input(shape=input_shape)
    else:
        if not backend.is_keras_tensor(input_tensor):
            img_input = layers.Input(tensor=input_tensor, shape=input_shape)
        else:
            img_input = input_tensor
    # Block 1
    x = layers.Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block1_conv1')(img_input)
    x = layers.Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block1_conv2')(x)
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

    # Block 2
    x = layers.Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block2_conv1')(x)
    x = layers.Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block2_conv2')(x)
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)

    # Block 3
    x = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv1')(x)
    x = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv2')(x)
    x = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv3')(x)
    x = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv4')(x)
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)

    # Block 4
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv1')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv2')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv3')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv4')(x)
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)

    # Block 5
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv1')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv2')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv3')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv4')(x)
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)

    if include_top:
        # Classification block
        x = layers.Flatten(name='flatten')(x)
        x = layers.Dense(4096, activation='relu', name='fc1')(x)
        x = layers.Dense(4096, activation='relu', name='fc2')(x)
        x = layers.Dense(classes, activation='softmax', name='predictions')(x)
    else:
        if pooling == 'avg':
            x = layers.GlobalAveragePooling2D()(x)
        elif pooling == 'max':
            x = layers.GlobalMaxPooling2D()(x)


    # Ensure that the model takes into account
    # any potential predecessors of `input_tensor`.
    if input_tensor is not None:
        inputs = keras_utils.get_source_inputs(input_tensor)
    else:
        inputs = img_input

    # Create model.
    model = models.Model(inputs, x, name='vgg19')

    # Load weights.
    if weights == 'imagenet':
        if include_top:
            weights_path = keras_utils.get_file(
                'vgg19_weights_tf_dim_ordering_tf_kernels.h5',
                WEIGHTS_PATH,
                cache_subdir='models',
                file_hash='cbe5617147190e668d6c5d5026f83318')
        else:
            weights_path = keras_utils.get_file(
                'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5',
                WEIGHTS_PATH_NO_TOP,
                cache_subdir='models',
                file_hash='253f8cb515780f3b799900260a226db6')
        model.load_weights(weights_path)
        if backend.backend() == 'theano':
            keras_utils.convert_all_kernels_in_model(model)
    elif weights is not None:
        model.load_weights(weights)

    return model


def preprocess_input(x):
    """Preprocesses a tensor or Numpy array encoding a batch of images.
    # Arguments
        x: Input Numpy or symbolic tensor, 3D or 4D.
            The preprocessed data is written over the input data
            if the data types are compatible. To avoid this
            behaviour, `numpy.copy(x)` can be used.
        data_format: Data format of the image tensor/array.
        mode: One of "caffe", "tf" or "torch".
            - caffe: will convert the images from RGB to BGR,
                then will zero-center each color channel with
                respect to the ImageNet dataset,
                without scaling.
            - tf: will scale pixels between -1 and 1,
                sample-wise.
            - torch: will scale pixels between 0 and 1 and then
                will normalize each channel with respect to the
                ImageNet dataset.
    # Returns
        Preprocessed tensor or Numpy array.
    # Raises
        ValueError: In case of unknown `data_format` argument.
    """


    mean = np.array([123.68, 116.779, 103.939])
        
    return (x - mean)





def _obtain_input_shape(input_shape,
                        default_size,
                        min_size,
                        require_flatten,
                        weights=None):
    """Internal utility to compute/validate a model's input shape.
    # Arguments
        input_shape: Either None (will return the default network input shape),
            or a user-provided shape to be validated.
        default_size: Default input width/height for the model.
        min_size: Minimum input width/height accepted by the model.
        data_format: Image data format to use.
        require_flatten: Whether the model is expected to
            be linked to a classifier via a Flatten layer.
        weights: One of `None` (random initialization)
            or 'imagenet' (pre-training on ImageNet).
            If weights='imagenet' input channels must be equal to 3.
    # Returns
        An integer shape tuple (may include None entries).
    # Raises
        ValueError: In case of invalid argument values.
    """
    if weights != 'imagenet' and input_shape and len(input_shape) == 3:
        if input_shape[-1] not in {1, 3}:
            warnings.warn(
                'This model usually expects 1 or 3 input channels. '
                'However, it was passed an input_shape with ' +
                str(input_shape[-1]) + ' input channels.')
        default_shape = (default_size, default_size, input_shape[-1])
    else:
        default_shape = (default_size, default_size, 3)

    if weights == 'imagenet' and require_flatten:
        if input_shape is not None:
            if input_shape != default_shape:
                raise ValueError('When setting `include_top=True` '
                                 'and loading `imagenet` weights, '
                                 '`input_shape` should be ' +
                                 str(default_shape) + '.')
        return default_shape

    if input_shape:
        if input_shape is not None:
            if len(input_shape) != 3:
                raise ValueError(
                    '`input_shape` must be a tuple of three integers.')
            if input_shape[-1] != 3 and weights == 'imagenet':
                raise ValueError('The input must have 3 channels; got '
                                    '`input_shape=' + str(input_shape) + '`')
            if ((input_shape[0] is not None and input_shape[0] < min_size) or
                (input_shape[1] is not None and input_shape[1] < min_size)):
                raise ValueError('Input size must be at least ' +
                                    str(min_size) + 'x' + str(min_size) +
                                    '; got `input_shape=' +
                                    str(input_shape) + '`')
    else:
        if require_flatten:
            input_shape = default_shape
        else:
            input_shape = (None, None, 3)

    if require_flatten:
        if None in input_shape:
            raise ValueError('If `include_top` is True, '
                             'you should specify a static `input_shape`. '
                             'Got `input_shape=' + str(input_shape) + '`')
    return input_shape

In [8]:
def trainer(style_file, dataset_path, weights_path, content_weight, style_weight, 
            tv_weight, learning_rate, batch_size, epochs, debug):

    # Setup the given layers
    content_layers = ['block4_conv2']

    style_layers = ['block1_conv1',
                    'block2_conv1',
                    'block3_conv1',
                    'block4_conv1',
                    'block5_conv1']

    # Build Feed-forward transformer
    network = feed_forward()

    # Build VGG-19 Loss network
    extractor = StyleContentModel(style_layers, content_layers)

    # Load style target image
    style_image = load_img(style_file, resize=False)

    # Initialize content target images
    batch_shape = (batch_size, 256, 256, 3)
    X_batch = np.zeros(batch_shape, dtype=np.float32)

    # Extract style target 
    style_target = extractor(style_image*255.0)['style']

    # Build optimizer
    opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    loss_metric = tf.keras.metrics.Mean()
    sloss_metric = tf.keras.metrics.Mean()
    closs_metric = tf.keras.metrics.Mean()
    tloss_metric = tf.keras.metrics.Mean()


    @tf.function()
    def train_step(X_batch):
        with tf.GradientTape() as tape:

            content_target = extractor(X_batch*255.0)['content']
            image = network(X_batch)
            outputs = extractor(image)
            
            s_loss = style_weight * style_loss(outputs['style'], style_target)
            c_loss = content_weight * content_loss(outputs['content'], content_target)
            t_loss = tv_weight * total_variation_loss(image)
            loss = s_loss + c_loss + t_loss

        grad = tape.gradient(loss, network.trainable_variables)
        opt.apply_gradients(zip(grad, network.trainable_variables))

        loss_metric(loss)
        sloss_metric(s_loss)
        closs_metric(c_loss)
        tloss_metric(t_loss)


    train_dataset = tf.data.Dataset.list_files(dataset_path + '/*.jpg')
    train_dataset = train_dataset.map(load_img,
                                      num_parallel_calls=tf.data.experimental.AUTOTUNE)
    train_dataset = train_dataset.shuffle(1024)
    train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
    train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

    import time
    start = time.time()

    for e in range(epochs):
        print('Epoch {}'.format(e))
        iteration = 0

        for img in train_dataset:

            for j, img_p in enumerate(img):
                X_batch[j] = img_p

            iteration += 1
            
            train_step(X_batch)

            if iteration % 3000 == 0:
                # Save checkpoints
                network.save_weights(weights_path, save_format='tf')
                print('=====================================')
                print('            Weights saved!           ')
                print('=====================================\n')

                if debug:
                    print('step %s: loss = %s' % (iteration, loss_metric.result()))
                    print('s_loss={}, c_loss={}, t_loss={}'.format(sloss_metric.result(), closs_metric.result(), tloss_metric.result()))

    end = time.time()
    print("Total time: {:.1f}".format(end-start))
    

    # Training is done !
    network.save_weights(weights_path, save_format='tf')
    print('=====================================')
    print('             All saved!              ')
    print('=====================================\n')