In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
import pathlib
from tensorflow import keras


In [None]:
AUTOTUNE = tf.data.AUTOTUNE
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

# Directory image set pipline

In [None]:
paintingset_url = 'Path/To/Dataset'
photoset_url = 'Path/To/Dataset'
painting_dir = pathlib.Path(paintingset_url)
photo_dir = pathlib.Path(photoset_url)
#painting_dir -> 'insert_varibale_name'
#photo_dir -> 'insert_varibale_name'2
image_count = len(list(painting_dir.glob('*/*.jpg')))
image_count2 = len(list(photo_dir.glob('*/*.jpg')))
print(image_count, image_count2)

In [None]:
list_ds = tf.data.Dataset.list_files(str(painting_dir/'*/*'), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)

class_names = np.array(sorted([item.name for item in painting_dir.glob('*') if item.name != "LICENSE.txt"]))
val_size = int(image_count * 0.2)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)

list_ds2 = tf.data.Dataset.list_files(str(photo_dir/'*/*'), shuffle=False)
list_ds2 = list_ds2.shuffle(image_count, reshuffle_each_iteration=False)

class_names2 = np.array(sorted([item.name for item in photo_dir.glob('*') if item.name != "LICENSE.txt"]))
val_size2 = int(image_count2 * 0.5)
train_ds2 = list_ds2.skip(val_size2)
val_ds2 = list_ds2.take(val_size2)

In [None]:
def get_label(file_path):
  # convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
  one_hot = parts[-2] == class_names
  # Integer encode the label
  return tf.argmax(one_hot)

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  return img

def process_path(file_path):
  label = get_label(file_path)
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label

In [None]:
train_ds2 = train_ds2.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds2 = val_ds2.map(process_path, num_parallel_calls=AUTOTUNE)

train_ds2 = train_ds2.prefetch(buffer_size=AUTOTUNE)
val_ds2 = val_ds2.prefetch(buffer_size=AUTOTUNE)

train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)

In [None]:
def random_crop(image):
  cropped_image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
  return cropped_image

In [None]:
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

In [None]:
def random_jitter(image):
  s = 286
  image = tf.image.resize(image, [s, s],
                          method=tf.image.ResizeMethod.BICUBIC)
  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)
  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image

In [None]:
def random_jitter2(image):
  image = tf.image.resize(image, [1024, 1024],
                        method=tf.image.ResizeMethod.BICUBIC)
  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)
  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image

In [None]:
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image

In [None]:
def resize_image(image):
    image = tf.image.resize(image, [256, 256], method=tf.image.ResizeMethod.BICUBIC)
    return image

In [None]:
def preprocess_image_test(image, label):
  image = random_jitter(image)  
  image = normalize(image)
  return image

In [None]:
def preprocess_test_2(image, label):
    image = random_jitter2(image)
    image = normalize(image)
    return image

# Dataset mapping

In [None]:
train_ds = train_ds.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
val_ds = val_ds.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
train_ds2 = train_ds2.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
val_ds2 = val_ds2.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
print(val_ds2)


# Data Visualization

In [None]:
sample_painting = next(iter(train_ds))
sample_photo = next(iter(train_ds2))

