In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import json

from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.dataset import Dataset

import sys, argparse, os
import numpy as np
import pandas as pd
import random

import torch.nn.functional as F
import torch.utils.data as torchdata

sys.path.insert(0, '../')
sys.path.append("/scratch2/ml_flood/mlflood/")
from pathlib import Path
from mlflood.conf import PATH_DATA
from mlflood.conf import rain_const, waterdepth_diff_const
import h5py

from torch.utils.tensorboard import SummaryWriter
from mlflood.utils import new_log
from models.utae import UTAE
from models.utae_old import UTAE as UTAE_old
from models.CNNrolling import CNNrolling
from models.unet3d import UNet3D
from models.unet import UNet
from dataset_utae import load_test_dataset, dataloader_args_utae_test
from dataset_old import load_test_dataset as load_test_unet
from dataset_old import dataloader_args_test

from training import *

from evaluation import predict_event, predict_batch, mae_event
from evaluation import plot_maes, multiboxplot, plot_answer_sample, boxplot_mae

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using: ', device)
args = []   

In [None]:
path_exp_utae = "/scratch2/ml_flood/data/checkpoints/709/cluster/utae_L1/experiment_0/"
# path_exp_utae = "/scratch2/ml_flood/data/checkpoints/709/cluster/may_24/utae_head_8/experiment_0/"

path_exp_cnn = "/scratch2/ml_flood/data/checkpoints/709/cluster/cnn_temp/experiment_0/"
# path_exp_cnn = "/scratch2/ml_flood/data/checkpoints/709/cluster/apr_22/utae_L1_upd/experiment_5/"
# path_exp_cnn_1 = "/scratch2/ml_flood/data/checkpoints/709/cnn_temp/experiment_5/"

path_exp_unet = "/scratch2/ml_flood/data/checkpoints/709/cluster/unet_temp/experiment_0/"
# path_exp_unet = "/scratch2/ml_flood/data/checkpoints/709/cluster/may_24/utae_head_32/experiment_0/"

# path_exp_utae64 = "/scratch2/ml_flood/data/checkpoints/709/cluster/may_24/utae_head_64/experiment_0/"

***In order for the models to be compared they should be run with the same hp (batch size, predict_ahead, dim_patch, timestep)...in this way, we can use the same test data***

In [None]:
## Stefania ##
parser = argparse.ArgumentParser(description="evaluation")
parser.add_argument('--n_head', type=int, default=16)
parser.add_argument("--catchment_kwargs", default='./default_catchment_kwargs.yml', type=str, 
                    help="path to catchment kwargs saved in yml file")

exp_hp = ['--catchment_kwargs=../mlflood/exp_yml/exp_utae.yml']
args = parser.parse_args(exp_hp)

with open(args.catchment_kwargs) as file:
    catchment_kwargs = yaml.full_load(file)
    
# Always in eval mode
catchment_kwargs['fix_indexes'] = True 

In [None]:
## Priyanka ##

### Catchment settings
catchment_kwargs = {}
catchment_kwargs["num"] = "709"
catchment_kwargs["tau"] = 0.5
catchment_kwargs["timestep"]= 5      # for timestep >1 use CNN rolling or Unet
catchment_kwargs["sample_type"]="single"
catchment_kwargs["dim_patch"]=256
catchment_kwargs["fix_indexes"]=True
catchment_kwargs["border_size"] = 0
catchment_kwargs["normalize_output"] = False
catchment_kwargs["use_diff_dem"] = False
catchment_kwargs["num_patch"] = 10      # number of patches to generate from a timestep
catchment_kwargs["predict_ahead"] = 12

In [None]:
## Model 1 ##
parser1 = argparse.ArgumentParser(description="evaluation")
parser1.add_argument('--n_head', type=int, default=16)
str_args1 = ["--n_head=8"]
args_m1 = parser1.parse_args(str_args1)

