In [None]:
import src
import keras.backend as K
import os
import numpy as np
import sys
import re
import math
import io
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from  matplotlib.animation import FuncAnimation
from matplotlib import colors
from netCDF4 import Dataset
from IPython.display import clear_output
#data folder
sys.path.insert(0, 'C:/Users/pkicsiny/Desktop/TUM/3/ADL4CV/ADL4CV_project/trainings')

sys.path.insert(0, 'C:/Users/pkicsiny/Desktop/TUM/3/ADL4CV/data')
#forces CPU usage
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0" #"" or "-1" for CPU, "0" for GPU
import tensorflow as tf
from tensorflow import keras
from keras.models import load_model
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

In [None]:
def sample_images(epoch, gan_test, gan_test_truth, past_input):
    n = 5
    test_batch = gan_test[:n]
    test_truth = gan_test_truth[:n]
    gen_imgs = generator.predict(test_batch)
    plot_range = past_input 
    fig, axs = plt.subplots(n, plot_range+2, figsize=(16, 16))
    for i in range(n):
        vmax = np.max([np.max(test_batch[i]), np.max(test_truth[i])])
        vmin = 0
        for j in range(plot_range):
            im = axs[i,j].imshow(test_batch[i, :,:,j], vmax=vmax,vmin=vmin)
            axs[i,j].axis('off')
            src.colorbar(im)
            axs[i,j].set_title("Frame t"+str([-past_input+1+j if j < past_input-1 else ""][0]))
        im2 = axs[i,-2].imshow(test_truth[i, :,:,0], vmax=vmax, vmin=vmin)
        axs[i,-2].axis('off')
        src.colorbar(im2)                
        axs[i,-2].set_title("Frame t+1")
        im3 = axs[i,-1].imshow(gen_imgs[i, :,:,0], vmax=vmax, vmin=vmin)
        axs[i,-1].axis('off')
        src.colorbar(im3)
        axs[i,-1].set_title("Prediction t+1")
    fig.savefig("Plots/epoch %d.png" % epoch)
    plt.close()

In [None]:
past = 2
name = f"sgan_{past}-1"

Load dataset.

In [None]:
train, xval, test = src.load_datasets(past_frames=past)

In [None]:
gan_train, gan_truth, gan_val, gan_val_truth, gan_test, gan_test_truth = src.split_datasets(
            train[:2000], xval, test, past_frames=past, augment=True)

Make discriminator labels.

In [None]:
batch_size=64

In [None]:
# Adversarial ground truths
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
#Generator ground truths
g_real = np.ones((batch_size, 1))

Make generator but don't compile.

In [None]:
generator = src.unet(gan_train.shape[1:], dropout=0, batchnorm=True, kernel_size=4)

In [None]:
generator.summary()

Make discriminator and compile.

In [None]:
discriminator = src.spatial_discriminator(condition_shape=gan_train.shape[1:], dropout = 0.25, batchnorm=True)
discriminator.compile(loss=keras.losses.binary_crossentropy,optimizer=keras.optimizers.SGD(),
                      metrics=[keras.metrics.binary_accuracy])

Inputs and outputs of the GAN.

In [None]:
frame_t = keras.layers.Input(shape=gan_train.shape[1:])

In [None]:
frame_t.shape

In [None]:
generated = generator(frame_t)

In [None]:
score = discriminator([frame_t, generated])

Freeze discriminator weights.

In [None]:
discriminator.trainable = False

Compile combined model.

In [None]:
loss_weights=[1,1]

In [None]:
combined = keras.models.Model(inputs=[frame_t], outputs=[generated, score])

In [None]:
combined.compile(loss=[src.custom_loss(loss="l1"), keras.losses.binary_crossentropy], optimizer=keras.optimizers.Adam(0.0002, 0.5),
                 loss_weights=loss_weights, metrics=[src.relative_error_tensor,"accuracy"])

In [None]:

log = {"g_loss":[],
       "d_loss":[],
       "g_metric":[],
       "d_metric":[],
       "d_loss_real":[],
       "d_loss_fake":[],
       "d_test_real":[],
       "d_test_fake":[]}

Train x epochs.

In [None]:
epochs = 2500

