In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# import tensorflow as tf
# import matplotlib.pyplot as plt
# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import math
import random
import re
import cv2
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow_examples.models.pix2pix import pix2pix
import os
import PIL
import PIL.Image

import pathlib

In [None]:
photo_dir = '../input/gan-getting-started/photo_jpg/'
monet_dir = '../input/gan-getting-started/monet_jpg/'
monet_tfrec = '../input/gan-getting-started/monet_tfrec/monet00-60.tfrec'
BATCH_SIZE=1
BUFFER_SIZE=1000

In [None]:
monet_names = tf.io.gfile.glob(r"../input/gan-getting-started/monet_tfrec/*.tfrec")
print(monet_names)
photo_names = tf.io.gfile.glob(r"../input/gan-getting-started/photo_tfrec/*.tfrec")

In [None]:
def prepare_image(img, dim = 256):    
    img = tf.image.decode_jpeg(img, channels = 3)
    img = (tf.cast(img, tf.float32)*2 / 255.0) - 1
    img = tf.reshape(img, [dim, dim, 3])
    return img

def read_tfrecord(example):
    tfrec_format = {
        'image' : tf.io.FixedLenFeature([], tf.string),
        'image_name' : tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }   
    
    example = tf.io.parse_single_example(example, tfrec_format)
    image = prepare_image(example['image'])
    return image

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    return dataset

In [None]:
monet_ds = load_dataset(monet_names, labeled=True)
photo_ds = load_dataset(photo_names, labeled=True)

In [None]:
data_augmentation = tf.keras.Sequential([
  layers.experimental.preprocessing.RandomFlip("horizontal"),
])

In [None]:
base_path = '../input/gan-getting-started/'
monet_path = os.path.join(base_path, 'monet_jpg')
photo_path = os.path.join(base_path, 'photo_jpg')

In [None]:
def batch_visualization(path, n_images, is_random=True, figsize=(16, 16)):
    plt.figure(figsize=figsize)
    n=2
    w = int(n_images ** .5)
    h = math.ceil(n_images / w)
    
    all_names = os.listdir(path)
    
    image_names = all_names[:n_images]
    if is_random:
        image_names = random.sample(all_names, n_images)
    
    for ind, image_name in enumerate(image_names):
        img = cv2.imread(os.path.join(path, image_name))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
        plt.subplot(h, w, ind +1)
        plt.imshow(img)
        plt.axis('off')
    
    plt.show()

In [None]:
batch_visualization(monet_path,3)

In [None]:
figsize=(16, 16)
n=3
plt.figure(plt.figure(figsize=figsize))
j=0
for raw_record in monet_ds.take(3):
    #print((raw_record))
    
    for i in range(n):
        img=data_augmentation(tf.expand_dims(raw_record, 0))
        plt.subplot(3, n, j*n+i +1)
        plt.imshow(img[0]/2+1/2)
    j+=1
plt.show()

In [None]:
monet_ds

In [None]:
test_monet = monet_ds.take(0).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 
train_monet = monet_ds.skip(0).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 

test_photo = photo_ds.take(0).shuffle(BUFFER_SIZE).batch(BATCH_SIZE,)  
train_photo = photo_ds.skip(0).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 

In [None]:
train_monet = train_monet.map(
  lambda x: (data_augmentation(x, training=True)))

train_photo = train_photo.map(
  lambda x: (data_augmentation(x, training=True)))

In [None]:
plt.figure()
for raw_record in monet_ds.take(2):
    #print((raw_record).numpy())
    plt.imshow((raw_record).numpy()+1)
plt.show()

In [None]:
sample_monet = next(iter(train_monet))
sample_photo = next(iter(train_photo))

In [None]:
OUTPUT_CHANNELS = 3

generator_monet = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_photo = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_monet = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_photo = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [None]:
to_photo = generator_monet(sample_monet)
to_monet = generator_photo(sample_photo)
plt.figure(figsize=(8, 8))
contrast = 1

imgs = [sample_monet, to_photo, sample_photo, to_monet]
title = ['MONET', 'TO PHOTO', 'PHOTO', 'TO MONET']

for i in range(len(imgs)):
    plt.subplot(2, 2, i+1)
    plt.title(title[i])
    if i % 2 == 0:
        plt.imshow(imgs[i][0]/2 + 1/2)
    else:
        plt.imshow(imgs[i][0]/2  * contrast + 1/2)
plt.show()