## Model 2 for CNN ##
# parser2 = argparse.ArgumentParser(description="evaluation")
# parser.add_argument("--catchment_kwargs", default='./default_catchment_kwargs.yml', type=str, 
#                     help="path to catchment kwargs saved in yml file")
# exp_hp = ['--catchment_kwargs=../mlflood/exp_yml/exp_cnn.yml']
# args_m2 = parser.parse_args(exp_hp)
# with open(args.catchment_kwargs) as file:
#     catchment_kwargs = yaml.full_load(file)
# catchment_kwargs['fix_indexes'] = True 

## Model 3 ##
parser3 = argparse.ArgumentParser(description="evaluation")
parser3.add_argument('--n_head', type=int, default=16)
str_args3 = ["--n_head=32"]
args_m3 = parser3.parse_args(str_args3)

## Model 4 ##
parser4 = argparse.ArgumentParser(description="evaluation")
parser4.add_argument('--n_head', type=int, default=16)
str_args4 = ["--n_head=64"]
args_m4 = parser4.parse_args(str_args4)

In [None]:

model_utae = UTAE(args_m1)  # You don't need args in UTAE
file_path1 = path_exp_utae + "model.pth.tar"
model_utae.load_state_dict(torch.load(file_path1))
model_utae.cuda()

# model_cnn_1 = CNNrolling(args, catchment_kwargs)  # You don't need args in UTAE
# file_path1 = path_exp_cnn_1 + "model.pth.tar"
# model.load_state_dict(torch.load(file_path1))
# model.cuda()

model_cnn = UTAE_old(args)
file_path1 = path_exp_cnn + "model.pth.tar"
model_cnn.load_state_dict(torch.load(file_path1))
model_cnn.cuda()

#model_unet = UNet(args)  
model_unet = UTAE(args_m3)
file_path1 = path_exp_unet + "model.pth.tar"
model_unet.load_state_dict(torch.load(file_path1))
model_unet.cuda()

#utae_64
model_utae64 = UTAE(args_m4)  # You don't need args in UTAE
file_path1 = path_exp_utae64 + "model.pth.tar"
model_utae64.load_state_dict(torch.load(file_path1))
model_utae64.cuda()

In [None]:
dataloaders = {}
dataset = load_test_dataset(catchment_kwargs)
                                                             
dataloaders["test"] = dataloader_args_utae_test(dataset, catchment_num = catchment_kwargs['num'])
dataset_test = dataloaders["test"]

## A. 12-step ahead

In [None]:
event_num = 0
start_ts = 5

pred_utae, gt_utae, mask_utae = predict_event(model_utae, dataset, event_num, 'utae', start_ts=None, ar = False, T = None)
pred_cnn, gt_cnn, mask_cnn = predict_event(model_cnn, dataset, event_num, 'cnn', start_ts=None, ar = False, T = None)


In [None]:
del model_utae, model_cnn
pred_unet, gt_unet, mask_unet = predict_event(model_unet, dataset, event_num, 'utae_32', start_ts=None, ar = False, T = None)
pred_utae64, gt_utae64, mask_utae64 = predict_event(model_utae64, dataset, event_num, 'utae_64', start_ts=None, ar = False, T = None)


In [None]:
# what are we plotting here? let's make sure the x axis is correct (timesteps) anf the y axis (cm or meters?)

In [None]:
save_folder = "/scratch2/ml_flood/data/checkpoints/709/cluster/may_24/utae_head_8/experiment_0/results/"

In [None]:
mae_utae = mae_event(pred_utae, gt_utae, mask_utae)
mae_cnn = mae_event(pred_cnn, gt_cnn, mask_cnn)
mae_unet = mae_event(pred_unet, gt_unet, mask_unet)
mae_utae64 = mae_event(pred_utae64, gt_utae64, mask_utae64)

maes = [mae_utae, mae_cnn, mae_unet, mae_utae64]