In [None]:
plt.subplot(121)
plt.title('painting')
plt.imshow(sample_painting[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Painting with random jitter')
plt.imshow(random_jitter(sample_painting[0]) * 0.5 + 0.5)

In [None]:
plt.subplot(121)
plt.title('photo')
plt.imshow(sample_photo[0] * 0.5 + 0.5)

In [None]:
def resnet_block(num_filters, input_layer):
    init = tf.random_normal_initializer(0., 0.02)
    # first layer convolutional layer
    g = tf.keras.layers.Conv2D(num_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
    g = tfa.layers.InstanceNormalization(axis=-1)(g)
    g = tf.keras.layers.Activation('relu')(g)
    # second convolutional layer
    g = tf.keras.layers.Conv2D(num_filters, (3,3), padding='same', kernel_initializer=init)(g)
    g = tfa.layers.InstanceNormalization(axis=-1)(g)
    # concatenate input layer
    g = tf.keras.layers.Concatenate()([g, input_layer])
    return g

def Generator():
   num_resnet=9
    initializer = tf.random_normal_initializer(0., 0.02)
    inputs = tf.keras.layers.Input(shape=[256, 256, 3])
    
    #first encoding layer
    x = tf.keras.layers.Conv2D(64, (7, 7), padding="same", kernel_initializer = initializer)(inputs)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    #second encoding layer
    x = tf.keras.layers.Conv2D(128, (3, 3), strides = 2, padding="same", kernel_initializer = initializer)(x)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    #third encoding layer
    x = tf.keras.layers.Conv2D(256, (3, 3), strides = 2, padding="same", kernel_initializer = initializer)(x)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    #resnet blocks
    for _ in range(num_resnet):
        x = resnet_block(256, x)
    #first decoding layer
    x = tf.keras.layers.Conv2DTranspose(128, (3, 3), strides = 2, name = "feature_map", padding="same", kernel_initializer = initializer)(x)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    #second decoding layer
    x = tf.keras.layers.Conv2DTranspose(64, (3, 3), strides = 2, padding="same", kernel_initializer = initializer)(x)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    #third decoding layewr
    x = tf.keras.layers.Conv2D(3, (7, 7), padding="same", kernel_initializer = initializer)(x)
    x= tfa.layers.InstanceNormalization(axis = -1)(x)
    output_image = tf.keras.layers.Activation('tanh')(x)
    
    model = tf.keras.Model(inputs, output_image)
    return model

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    inputs = tf.keras.layers.Input(shape=[256, 256, 3])
    
    x = tf.keras.layers.Conv2D(64, (4,4), strides = 2, padding = "same", kernel_initializer=initializer)(inputs)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    
    x = tf.keras.layers.Conv2D(128, (4,4), strides = 2, padding = "same", kernel_initializer=initializer)(x)
    x= tfa.layers.InstanceNormalization(axis=-1)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    
    x = tf.keras.layers.Conv2D(256, (4,4), strides = 2, padding = "same", kernel_initializer=initializer)(x)
    x= tfa.layers.InstanceNormalization(axis=-1)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    
    x = tf.keras.layers.Conv2D(512, (4,4), strides = 2, padding = "same", kernel_initializer=initializer)(x)
    x= tfa.layers.InstanceNormalization(axis=-1)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2, name = "attention_map")(x)
    
    x = tf.keras.layers.Conv2D(512, (4,4), padding = "same", kernel_initializer=initializer)(x)
    x= tfa.layers.InstanceNormalization(axis=-1)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2, name = "features_map")(x)
    
    patch_output = tf.keras.layers.Conv2D(1, (4,4), padding="same", kernel_initializer= initializer)(x)
    model = tf.keras.Model(inputs, patch_output)
    return model
    

In [None]:
generator_f = Generator()
generator_g = Generator()
discriminator_y = Discriminator()
discriminator_x = Discriminator()

In [None]:
LAMBDA = 10
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
x_extractor = keras.Model(inputs=discriminator_x.inputs,
                        outputs=discriminator_x.get_layer('attention_map').output)

y_extractor = keras.Model(inputs=discriminator_y.inputs,
                        outputs=discriminator_y.get_layer('attention_map').output)

f_extractor = keras.Model(inputs=generator_f.inputs,
                        outputs=generator_f.get_layer('l0').output)

g_extractor = keras.Model(inputs=generator_g.inputs,
                        outputs=generator_g.get_layer('l0').output)

In [None]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(real, generated):
    
  one_vector = tf.ones_like(real)
  one_vector = tf.math.scalar_mul(0.9, one_vector)
  real_loss = loss_obj(one_vector, real)
  generated_loss = loss_obj(tf.zeros_like(generated), generated)
  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.8

In [None]:
def generator_loss(generated):
  one_vector = tf.ones_like(generated)
  one_vector = tf.math.scalar_mul(0.9, one_vector)
  return loss_obj(one_vector, generated)

In [None]:
def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

In [None]:
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.1 * loss

In [None]:
def feature_map_loss(map_real, map_generated):
    loss = tf.math.reduce_mean(  tf.math.abs( tf.math.subtract( map_real, map_generated ) ) )
    return LAMBDA * loss

In [None]:
checkpoint_path = "./checkpoints/selfie2painting/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

In [None]:
EPOCHS = 100

In [None]:
def generate_images(model, test_input):
      prediction = model.predict(test_input)

      plt.figure(figsize=(12, 12))

      display_list = [test_input[0], prediction[0]]
      title = ['Input Image', 'Predicted Image']

      for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
      plt.show()

In [None]:
def attention_map(image, features):
    features = features
    abs_features = tf.math.abs(features)
    attention_map = tf.math.reduce_sum(abs_features, axis=3, keepdims=True, name=None)
    largest_index = tf.math.reduce_max(attention_map)
    norm_map = tf.math.scalar_mul((1/largest_index), attention_map, name=None)
    attention = tf.image.resize(norm_map, [256,256], method=tf.image.ResizeMethod.BILINEAR)
    attention2 = tf.concat([attention,attention],axis=-1)
    attention3 = tf.concat([attention2,attention], axis=-1)
    attended = tf.math.multiply(image, attention)
    return attended
    

In [None]:
@tf.function
def train_step(real_x, real_y, config_tensor):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    balance = config_tensor[0]
    cycle_decay = config_tensor[1]
    
    x_a = x_extractor(real_x)
    y_a = y_extractor(real_y)
    
    real_x_attd = attention_map(real_x, x_a)
    fake_y = generator_g(real_x_attd, training=True)
    cycled_x = generator_f(fake_y, training=True)
    
    real_y_attd = attention_map(real_y, y_a)
    fake_x = generator_f(real_y_attd, training=True)
    cycled_y = generator_g(fake_x, training=True)
    
    '''
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)
  
    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)
    '''
    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    real_y_fm = f_extractor(real_y_attd)
    fake_y_fm = f_extractor(fake_y)
    fm_g_loss = feature_map_loss(real_y_fm, fake_y_fm)
    
    real_x_fm = g_extractor(real_x_attd)
    fake_x_fm = g_extractor(fake_x)
    fm_f_loss = feature_map_loss(real_x_fm, fake_x_fm)
    
    total_cycle_loss = (((1-balance) * calc_cycle_loss(real_x, cycled_x) + balance * fm_f_loss) + ((1-balance) * calc_cycle_loss(real_y, cycled_y) + balance * fm_g_loss))
    #total_cycle_loss = cycle_decay * (calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y))
    # Total generator loss = adversarial loss + cycle loss
    
    '''
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y) 
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
    '''
    
    total_gen_g_loss = gen_g_loss + ((1-balance) * total_cycle_loss) + identity_loss(real_y, same_y) + (balance * fm_g_loss)
    total_gen_f_loss = gen_f_loss + ((1-balance) * total_cycle_loss) + identity_loss(real_x, same_x) + (balance * fm_f_loss)
    
    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

In [None]:
for epoch in range(EPOCHS):
  start = time.time()
  n = 0
    
  balance = min(0.9, ((epoch+10)/50))
  cycle_decay = max(0.2, (1-epoch/50))
  config_tensor = tf.constant([balance, cycle_decay])

  for image_x, image_y in tf.data.Dataset.zip((train_ds, train_ds2)):
    train_step(image_x, image_y, config_tensor)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_photo)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

In [None]:
for inp in val_ds2.take(10):
  generate_images(generator_g, inp)

In [None]:
generator_f.save('./Van_Gogh_Painter_Unet3/Painter')

In [None]:
generator_g.save('./Van_Gogh_Painter_Unet3/Photographer')
discriminator_x.save('./Van_Gogh_Painter_Unet3/dx')
discriminator_y.save('./Van_Gogh_Painter_Unet3/dy')