## Dependencies

In [None]:
!pip install tensorflow-addons

In [None]:
import os, random, json, PIL, shutil, re
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow_addons as tfa
from tensorflow.keras import Model, losses, optimizers
import time

## TPU configuration

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()


REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')
AUTO = tf.data.experimental.AUTOTUNE

# Model parameters

In [None]:
HEIGHT = 256
WIDTH = 256
CHANNELS = 3
EPOCHS = 50
BATCH_SIZE = 32

# Load data

In [None]:
try:
    from kaggle_datasets import KaggleDatasets
    GCS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started')
except:
    GCS_PATH = "gs://kds-2ee06126c50f46e241ae426668de2fce51b526beb231d382373580fa"


MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))


## Auxiliar functions

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=CHANNELS)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [HEIGHT, WIDTH, CHANNELS])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        'image_name': tf.io.FixedLenFeature([], tf.string),
        'image':      tf.io.FixedLenFeature([], tf.string),
        'target':     tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    return dataset

## Generating datasets

In [None]:

monet_ds = load_dataset(MONET_FILENAMES).batch(1)
photo_ds = load_dataset(PHOTO_FILENAMES).batch(1)


fast_photo_ds = load_dataset(PHOTO_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_photo_ds = load_dataset(PHOTO_FILENAMES).take(1024).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_monet_ds = load_dataset(MONET_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)

# Visualize predictions

In [None]:
# model_path = '../input/d0eb4546-f602-4898-9e06-29a68eddf64a/model.h5'
model_path = '../input/300-pics-model/model.h5'

# load a model for inference
loaded_model = tf.keras.models.load_model(model_path, compile=False)

# do some inference and plot
with strategy.scope():
    row = 4
    col = 2
    ds_iter = iter(photo_ds)
    plt.figure(figsize=(24, 24))
    for j in range(0, row * (col * 2), 2):
        example_sample = next(ds_iter)
        plt.subplot(row, col * 2, j + 1)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        generated_sample = loaded_model(example_sample)
        plt.subplot(row, col * 2, j + 2)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

## Submission

In [None]:
import PIL
! mkdir ../images
! mkdir ../original_images

In [None]:
i = 1
for img in fast_photo_ds:
    prediction = loaded_model(img, training=False).numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    for pred in prediction:
        im = PIL.Image.fromarray(pred)
        im.save("../images/" + str(i) + ".jpg")
        i += 1
#         if i % 500 == 0:
#             print(str(i) + ' prediction images created')

i = 1
for img in fast_photo_ds.unbatch():
    orig_img = img.numpy()
    orig_img = (orig_img * 127.5 + 127.5).astype(np.uint8)
    orig_img = PIL.Image.fromarray(orig_img)
    orig_img.save("../original_images/" + str(i) + ".jpg")
    i += 1
#     if i % 500 == 0:
#         print(str(i) + ' original images created')

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images/", 'zip', "../images")
shutil.make_archive("/kaggle/working/original_images", 'zip', "../original_images")