labels = ['utae_8', 'utae_16', 'utae_32', 'utae_64']
plot_maes(maes, labels, start_ts=0, save_folder = save_folder, name = '12_ts_ahead', title = "MAE for 12 ts ahead")

In [None]:
pred_ts = 12 # here it indicated how many timesteps ahead we are looking at. If None, all timesteps are computed

In [None]:
lims = (0.1,0.2, 0.5, 1)

data_utae = boxplot_mae(pred_utae, gt_utae, mask_utae, lims=lims, pred_ts = pred_ts)
data_cnn = boxplot_mae(pred_cnn, gt_cnn, mask_cnn, lims=lims, pred_ts = pred_ts)
data_unet = boxplot_mae(pred_unet, gt_unet, mask_unet, lims=lims, pred_ts = pred_ts)
data_utae64 = boxplot_mae(pred_utae64, gt_utae64, mask_utae64, lims=lims, pred_ts = pred_ts)

In [None]:
ticks = ['0-10 cm', '10-20 cm', '20-50cm', '50-100cm','>100cm']
#labels = ['cnn', 'graph', 'baseline']
labels = ['utae_8', 'utae_16', 'utae_32', 'utae_64']
colors = ['#EF8A62', '#67A9CF', '#1B9E77', '#CA0020', '#998EC3']
data = [data_utae, data_cnn,  data_unet, data_utae64]

multiboxplot(data, ticks, labels, colors, save_folder = save_folder, name = '1ts_ahead', title = "Multiboxplots for models utae with n_heads=[8,16,32]")

In [None]:
model_name = 'unet'
data = data_unet

fig, ax = plt.subplots(figsize = [12,5])
ax.set_ylabel('Absolute Error (cm)')
ax.set_xlabel('Water Depth')
flierprops = dict(marker='d', markerfacecolor='black', markersize=4, linestyle='none', markeredgecolor='black')
bp = ax.boxplot(data, showfliers=False, patch_artist=True, flierprops=flierprops)                                #‘Hide Outlier Points’
ax.set_xticklabels(ticks)
plt.title("Visualization for model {} ".format(model_name), fontsize = 14 ,fontweight="bold")
# filename = save_folder + 'Mae_boxplot_1_ts_ahead_' + model_name + '.png'
# plt.savefig(filename, dpi=1200)



In [None]:
zoom = [500,1000,500,1000]
plot_answer_sample(pred_utae, gt_utae, mask_utae, ts=12, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '12ts_ahead_utae8')
plot_answer_sample(pred_cnn, gt_cnn, mask_cnn, ts=12, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '12ts_ahead_utae16')
plot_answer_sample(pred_unet, gt_unet, mask_unet, ts=12, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '12ts_ahead_utae32')
plot_answer_sample(pred_utae64, gt_utae64, mask_utae64, ts=12, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '12ts_ahead_utae64')

## OTHER MODELS comparison

In [None]:
path_exp_utae = "/scratch2/ml_flood/data/checkpoints/709/cluster/may_24/utae_head_32/experiment_0/"   # best performing utae

path_exp_utae64 = "/scratch2/ml_flood/data/checkpoints/709/cluster/may_24/utae_head_64/experiment_0/"

#path_exp_cnn = "/scratch2/ml_flood/data/checkpoints/709/cnn_temp/experiment_3/" #does not work for stefania during predictions

path_exp_unet = "/scratch2/ml_flood/data/checkpoints/709/cluster/apr_22/unet_L1_upd/experiment_0/"   #does not work for stefania during loading of model

path_exp_unet3d = "/scratch2/ml_flood/data/checkpoints/709/cluster/apr_22/unet3d_L1/experiment_0/"

In [None]:
## Priyanka ##

