In [None]:
with open('env.txt') as f:
    ENVIRONMENT = f.readlines()[0][:-1]
print(f'running on environment: "{ENVIRONMENT}"')
assert ENVIRONMENT in ['blaze',
                       'colab',
                       'local',
                       'cpom']


In [None]:
if ENVIRONMENT == 'blaze':

    import subprocess
    import os

    command = 'source /usr/local/cuda/CUDA_VISIBILITY.csh'
    process = subprocess.Popen(command, shell=True, executable="/bin/csh", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()

    os.environ['CUDA_VISIBLE_DEVICES'] = stdout.decode()[-2]
    # os.environ['CUDA_HOME'] = '/opt/cuda/cuda-10.0'

    print(stdout.decode())

    command = 'source /server/opt/cuda/enable_cuda_11.0'
    process = subprocess.Popen(command, shell=True, executable="/bin/csh", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()

    !echo $CUDA_VISIBLE_DEVICES



In [None]:
if ENVIRONMENT == 'colab':
    from google.colab import drive
    drive.mount('/content/drive')
    import sys
    sys.path.append('/content/drive/MyDrive/sis2/')
    

In [None]:
import tensorflow as tf

import os
import glob
import time
import datetime
import random

from matplotlib import pyplot as plt
import numpy as np

import sis_helper as helper
from sis_helper import RGBProfile as rgb

from models import pix2pix
from dataset.reader import Reader



In [None]:
from tensorflow.python.client import device_lib

def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if True]

get_available_gpus()

In [None]:
tf.config.list_physical_devices()

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
with tf.compat.v1.Session() as sess:
    device_name = tf.test.gpu_device_name()
    if device_name != '':
        print('TensorFlow is using GPU:', device_name)
    else:
        print('TensorFlow is not using GPU')

In [None]:
TILESIZE = 256
IMG_WIDTH = 256
IMG_HEIGHT = 256

# TILESIZE = 960
# IMG_WIDTH = 1024
# IMG_HEIGHT = 1024

INPUT_CHANNELS = 21
OUTPUT_CHANNELS = 3

if ENVIRONMENT == 'blaze':
    PATH_PREFIX = '/cs/student/msc/aisd/2022/cboehm/projects/li1_data/'
elif ENVIRONMENT == 'colab':
    PATH_PREFIX = f'/content/drive/MyDrive/sis2/data/'
elif ENVIRONMENT == 'local':
    PATH_PREFIX = f'/Users/christianboehm/projects/sis2/data/'
elif ENVIRONMENT == 'cpom':
    PATH_PREFIX = f'/home/cb/sis2/data/'
else:
    PATH_PREFIX = f'~/projects/sis2/data'

PATH_TRAIN = os.path.join(PATH_PREFIX, f'tfrecords{TILESIZE}/')
PATH_VAL = os.path.join(PATH_PREFIX, f'tfrecords{TILESIZE}/')
PATH_LOGS = os.path.join(PATH_PREFIX, 'logs/')
PATH_CKPT = os.path.join(PATH_PREFIX, 'checkpoints/')

# The training set consist of n images
BUFFER_SIZE = number_of_files = len(glob.glob(os.path.join(PATH_TRAIN, '*')))
# BUFFER_SIZE = 1077
# The batch size of 1 produced better results for the U-Net in the original pix2pix experiment
BATCH_SIZE = 10
LAMBDA = 100


In [None]:
model = pix2pix.Model(IMG_WIDTH, IMG_HEIGHT, INPUT_CHANNELS, OUTPUT_CHANNELS, LAMBDA, PATH_LOGS, PATH_CKPT)


In [None]:
# import importlib
# importlib.reload(reader)

dataset_reader = reader.Reader(TILESIZE, IMG_HEIGHT, IMG_WIDTH, PATH_TRAIN, PATH_VAL, BUFFER_SIZE, BATCH_SIZE)
train_dataset = dataset_reader.train_dataset
test_dataset = dataset_reader.test_dataset


In [None]:
# def normalize_tensor(input_image, real_image):
#     return tf.nn.l2_normalize(input_image), tf.nn.l2_normalize(real_image)


In [None]:
sample_dataset = tf.data.TFRecordDataset(os.path.join(PATH_VAL, random.choice(os.listdir(PATH_VAL))))
for element in sample_dataset:
    helper.plot_tensor_sbs(element, TILESIZE)

    s2_tensor, s3_tensor = helper.parse_tfrecord(element, TILESIZE)
    # helper.plot_tensor(s2_tensor, rgb.S2)
    # helper.plot_tensor(s3_tensor, rgb.S3)


In [None]:
# def resize(image1, image2, height, width):
#     image1 = tf.image.resize(image1, [height, width],
#                              method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
#     image2 = tf.image.resize(image2, [height, width],
#                              method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

#     return image1, image2


In [None]:

s2_tensor, s3_tensor = reader.resize(s2_tensor, s3_tensor, IMG_HEIGHT, IMG_WIDTH)


In [None]:
# def random_crop(s2_image, s3_image):
#     #TODO: Crop not to 3channels only!
#     stacked_image = tf.concat([s2_image, s3_image], axis=2)
#     cropped_image = tf.image.random_crop(stacked_image, size=[IMG_HEIGHT, IMG_WIDTH, 24])
    
#     return cropped_image[:,:,:3], cropped_image[:,:,3:]
#     # return resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)

In [None]:
# # Normalizing the images to [-1, 1]
# def normalize(input_image, real_image):
#   # TODO: reflect which normalization is needed... divide by 1024????
#   input_image = (input_image / 127.5) - 1
#   real_image = (real_image / 127.5) - 1

#   return input_image, real_image

In [None]:
# @tf.function()
# def random_jitter(s2_image, s3_image):
#     # Resizing to 286x286
#     s2_image, s3_image = resize(s2_image, s3_image, int(IMG_HEIGHT * 1.11), int(IMG_WIDTH * 1.11))
    
#     # Random cropping back to 256x256
#     s2_image, s3_image = random_crop(s2_image, s3_image)
    
#     if tf.random.uniform(()) > 0.5:
#         # Random mirroring
#         s2_image = tf.image.flip_left_right(s2_image)
#         s3_image = tf.image.flip_left_right(s3_image)
        
#     return s2_image, s3_image


In [None]:
# def load_image_train(tfrecord):
#     s2_image, s3_image = helper.parse_tfrecord(tfrecord, TILESIZE)
#     s2_image, s3_image = resize(s2_image, s3_image, IMG_HEIGHT, IMG_WIDTH)
#     s2_image, s3_image = random_jitter(s2_image, s3_image)
#     s2_image, s3_image = normalize_tensor(s2_image, s3_image)
    
#     return s2_image, s3_image


In [None]:
# def load_image_test(image_file):
#     s2_image, s3_image = helper.parse_tfrecord(image_file, TILESIZE)
#     s2_image, s3_image = resize(s2_image, s3_image, IMG_HEIGHT, IMG_WIDTH)
#     s2_image, s3_image = normalize_tensor(s2_image, s3_image)
    
#     return s2_image, s3_image


In [None]:
# train_file_list = [os.path.join(PATH_TRAIN, file) for file in os.listdir(PATH_TRAIN) if file.endswith('.tfrecord')]
# train_dataset = tf.data.TFRecordDataset(train_file_list)

# # train_dataset = tf.data.Dataset.list_files(str(f'{PATH_TRAIN}/*.tfrecords'))
# train_dataset = train_dataset.map(load_image_train,
#                                   num_parallel_calls=tf.data.AUTOTUNE)
# train_dataset = train_dataset.shuffle(BUFFER_SIZE)
# train_dataset = train_dataset.batch(BATCH_SIZE)

In [None]:
# test_file_list = [os.path.join(PATH_VAL, file) for file in os.listdir(PATH_VAL) if file.endswith('.tfrecord')]

# try:
#     test_dataset = tf.data.TFRecordDataset(test_file_list)
# except tf.errors.InvalidArgumentError:
#     test_dataset = tf.data.TFRecordDataset(train_file_list)
# test_dataset = test_dataset.map(load_image_test)
# #TODO: check if shuffling is helpful (added for validation)
# test_dataset = test_dataset.shuffle(BUFFER_SIZE)
# test_dataset = test_dataset.batch(BATCH_SIZE)


In [None]:
# # down_model = downsample(64, 4)
# down_result = downsample(64, 4)(tf.expand_dims(s3_tensor, 0))
# down_result = downsample(128, 4)(down_result)
# print (down_result.shape)

In [None]:
# up_model = upsample(21, 4)
# up_result = up_model(down_result)
# print (up_result.shape)

In [None]:
generator = model.generator
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)