In [None]:
plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real monet?')
plt.imshow(discriminator_monet(sample_monet)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real photo?')
plt.imshow(discriminator_photo(sample_photo)[0, ..., -1], cmap='RdBu_r')

plt.show()

In [None]:
LAMBDA = 10

In [None]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(real, generated):
    real_loss = loss_obj(tf.ones_like(real), real)
    generated_loss = loss_obj(tf.zeros_like(generated), generated)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss * 0.5

In [None]:
def generator_loss(generated):
    return loss_obj(tf.ones_like(generated), generated)

In [None]:
def calc_cycle_loss(real_image, cycled_image):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return LAMBDA * loss1

In [None]:
def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

In [None]:
generator_monet_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_photo_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_monet_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_photo_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_path = "./"

ckpt = tf.train.Checkpoint(generator_monet=generator_monet,
                           generator_photo=generator_photo,
                           discriminator_monet=discriminator_monet,
                           discriminator_photo=discriminator_photo,
                           generator_monet_optimizer=generator_monet_optimizer,
                           generator_photo_optimizer=generator_photo_optimizer,
                           discriminator_monet_optimizer=discriminator_monet_optimizer,
                           discriminator_photo_optimizer=discriminator_photo_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [None]:
EPOCHS = 200
def generate_images(model, test_input):
    prediction = model(test_input)
    plt.figure(figsize=(12, 12))
    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']
    
    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 1/2 + 1/2)
        plt.axis('off')
    plt.show()

In [None]:
@tf.function
def train_step(real_photo, real_monet):
    with tf.GradientTape(persistent=True) as tape:
        fake_monet = generator_monet(real_photo, training=True)
        cycled_photo = generator_photo(fake_monet, training=True)

        fake_photo = generator_photo(real_monet, training=True)
        cycled_monet = generator_monet(fake_photo, training=True)

        # same_x and same_y are used for identity loss.
        same_photo = generator_photo(real_photo, training=True)
        same_monet = generator_monet(real_monet, training=True)

        disc_real_photo = discriminator_photo(real_photo, training=True)
        disc_real_monet = discriminator_monet(real_monet, training=True)

        disc_fake_photo = discriminator_photo(fake_photo, training=True)
        disc_fake_monet = discriminator_monet(fake_monet, training=True)

        # calculate the loss
        gen_monet_loss = generator_loss(disc_fake_monet)
        gen_photo_loss = generator_loss(disc_fake_photo)

        total_cycle_loss = calc_cycle_loss(real_photo, cycled_photo) + calc_cycle_loss(real_monet, cycled_monet)

        # Total generator loss = adversarial loss + cycle loss
        total_gen_monet_loss = gen_monet_loss + total_cycle_loss + identity_loss(real_monet, same_monet)
        total_gen_photo_loss = gen_photo_loss + total_cycle_loss + identity_loss(real_photo, same_photo)

        disc_photo_loss = discriminator_loss(disc_real_photo, disc_fake_photo)
        disc_monet_loss = discriminator_loss(disc_real_monet, disc_fake_monet)

  # Calculate the gradients for generator and discriminator
    generator_monet_gradients = tape.gradient(total_gen_monet_loss, 
                                        generator_monet.trainable_variables)
    generator_photo_gradients = tape.gradient(total_gen_photo_loss, 
                                        generator_photo.trainable_variables)

    discriminator_photo_gradients = tape.gradient(disc_photo_loss, 
                                            discriminator_photo.trainable_variables)
    discriminator_monet_gradients = tape.gradient(disc_monet_loss, 
                                            discriminator_monet.trainable_variables)

    # Apply the gradients to the optimizer
    generator_monet_optimizer.apply_gradients(zip(generator_monet_gradients, 
                                            generator_monet.trainable_variables))

    generator_photo_optimizer.apply_gradients(zip(generator_photo_gradients, 
                                            generator_photo.trainable_variables))

    discriminator_photo_optimizer.apply_gradients(zip(discriminator_photo_gradients,
                                                discriminator_photo.trainable_variables))

    discriminator_monet_optimizer.apply_gradients(zip(discriminator_monet_gradients,
                                                discriminator_monet.trainable_variables))

In [None]:
import time
for epoch in range(EPOCHS):
    start = time.time()

    n = 0
    for image_photo, image_monet in tf.data.Dataset.zip((train_photo, train_monet)):
        train_step(image_photo, image_monet)
        if n % 10 == 0:
            print ('.', end='')
        n += 1

    #clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
    generate_images(generator_monet, sample_photo)

    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                             ckpt_save_path))

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

In [None]:
# # Run the trained model on the test dataset
# for inp in test_photo.take(20):
#     generate_images(generator_monet, inp)

In [None]:
# Run the trained model on the test dataset
for inp in train_photo.take(20):
    generate_images(generator_monet, inp)

In [None]:
# for inp in test_monet.take(5):
#     generate_images(generator_photo, inp)

In [None]:
# for inp in train_monet.take(5):
#     generate_images(generator_photo, inp)

In [None]:
# for inp in test_photo.take(5):
#     generate_images(generator_photo, inp)

In [None]:
# for inp in train_photo.take(5):
#     generate_images(generator_photo, inp)

In [None]:
# for inp in test_monet.take(5):
#     generate_images(generator_monet, inp)

In [None]:
# for inp in train_monet.take(5):
#     generate_images(generator_monet, inp)

In [None]:
# os.remove("./submission")

In [None]:
os.mkdir("../images")

In [None]:
n=1
savepath="../images/"
for image in photo_ds.take(7000):
    image=generator_monet(tf.expand_dims(image, 0))
    image=(image[0] * 1/2 + 1/2)*255
    img = PIL.Image.fromarray(image.numpy().astype('uint8'), 'RGB')
    name=savepath+'image_'+str(n)+'.jpg'
    img.save(name)
    if (n%300==0):
        print(n)
    n+=1

In [None]:
import shutil
shutil.make_archive('/kaggle/working/images/', 'zip', '../images')

In [None]:
# n=1
# img = PIL.Image.fromarray(data, 'RGB')
# name='image'+str(n)+'.jpg',
# img.save(name)
# img.show()