### Catchment settings
catchment_kwargs = {}
catchment_kwargs["num"] = "709"
catchment_kwargs["tau"] = 0.5
catchment_kwargs["timestep"]= 5      # for timestep >1 use CNN rolling or Unet
catchment_kwargs["sample_type"]="single"
catchment_kwargs["dim_patch"]=256
catchment_kwargs["fix_indexes"]=True
catchment_kwargs["border_size"] = 0
catchment_kwargs["normalize_output"] = False
catchment_kwargs["use_diff_dem"] = False
catchment_kwargs["num_patch"] = 10      # number of patches to generate from a timestep
catchment_kwargs["predict_ahead"] = 12

In [None]:
## Model UTAE ##
parser1 = argparse.ArgumentParser(description="evaluation")
parser1.add_argument('--n_head', type=int, default=16)
str_args1 = ["--n_head=32"]
args_m1 = parser1.parse_args(str_args1)

## Model 2 ##
parser2 = argparse.ArgumentParser(description="evaluation")
parser2.add_argument('--n_head', type=int, default=16)
str_args2 = ["--n_head=64"]
args_m2 = parser2.parse_args(str_args2)

## Model 3 ##
args = []

## Model 4 ##
args = []

In [None]:
model_utae = UTAE(args_m1) 
file_path1 = path_exp_utae + "model.pth.tar"
model_utae.load_state_dict(torch.load(file_path1))
model_utae.cuda()

#model_cnn = CNNrolling(args)
#file_path1 = path_exp_cnn + "model.pth.tar"
#model_cnn.load_state_dict(torch.load(file_path1))
#model_cnn.cuda()

model_unet = UNet(args)
file_path1 = path_exp_unet + "model.pth.tar"
model_unet.load_state_dict(torch.load(file_path1))
model_unet.cuda()

model_unet3d = UNet3D(args)
file_path1 = path_exp_unet3d + "model.pth.tar"
model_unet3d.load_state_dict(torch.load(file_path1))
model_unet3d.cuda()

#utae_64
model_utae64 = UTAE(args_m2)  # You don't need args in UTAE
file_path1 = path_exp_utae64 + "model.pth.tar"
model_utae64.load_state_dict(torch.load(file_path1))
model_utae64.cuda()


In [None]:
dataloaders = {}
dataset = load_test_dataset(catchment_kwargs)
                                                             
dataloaders["test"] = dataloader_args_utae_test(dataset, catchment_num = catchment_kwargs['num'])
dataset_test = dataloaders["test"]

## for unet need different dataset
dataloaders_u = {}
dataset_u = load_test_unet(catchment_kwargs)
                                                             
dataloaders_u["test"] = dataloader_args_test(dataset_u, catchment_num = catchment_kwargs['num'])
dataset_test_u = dataloaders_u["test"]

In [None]:
# Change model string in predict_event acccordingly to the model used

event_num = 0
start_ts = None

pred_utae, gt_utae, mask_utae = predict_event(model_utae, dataset, event_num, 'utae', start_ts=None, ar = False, T = None)


In [None]:
pred_unet3d, gt_unet3d, mask_unet3d = predict_event(model_unet3d, dataset, event_num, 'unet3d', start_ts=None, ar = False, T = None)

In [None]:
pred_unet, gt_unet, mask_unet = predict_event(model_unet, dataset_u, event_num, 'unet', start_ts=None, ar = False, T = None)
pred_unet.shape

In [None]:
pred_utae64, gt_utae64, mask_utae64 = predict_event(model_utae64, dataset, event_num, 'utae_64', start_ts=None, ar = False, T = None)

In [None]:
save_folder = "/scratch2/ml_flood/data/checkpoints/709/cluster/apr_22/unet3d_L1_upd/experiment_0/results/"

In [None]:
mae_utae = mae_event(pred_utae, gt_utae, mask_utae)
mae_unet = mae_event(pred_unet, gt_unet, mask_unet)
mae_unet3d = mae_event(pred_unet3d, gt_unet3d, mask_unet3d)
mae_utae64 = mae_event(pred_utae64, gt_utae64, mask_utae64)

maes = [mae_utae, mae_unet, mae_unet3d, mae_utae64]