In [None]:
gen_output = generator(s3_tensor[tf.newaxis, ...], training=False)
helper.plot_tensor(gen_output[0], rgb.S2)


In [None]:
discriminator = model.discriminator
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)


In [None]:
disc_out = discriminator([s3_tensor[tf.newaxis, ...], gen_output], training=False)
plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()


In [None]:
def generate_images(model, example_input, example_target, num_images=5):
    #TODO: training True or False!?
    prediction = model(example_input, training=False)

    for i in range(min(num_images, len(example_input))):
        display_list = [example_input[i], example_target[i], prediction[i]]

        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,10))

        helper.plot_tensor(display_list[0], rgb.S3, ax=ax1)
        ax1.set_title('Input Image')
        ax1.axis('off')

        helper.plot_tensor(display_list[1], rgb.S2, ax=ax2)
        ax2.set_title('Ground Truth')
        ax2.axis('off')

        helper.plot_tensor(display_list[2], rgb.S2, ax=ax3)
        ax3.set_title('Predicted Image')
        ax3.axis('off')

        plt.tight_layout()
        plt.show()


In [None]:
for example_target, example_input in test_dataset.take(1):
    generate_images(generator, example_input, example_target)
    

In [None]:
def fit(train_ds, test_ds, steps):
    # example_target, example_input = next(iter(test_ds.take(1)))
    start = time.time()
    
    for step, (target, input_image) in train_ds.repeat().take(steps).enumerate():
        if (step) % 1000 == 0:
            # display.clear_output(wait=True)
            
            if step != 0:
                print(f'Time taken for 1000 steps: {time.time()-start:.2f} sec\n')
                start = time.time()

            for example_target, example_input in test_dataset.take(1):
                generate_images(generator, example_input, example_target)

            print(f"Step: {step // 1000}k")

        model.train_step(input_image, target, step)

        # Training step
        if (step + 1) % 10 == 0:
            print('.', end='', flush=True)

        # Save (checkpoint) the model every 5k steps
        if (step + 1) % 5000 == 0:
            model.save()

In [None]:
fit(train_dataset, test_dataset, steps=5000)

In [None]:
!tensorboard dev upload --logdir {PATH_LOGS}
