In [None]:
EXPERIMENT = 'aris-c/land/s3'
# TIMESTAMP = '0731-1633'


In [None]:
import sys
import os
import json
sys.path.append(os.getcwd())

from environment import Environment

env = Environment()

import tensorflow as tf

import importlib
import pandas as pd

from experiment import Experiment
from dataset import Reader
import toolbox as tbx

from models import metrics


In [None]:
experiment_root = os.path.join(env.EXPERIMENT_ROOT, EXPERIMENT)

try:
    timestamp = TIMESTAMP
except:
    timestamp = max(os.listdir(os.path.join(experiment_root, 'ckpt')))

exp = Experiment(experiment_root, EXPERIMENT, timestamp, restore=True)


In [None]:
modules = []
for filename in os.listdir(os.path.join(exp.PROJECT_ROOT, 'models')):
    if filename.endswith('.py'):
        module_name = filename[:-3]
        modules.append(importlib.import_module("." + module_name, package='models'))

model = None
for module in modules:
    if module.__name__[-len(exp.MODEL_NAME):] == exp.MODEL_NAME:
        model = module.GAN(exp)
assert model is not None

generator:tf.keras.Model = model.generator
discriminator:tf.keras.Model = model.discriminator


In [None]:
dataset_reader = Reader(exp, 'create_sample_images.py')
train_dataset = dataset_reader.train_dataset
test_dataset = dataset_reader.test_dataset


In [None]:
checkpoint = tf.train.Checkpoint(
    generator_optimizer=model.generator_optimizer,
    discriminator_optimizer=model.discriminator_optimizer,
    generator=generator,
    discriminator=discriminator,
    step=tf.Variable(0, dtype=tf.int64))

stepoffset = 0
latest_checkpoint = None

print('Trying to restore: ' + os.path.join(exp.output.CKPT))
latest_checkpoint = tf.train.latest_checkpoint(os.path.join(exp.output.CKPT))
checkpoint.restore(latest_checkpoint).expect_partial()
stepoffset = int(checkpoint.step)
print("Loaded checkpoint:", latest_checkpoint)
print("Continue at step:", stepoffset)


In [None]:
from matplotlib import pyplot as plt

os.makedirs(os.path.join(env.EXPERIMENT_ROOT, EXPERIMENT, 'test/'), exist_ok=True)

for step, (example_target, example_input) in test_dataset.enumerate():
    print(os.path.join(env.EXPERIMENT_ROOT, EXPERIMENT, 'test/', f'{int(step)}.png'))

    prediction = generator(example_input, training=False)

    num_images = 5

    fig, ax = plt.subplots(num_images, 3, figsize=(15,5*num_images))

    for i in range(num_images):
        display_list = [example_input[i], example_target[i], prediction[i]]
        tempax = ax[i] if num_images > 1 else ax

        # Change to color profile S3 for Sentinel-3 input images
        # Change to color profile S2 for Sentinel-2 input images
        tbx.plot_tensor(display_list[0], tbx.RGBProfile.S3, ax=tempax[0])
        tempax[0].set_title('Input Image')
        tempax[0].axis('off')

        tbx.plot_tensor(display_list[1], tbx.RGBProfile.S2, ax=tempax[1])
        tempax[1].set_title('Ground Truth')
        tempax[1].axis('off')

        tbx.plot_tensor(display_list[2], tbx.RGBProfile.S2, ax=tempax[2])
        tempax[2].set_title('Predicted Image')
        tempax[2].axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(env.EXPERIMENT_ROOT, EXPERIMENT, 'test/', f'{int(step)}.png'))
    plt.close()