labels = ['utae_32', 'unet' , 'unet3d', 'utae_64']
plot_maes(maes, labels, start_ts=17, save_folder = save_folder, name = '12_ts_ahead', title = "MAE for 12 ts ahead")

In [None]:
pred_ts = 12 # here it indicated how many timesteps ahead we are looking at. If None, all timesteps are computed

In [None]:
lims = (0.1,0.2, 0.5, 1)

data_utae = boxplot_mae(pred_utae, gt_utae, mask_utae, lims=lims, pred_ts = pred_ts)
data_unet = boxplot_mae(pred_unet, gt_unet, mask_unet, lims=lims, pred_ts = pred_ts)
data_unet3d = boxplot_mae(pred_unet3d, gt_unet3d, mask_unet3d, lims=lims, pred_ts = pred_ts)
data_utae64 = boxplot_mae(pred_utae64, gt_utae64, mask_utae64, lims=lims, pred_ts = pred_ts)

In [None]:
ticks = ['0-10 cm', '10-20 cm', '20-50cm', '50-100cm','>100cm']
colors = ['#EF8A62', '#67A9CF', '#1B9E77', '#CA0020', '#998EC3']
data = [data_utae,  data_unet, data_unet3d, data_utae64]

multiboxplot(data, ticks, labels, colors, save_folder = save_folder, name = '1ts_ahead', title = "Multiboxplots for models utae, unet and unet3d")

In [None]:
model_name = 'utae_32'
data = data_utae

fig, ax = plt.subplots(figsize = [12,5])
ax.set_ylabel('Absolute Error (cm)')
ax.set_xlabel('Water Depth')
flierprops = dict(marker='d', markerfacecolor='black', markersize=4, linestyle='none', markeredgecolor='black')
bp = ax.boxplot(data, showfliers=False, patch_artist=True, flierprops=flierprops)                                #‘Hide Outlier Points’
ax.set_xticklabels(ticks)
plt.title("Visualization for model {} ".format(model_name), fontsize = 14 ,fontweight="bold")
filename = save_folder + 'Mae_boxplot_1_ts_ahead_' + model_name + '.png'
plt.savefig(filename, dpi=1200)