In [None]:
for epoch in range(epochs):
    discriminator.trainable = True
    idx = np.random.randint(0, gan_truth.shape[0], batch_size)
    real_imgs = gan_truth[idx]
    training_batch = gan_train[idx]
        
    generated_imgs = generator.predict(training_batch) 
    d_loss_real = discriminator.train_on_batch([training_batch, real_imgs], real)
    d_loss_fake = discriminator.train_on_batch([training_batch, generated_imgs], fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    discriminator.trainable = False 
    
    idx = np.random.randint(0, gan_train.shape[0], batch_size)
    training_batch = gan_train[idx]
    training_truth = gan_truth[idx]
    
    g_loss = combined.train_on_batch(training_batch, [training_truth, g_real])
    
    if g_loss[1] < 0.11 and loss_weights[0] > 2**(-4):
        loss_weights[0] /= 2
        combined.compile(loss=[src.custom_loss(loss="l1"), keras.losses.binary_crossentropy], optimizer=keras.optimizers.Adam(0.0002, 0.5),
                 loss_weights=loss_weights)
    
    log["g_loss"].append(g_loss) #sum, obj, bce
    log["d_loss"].append(d_loss) #sum
    log["g_metric"].append(g_loss[1])
    log["d_metric"].append(d_loss[1])
    log["d_loss_real"].append(d_loss_real)
    log["d_loss_fake"].append(d_loss_fake)
    
    print(f"\033[1m {epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]}]\033[0m \n"+
          f"\033[1m {epoch} [G loss: {g_loss[0]}, G obj.: {g_loss[1]}, G bce.: {g_loss[2]}]\033[0m \n"+
          f"\033[1m {epoch} [real loss: {d_loss_real}, fake loss: {d_loss_fake}]\033[0m")
    if epoch%100 == 0:
        sample_images(epoch, gan_test, gan_test_truth, past)

In [None]:
#%matplotlib notebook
plt.plot(np.array(log["g_loss"])[:,0], alpha=0.3,c="b")
plt.plot(np.array(log["d_loss"])[:,0],alpha=0.3, c="orange")

plt.plot(np.array(log["g_loss"])[:,1], alpha=0.9,c="green", label="L1 objective")

plt.plot(src.smooth(np.array(log["g_loss"])[:,0]),c="b", label="generator")
plt.plot(src.smooth(np.array(log["d_loss"])[:,0]),c="orange", label="discriminator")
plt.grid()
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("sGAN_training_curves_r")

In [None]:
total_g_loss = np.array(log["g_loss"])[:,0]
total_d_loss = np.array(log["d_loss"])[:,0]
smoothed_tgl = src.smooth(np.array(log["g_loss"])[:,0])
smoothed_tdl = src.smooth(np.array(log["d_loss"])[:,0])
objective_loss = np.array(log["g_loss"])[:,1]

# plot 'em
f, (a0, a1) = plt.subplots(2,1, gridspec_kw = {'height_ratios':[5, 2]})
a0.plot(total_g_loss, alpha=0.3, c="b")
a0.plot(total_d_loss, alpha=0.3, c="orange")
a0.plot(smoothed_tgl, c="b", label="generator")
a0.grid()
a0.plot(smoothed_tdl, c="orange", label="discriminator")
a0.legend()
a1.plot(objective_loss, alpha=0.9, c="green", label="L1 objective")
a1.grid()
a1.legend()
f.text(0.5, 0, 'Iterations', ha='center', va='center')
f.text(0, 0.5, 'Loss', ha='center', va='center', rotation='vertical')

f.tight_layout()
f.savefig(name+'_curves.png')

## Save features

Save model history

In [None]:
np.save(name+"_log",log)

Save model weights

In [None]:
combined.save_weights(name+"_model.h5")

Load model and predict

In [None]:
combined.load_weights("C:/Users/pkicsiny/Desktop/TUM/3/ADL4CV/ADL4CV_project/"+name+"_model.h5")

Predict future frames

In [None]:
train, xval, test = src.load_datasets(past_frames=8)

In [None]:
test_data = src.augment_data(test[:100])

In [None]:
test_data.shape

In [None]:
#test
predictions = {}
past_frames = test_data[...,0:past]
test_truth = test_data[...,past:past+1]
for t in range(4): #predict 4 next
    future = combined.predict(past_frames, batch_size=64)
    predictions[f"{t}"] = future[0]
    predictions[f"{t}_labels"] = future[1]
    past_frames = np.concatenate((past_frames[:,:,:,1:], predictions[f"{t}"]), axis=-1)
    test_truth = test_data[...,past+1+t:past+2+t]

Save example predictions

In [None]:
def save_examples(name, test, predictions_dict, past, samples=0):
    fig, axs = plt.subplots(len(samples)*2,past+4, figsize=(32, 32))
    fig.subplots_adjust(wspace=0.3, hspace=0.0)
    for n in range(len(samples)):
        vmax = np.max(test[n,:,:,:past])
        vmin = 0
        print(test.shape)
        for i in range(past):
            im = axs[2*n,i].imshow(test[samples[n], :,:,i], vmax=vmax,vmin=vmin)
            axs[2*n,i].axis('off')
            axs[2*n,i].set_title(f"Past frame {i+1}")
            src.colorbar(im)
            im = axs[2*n+1,i].imshow(test[samples[n], :,:,i], vmax=vmax,vmin=vmin)
            axs[2*n+1,i].axis('off')
            axs[2*n+1,i].set_title(f"Past frame {i+1}")
            src.colorbar(im)
        for i in range(past,past+4):
            im = axs[2*n,i].imshow(predictions_dict[f"{i-past}"][samples[n], :,:,0], vmax=vmax, vmin=vmin)
            axs[2*n,i].axis('off')
            axs[2*n,i].set_title(f"Predicted frame {i-past+1}")
            src.colorbar(im)
            im = axs[2*n+1,i].imshow(test[samples[n], :,:,i], vmax=vmax, vmin=vmin)
            axs[2*n+1,i].axis('off')
            axs[2*n+1,i].set_title(f"Reference frame {i-past+1}")
            src.colorbar(im)
    fig.savefig(f"Plots/{name}_sequence_prediction.png")
    plt.close()


In [None]:
save_examples(name, test_data, predictions, past, samples=[33,46,54])

Calculate scores

In [None]:
list(predictions.keys())

In [None]:
norms = np.load(sys.path[0]+"/5min_norms_compressed.npz")["arr_1"]

In [None]:
# *4 bc of augmentaion (it concats the frames so the 0th 1000th 2000th and 3000th are the same sample just rotated)
test_norms = list(norms[9000:])*4

In [None]:
#renormalize test samples
renormalized_test = np.array([sample * np.array(test_norms)[i] for i, sample in enumerate(test_data)])
renormalized_predictions = np.transpose((np.array([[sample * np.array(test_norms)[i] for i, sample in enumerate(predictions[key])] for key in ['0', '1', '2', '3']])[:,:,:,:,0]), (1,2,3,0))

In [None]:
renormalized_predictions.shape

In [None]:
#thresholds: 2, 8, 42
thresholds = [10, 50, 100]
scores = {}
for t in range(renormalized_predictions.shape[-1]): # loop over the predictions (4)
    for s in thresholds: # make a dict entry for each threshold score
        scores[f"pred_{t+1}_threshold_{s}"] = src.calculate_skill_scores(renormalized_predictions[...,t:t+1],
                                                                                     renormalized_test[...,past+t:past+1+t],
                                                                                     x=renormalized_test[...,:past],
                                                                                     threshold=s)

In [None]:
scores["pred_1_threshold_10"].keys()

In [None]:
np.save(name+"_scores",scores)

In [None]:
name

In [None]:
loaded_scores = np.load(sys.path[1]+"/"+name+"/"+name+"_scores.npy").item()

In [None]:
list(loaded_scores.keys())

In [None]:
np.mean((pd.Series(scores["pred_4_threshold_100"]["corr_to_truth"]).dropna()))

In [None]:
np.mean((pd.Series(scores["pred_2_threshold_100"]["corr_to_truth"]).dropna()))

In [None]:
np.mean((pd.Series(scores["pred_3_threshold_100"]["corr_to_truth"]).dropna()))

In [None]:
np.mean((pd.Series(scores["pred_4_threshold_100"]["corr_to_truth"]).dropna()))

In [None]:
def wasserstein_loss(y_true, y_pred):
    """Calculates the Wasserstein loss for a sample batch.
    The Wasserstein loss function is very simple to calculate. In a standard GAN, the discriminator
    has a sigmoid output, representing the probability that samples are real or generated. In Wasserstein
    GANs, however, the output is linear with no activation function! Instead of being constrained to [0, 1],
    the discriminator wants to make the distance between its output for real and generated samples as large as possible.
    The most natural way to achieve this is to label generated samples -1 and real samples 1, instead of the
    0 and 1 used in normal GANs, so that multiplying the outputs by the labels will give you the loss immediately.
    Note that the nature of this loss means that it can be (and frequently will be) less than 0."""
    return K.mean(y_true * y_pred)

In [None]:
def noisy_d_labels(real, fake):
    # idea: https://arxiv.org/pdf/1606.03498.pdf
    batch_size = len(real)
    five_percent = int(0.05*batch_size)
    idx = np.random.randint(0, batch_size, five_percent)
    d_real = np.ones_like(real)
    d_fake = np.zeros_like(fake)
    d_real[idx] = 0
    d_fake[idx] = 1
    return d_real, d_fake