In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, BatchNormalization

activation = "tanh"
encoder_base = tf.keras.Sequential([
    Conv2D(16, 3, 2, activation=activation),
    BatchNormalization(),
    Conv2D(32, 3, 2, activation=activation),
    BatchNormalization(),
    Conv2D(64, 3, 2, activation=activation),
    BatchNormalization(),
    Conv2D(64, 3, 2, activation=activation),
    BatchNormalization(),
    Conv2D(128, 3, 2, activation=activation),
    BatchNormalization(),
    Conv2D(256, 3, 2, activation=activation),
    BatchNormalization(),
#     Conv2D(128, 3, 2, activation=activation),
])
# decoder_input_shape is fine-tunes for the result to be what we aproximately want
decoder_input_shape = (5, 3, 256) # HWC
decoder_base = tf.keras.Sequential([
    BatchNormalization(),
    Conv2DTranspose(128, 3, 2, activation=activation),
    BatchNormalization(),
    Conv2DTranspose(64, 3, 2, activation=activation),
    BatchNormalization(),
    Conv2DTranspose(32, 3, 2, activation=activation),
    BatchNormalization(),
    Conv2DTranspose(16, 3, 2, activation=activation),
    BatchNormalization(),
    Conv2DTranspose(3, 3, 2, activation="sigmoid"),
])
HEIGHT, WIDTH = decoder_base.compute_output_shape((None, *decoder_input_shape, None))[1:3]
CHANNELS = decoder_base.layers[-1].filters
input_shape = HEIGHT, WIDTH, CHANNELS
input_shape

In [None]:
from cvae import CVAE
model = CVAE(
    input_shape=input_shape,
    decoder_input_shape=decoder_input_shape,
    latent_dim=256,
    encoder_base=encoder_base,
    decoder_base=decoder_base,
)

In [None]:
# loss = tf.keras.losses.MeanSquaredError(name="MSE")

from perceptive_loss import PerceptionLoss
loss = PerceptionLoss(input_shape)

In [None]:
from data import find_splitting_timestamp, make_dataset

sqlite_path = '../fast_images.db'
split_ts = find_splitting_timestamp(sqlite_path, 0.7)
batch_size = 64

def to_float(ts, image):
    with tf.device("cpu"):
        return ts, tf.image.convert_image_dtype(image, "float32")

train_dataset = make_dataset(
    sqlite_path, input_shape, end_ts=split_ts, shuffle=True,
).map(to_float).batch(batch_size)
test_dataset = make_dataset(
    sqlite_path, input_shape, begin_ts=split_ts, shuffle=True,
).map(to_float).batch(batch_size)

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=loss,
    loss_weights=10000.,
)

In [None]:

import matplotlib.pyplot as plt

def ae_map(ts, image):
    return image, image

model.fit(
    train_dataset.map(ae_map),
    epochs=1000,
    validation_data=test_dataset.map(ae_map),
    validation_freq=1,
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

sample = next(iter(test_dataset))[1][:1]
predictions = model(sample)
mean, logvar = model.encoder(sample)
predictions, sample, mean, logvar = [
    x[0].numpy() for x in [predictions, sample, mean, logvar]
]

print(np.mean(np.square(predictions - sample)))
plt.figure()
plt.subplots(ncols=2)
plt.subplot(1, 2, 1).hist(mean)
plt.subplot(1, 2, 2).hist(logvar);
plt.show()

plt.figure()
plt.subplots(ncols=2)
plt.subplot(1, 2, 1).imshow(sample)
plt.subplot(1, 2, 2).imshow(predictions)
plt.show()