In [None]:
zoom = [500,1000,500,1000]
ts = 12
plot_answer_sample(pred_utae, gt_utae, mask_utae, ts=ts, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_utae32')
plot_answer_sample(pred_unet, gt_unet, mask_unet, ts=ts, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_unet')
plot_answer_sample(pred_unet3d, gt_unet3d, mask_unet3d, ts=ts, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_unet3d')
plot_answer_sample(pred_utae64, gt_utae64, mask_utae64, ts=ts, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_utae64')

In [None]:
zoom = [500,1000,500,1000]
ts = 10
plot_answer_sample(pred_utae, gt_utae, mask_utae, ts=ts, zoom=None, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_utae32')
plot_answer_sample(pred_unet, gt_unet, mask_unet, ts=ts, zoom=None, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_unet')
plot_answer_sample(pred_unet3d, gt_unet3d, mask_unet3d, ts=ts, zoom=None, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_unet3d')
plot_answer_sample(pred_utae64, gt_utae64, mask_utae64, ts=ts, zoom=None, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_utae64')

## Weighted loss comparison

In [None]:
path_exp_wo = "/scratch2/ml_flood/data/checkpoints/709/cluster/apr_22/utae_L1/experiment_0/"

path_exp_with = "/scratch2/ml_flood/data/checkpoints/709/cluster/apr_22/utae_L1_upd/experiment_0/" 

In [None]:
## Priyanka ##

### Catchment settings
catchment_kwargs = {}
catchment_kwargs["num"] = "709"
catchment_kwargs["tau"] = 0.5
catchment_kwargs["timestep"]= 5      # for timestep >1 use CNN rolling or Unet
catchment_kwargs["sample_type"]="single"
catchment_kwargs["dim_patch"]=256
catchment_kwargs["fix_indexes"]=True
catchment_kwargs["border_size"] = 0
catchment_kwargs["normalize_output"] = False
catchment_kwargs["use_diff_dem"] = False
catchment_kwargs["num_patch"] = 10      # number of patches to generate from a timestep
catchment_kwargs["predict_ahead"] = 12

In [None]:
## Model 1&2 ##
args = []

In [None]:
model_wo = UTAE_old(args)
file_path1 = path_exp_wo + "model.pth.tar"
model_wo.load_state_dict(torch.load(file_path1))
model_wo.cuda()

model_with = UTAE_old(args)
file_path1 = path_exp_with + "model.pth.tar"
model_with.load_state_dict(torch.load(file_path1))
model_with.cuda()

In [None]:
dataloaders = {}
dataset = load_test_dataset(catchment_kwargs)
                                                             
dataloaders["test"] = dataloader_args_utae_test(dataset, catchment_num = catchment_kwargs['num'])
dataset_test = dataloaders["test"]

In [None]:
# Change model string in predict_event acccordingly to the model used

event_num = 0
start_ts = 5

pred_utae, gt_utae, mask_utae = predict_event(model_wo, dataset, event_num, 'wo_wg_L1', start_ts=None, ar = False, T = None)
pred_unet, gt_unet, mask_unet = predict_event(model_with, dataset, event_num, 'with_wg_L1', start_ts=None, ar = False, T = None)

In [None]:
save_folder = "/scratch2/ml_flood/data/checkpoints/709/cluster/apr_22/utae_L1/experiment_0/results/"

In [None]:
mae_utae = mae_event(pred_utae, gt_utae, mask_utae)
mae_unet = mae_event(pred_unet, gt_unet, mask_unet)

maes = [mae_utae , mae_unet]

labels = ['wo_wg_L1', 'with_wg_L1']
plot_maes(maes, labels, start_ts=17, save_folder = save_folder, name = '12_ts_ahead', title = "MAE for 12 ts ahead")

In [None]:
pred_ts = 12 # here it indicated how many timesteps ahead we are looking at. If None, all timesteps are computed

In [None]:
lims = (0.1,0.2, 0.5, 1)

data_utae = boxplot_mae(pred_utae, gt_utae, mask_utae, lims=lims, pred_ts = pred_ts)
data_unet = boxplot_mae(pred_unet, gt_unet, mask_unet, lims=lims, pred_ts = pred_ts)

In [None]:
ticks = ['0-10 cm', '10-20 cm', '20-50cm', '50-100cm','>100cm']
colors = ['#EF8A62', '#67A9CF', '#1B9E77', '#CA0020', '#998EC3']
data = [data_utae,  data_unet]

multiboxplot(data, ticks, labels, colors, save_folder = save_folder, name = '1ts_ahead', title = "Multiboxplots for models UTAE with and without weighted loss")

In [None]:
zoom = [500,1000,500,1000]
ts = 12
plot_answer_sample(pred_utae, gt_utae, mask_utae, ts=ts, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_utae')
plot_answer_sample(pred_unet, gt_unet, mask_unet, ts=ts, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_unet3d')

In [None]:
zoom = [500,1000,500,1000]
ts = 14
plot_answer_sample(pred_utae, gt_utae, mask_utae, ts=ts, zoom=None, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_utae-wo')
plot_answer_sample(pred_unet, gt_unet, mask_unet, ts=ts, zoom=None, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_utae-with')

## Comparison tau = 0.5 and tau = 0.01

In [None]:
path_exp_05 = "/scratch2/ml_flood/data/checkpoints/709/cluster/apr_22/utae_L1_upd/experiment_0/" 

path_exp_01 = "/scratch2/ml_flood/data/checkpoints/709/cluster/apr_22/utae_L1_upd_tau/experiment_0/"

In [None]:
## Priyanka ##

### Catchment settings
catchment_kwargs = {}
catchment_kwargs["num"] = "709"
catchment_kwargs["tau"] = 0.5
catchment_kwargs["timestep"]= 5      # for timestep >1 use CNN rolling or Unet
catchment_kwargs["sample_type"]="single"
catchment_kwargs["dim_patch"]=256
catchment_kwargs["fix_indexes"]=True
catchment_kwargs["border_size"] = 0
catchment_kwargs["normalize_output"] = False
catchment_kwargs["use_diff_dem"] = False
catchment_kwargs["num_patch"] = 10      # number of patches to generate from a timestep
catchment_kwargs["predict_ahead"] = 12

In [None]:
## Model 1&2 ##
args = []

In [None]:
model_05 = UTAE_old(args)
file_path1 = path_exp_05 + "model.pth.tar"
model_05.load_state_dict(torch.load(file_path1))
model_05.cuda()

model_001 = UTAE_old(args)
file_path1 = path_exp_01 + "model.pth.tar"
model_001.load_state_dict(torch.load(file_path1))
model_001.cuda()

In [None]:
dataloaders = {}
dataset = load_test_dataset(catchment_kwargs)
                                                             
dataloaders["test"] = dataloader_args_utae_test(dataset, catchment_num = catchment_kwargs['num'])
dataset_test = dataloaders["test"]

In [None]:
# Change model string in predict_event acccordingly to the model used

event_num = 0
start_ts = 5

pred_utae, gt_utae, mask_utae = predict_event(model_05, dataset, event_num, 'tau=0.5', start_ts=None, ar = False, T = None)
pred_unet, gt_unet, mask_unet = predict_event(model_001, dataset, event_num, 'tau=0.01', start_ts=None, ar = False, T = None)

In [None]:
save_folder = "/scratch2/ml_flood/data/checkpoints/709/cluster/apr_22/utae_L1_upd_tau/experiment_0/results/"

In [None]:
mae_utae = mae_event(pred_utae, gt_utae, mask_utae)
mae_unet = mae_event(pred_unet, gt_unet, mask_unet)

maes = [mae_utae , mae_unet]

labels = ['tau=0.5', 'tau=0.01']
plot_maes(maes, labels, start_ts=17, save_folder = save_folder, name = '12_ts_ahead', title = "MAE for 12 ts ahead")

In [None]:
pred_ts = 12 # here it indicated how many timesteps ahead we are looking at. If None, all timesteps are computed

In [None]:
lims = (0.1,0.2, 0.5, 1)

data_utae = boxplot_mae(pred_utae, gt_utae, mask_utae, lims=lims, pred_ts = pred_ts)
data_unet = boxplot_mae(pred_unet, gt_unet, mask_unet, lims=lims, pred_ts = pred_ts)

In [None]:
ticks = ['0-10 cm', '10-20 cm', '20-50cm', '50-100cm','>100cm']
colors = ['#EF8A62', '#67A9CF', '#1B9E77', '#CA0020', '#998EC3']
data = [data_utae,  data_unet]

multiboxplot(data, ticks, labels, colors, save_folder = save_folder, name = '1ts_ahead', title = "Multiboxplots for models UTAE with and without weighted loss")

In [None]:
zoom = [500,1000,500,1000]
ts = 12
plot_answer_sample(pred_utae, gt_utae, mask_utae, ts=ts, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_tau0.5')
plot_answer_sample(pred_unet, gt_unet, mask_unet, ts=ts, zoom=zoom, show_diff=False, global_scale=True, save_folder = save_folder, model_name = '1ts_ahead_tau0.01')

In [None]:
zoom = [500,1000,500,1000]
ts = 10
plot_answer_sample(pred_utae, gt_utae, mask_utae, ts=ts, zoom=None, show_diff=False, global_scale=True, save_folder = None, model_name = '1ts_ahead_tau-0.5')
plot_answer_sample(pred_unet, gt_unet, mask_unet, ts=ts, zoom=None, show_diff=False, global_scale=True, save_folder = None, model_name = '1ts_ahead_tau-0.01')