In [None]:
import os
import re
import pathlib
import time
import itertools
import glob
import datetime
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import numpy as np
from IPython import display
from sklearn.model_selection import train_test_split

from deep_shadow import *
from utils import *

%load_ext autoreload
%autoreload 2

## Loading dataset

In [None]:
cities = ['la','bos','nyc','chi','sp', 'bue', 'joh', 'syd', 'tok', 'par', 'mex', 'sea', 'aus']
dates = ['winter', 'spring', 'summer']
zoom = 16
shadow_path = 'data/shadows/'
height_path = 'data/heights/'
checkpoint_name = 'all-all'
checkpoint_path = 'training_checkpoints/%s'%(checkpoint_name)

TILES_PER_CITY = 270
BATCH_SIZE = 2

In [None]:
train_dataset, test_dataset = get_train_test(height_path, shadow_path, cities, dates, zoom, TILES_PER_CITY, BATCH_SIZE)

In [None]:
for inp, real, lat, dat in test_dataset.take(1):
    inp = inp[0]
    real = real[0]
    lat = lat[0]
    dat = dat[0]
    
plt.imshow((inp * 0.5 + 0.5) * 10)
plt.show()

plt.imshow((real * 0.5 + 0.5))
plt.show()

plt.figure(figsize=(6, 6))
for i in range(4):
    rj_inp, rj_re, rj_la, rj_da = random_jitter(inp, real, lat, dat)
    plt.subplot(2, 2, i + 1)
    plt.imshow((rj_inp * 0.5 + 0.5) * 10)
plt.show()

## Build generator and discriminator

In [None]:
deep_shadow = DeepShadow(512,512)
tf.keras.utils.plot_model(deep_shadow.generator, show_shapes=True, dpi=64)

## Build discriminator

In [None]:
tf.keras.utils.plot_model(deep_shadow.discriminator, show_shapes=True, dpi=64)

## Generate images

In [None]:
for example_input, example_target, example_lat, example_date in test_dataset.take(5):
    generate_images(deep_shadow.generator, example_input, example_lat, example_date, example_target)

## Training

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/ --port 8089

In [None]:
deep_shadow.fit(checkpoint_path, train_dataset, test_dataset, 100000)