[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zanussbaum/pitch2mocap/blob/master/train.ipynb)

In [None]:
import os

if not os.path.isdir('pitch2mocap'):    
    !git clone https://github.com/zanussbaum/pitch2mocap.git

%cd pitch2mocap 
!pip -q install boto3

In [None]:
import datetime
import tensorflow as tf
import boto3 as boto
import matplotlib.pyplot as plt
import zipfile
import glob
from pix2pix import Pix2Pix
%load_ext autoreload
%autoreload 2

In [None]:
ACCESS_KEY = "access"
SECRET_KEY = "secret"
SESSION_TOKEN = "token"

In [None]:
client = boto.client(
    's3',
    aws_access_key_id=ACCESS_KEY,
    aws_secret_access_key=SECRET_KEY,
    aws_session_token=SESSION_TOKEN
)
boto.resource('s3').meta.client.download_file("motion-capture-data", "data.zip", "/tmp/data.zip")

In [None]:
with zipfile.ZipFile("/tmp/data.zip", "r") as z:
    z.extractall("/tmp")

In [None]:
source_imgs = sorted([file for file in glob.glob("/tmp/source*.png")])
target_imgs = sorted([file for file in glob.glob("/tmp/target*.png")])
all_imgs = list(zip(source_imgs, target_imgs))

In [None]:
PATH = "/tmp/"
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)

  input_image = tf.cast(image, tf.float32)

  return input_image

In [None]:
src = load(PATH + "source_100.png")
tar = load(PATH + "target_100.png")
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(src/255.0)
plt.figure()
plt.imshow(tar/255.0)

In [None]:
def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

In [None]:
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]

In [None]:
# normalizing the images to [-1, 1]

def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image

In [None]:
@tf.function()
def random_jitter(input_image, real_image):
  # resizing to 286 x 286 x 3
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # randomly cropping to 256 x 256 x 3
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

In [None]:
def load_image_train(image_file):
  input_file, target_file = image_file[0], image_file[1]
  input_image, target_image = load(input_file), load(target_file)

  input_image, target_image = random_jitter(input_image, target_image)
  input_image, target_image = normalize(input_image, target_image)

  return input_image, target_image

In [None]:
def load_image_test(image_file):
  input_file, target_file = image_file[0], image_file[1]
  input_image, target_image = load(input_file), load(target_file)

  input_image, target_image = resize(input_image, target_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, target_image = normalize(input_image, target_image)

  return input_image, target_image

In [None]:
def generate_images(model, test_input, tar):
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15,15))

    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
      plt.subplot(1, 3, i+1)
      plt.title(title[i])
      # getting the pixel values between [0, 1] to plot it.
      plt.imshow(display_list[i] * 0.5 + 0.5)
      plt.axis('off')
    plt.show()

In [None]:
train_val_dataset = tf.data.Dataset.from_tensor_slices(all_imgs)

split = 3
# we want to split the dataset into .75 train .25 val
train_dataset = train_val_dataset.window(split, split+1).flat_map(lambda ds: ds)
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(int(BUFFER_SIZE*.75))
train_dataset = train_dataset.batch(BATCH_SIZE)

val_dataset = train_val_dataset.skip(split).window(1, split+1).flat_map(lambda ds: ds)
val_dataset = val_dataset.map(load_image_test)
val_dataset = val_dataset.batch(BATCH_SIZE)

In [None]:
gan = Pix2Pix()

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
gan.compile(discriminator_optimizer, generator_optimizer, loss_object)

In [None]:
EPOCHS = 150
SIZE = int(len(all_imgs) * .75)

In [None]:
SIZE

In [None]:
%load_ext tensorboard
%tensorboard --logdir "logs/"

In [None]:
gan.fit(train_dataset, epochs=EPOCHS, size=SIZE, verbose=True, val_ds=val_dataset)

In [None]:
for inp, tar in val_dataset.take(5):
    generate_images(gan.generator, inp, tar)