# pix2pix for Maps to Aerial Image Translation

In [None]:
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import split_dataset
import matplotlib.pyplot as plt
import numpy as np
import pixutils as pxu
from pix2pix import UNet, PatchGAN, fit

In [None]:
train_path = "pix2pix-maps/train/*.jpg"
val_path = "pix2pix-maps/val/*.jpg"

buffer_size = 256
resize_to = 256
patch_size = 256
num_of_patches = (600//patch_size)**2

config = {
    "batch_size": 1,
}

## Run the following code block on Kaggle only

In [None]:
train_path = f"/kaggle/input/{train_path}"
val_path = f"/kaggle/input/{val_path}"

## Load dataset

In [None]:
train = tf.data.Dataset.list_files(train_path)
val = tf.data.Dataset.list_files(val_path)

In [None]:
train = train.map(pxu.load_image)
train = train.map(lambda input_image, real_image: pxu.extract_patches(input_image, real_image, patch_size, num_of_patches))
train = train.map(lambda input_patches, real_patches: pxu.random_jitter(input_patches, real_patches, 286))
train = train.map(pxu.rescale_images)
train = train.shuffle(buffer_size) # in batches of 'num_of_patches'
train = train.unbatch()
train = train.batch(config["batch_size"])

val = val.map(pxu.load_image)
val = val.map(lambda input_image, real_image: pxu.resize_images(input_image, real_image, resize_to))
val = val.map(pxu.rescale_images)

#size = int(val.cardinality())
#val_size = int(0.5 * size)
#test_size = size - val_size

#val, test = split_dataset(val, left_size=val_size, right_size=test_size, shuffle=True)

val = val.batch(config["batch_size"])
#test = test.batch(config["batch_size"])

## Visualize a few images

In [None]:
for index, (input_image, real_image) in enumerate(train.take(2)):
    pxu.show(
        tf.reshape(input_image, input_image.shape[1:]),
        tf.reshape(real_image, real_image.shape[1:]),
        index + 1,
        "train"
    )

In [None]:
for index, (input_image, real_image) in enumerate(val.take(2)):
    pxu.show(
        tf.reshape(input_image, input_image.shape[1:]),
        tf.reshape(real_image, real_image.shape[1:]),
        index + 1,
        "val"
    )

## Create models

In [None]:
G = UNet(input_shape=(resize_to, resize_to, 3))
D = PatchGAN(input_shape=(resize_to, resize_to, 3))

g_optim = Adam(learning_rate=0.0002, beta_1=0.5)
d_optim = Adam(learning_rate=0.0002, beta_1=0.5)

## Train

In [None]:
fit(train, val, 150, G, D, g_optim, d_optim, 100)