# Create plots of spatial maps

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import src.evaluation_utils as eu
import src.evaluation_plots as ep
from src.plot_utils import ScoreData, PlotSpatialHSSScores, PlotSingleFrames
import time

## Load Data

In [None]:
def load_data(fname):
    import pickle
    with open(f'{fname}.dat', "rb") as f:
        data = pickle.load(f)
    return data

def save_data(data, fname):
    import pickle
    with open(f'{fname}.dat', "wb") as f:
        pickle.dump(data, f)

In [None]:
def load_high_resolution(model_file_name: str):
    
    run = np.load(f'/path/to/models/{model_file_name}')
    dnn_weighted = run[0]
    
    run = np.load('/path/to/models/dnn.npy')
    dnn, target, ifs = run[0], run[1], run[2]
    
    return dnn_weighted, target, ifs

## Compute statistics and save results to disk

In [None]:
dataset_path = '/path/to/training_dataset.nc4'
ds = xr.open_dataset(path, chunks={'time': 1})
data = ds.trmm_total_precipitation.values
min_precipitation_threshold_in_mm_per_3hours = 0.1

percentiles = [75, 80, 85, 90, 95, 97.5, 99]

In [None]:
model_name = 'dnn.npy' # and 'dnn_weighted.npy'

thresholds = []
for percentile in percentiles:
    threshold = eu.local_thresholds_from_percentiles(data, percentile,
                                                     data_min=min_precipitation_threshold_in_mm_per_3hours)
    thresholds.append(threshold)
    
dnn, target, ifs = load_high_resolution()

target_binary = eu.continuous_to_categorical_with_thresholds(target, thresholds)
ifs_binary = eu.continuous_to_categorical_with_thresholds(ifs, thresholds)
dnn_binary = eu.continuous_to_categorical_with_thresholds(dnn, thresholds)

metric = 'heidke_skill_score'

dnn_geographic_scores = eu.geographic_categorical_evaluation(dnn_binary, target_binary, metric)
ifs_geographic_scores = eu.geographic_categorical_evaluation(ifs_binary, target_binary, metric)
    
data = [dnn_geographic_scores, ifs_geographic_scores]
fname = f"/path/to/categorical_hourly_geographic_scores_dnn"
save_data(data, fname)

## Load statistics from disk

### Figure 3

In [None]:
fname = "/path/to/categorical_hourly_geographic_scores_dnn"
data = load_data(fname)
dnn_scores, ifs_geographic_scores = data[0], data[1]

fname = "/path/to/categorical_hourly_geographic_scores"
data = load_data(fname)
dnn_weighted_geographic_scores, ifs_geographic_scores = data[0], data[1]

path = '/path/to/training_dataset.nc4'
ds = xr.open_dataset(path, chunks={'time': 1})
data = ds.trmm_total_precipitation.values
min_precipitation_threshold_in_mm_per_3hours = 0.1
percentiles = [75, 80, 85, 90, 95, 97.5, 99]

threshold = eu.local_thresholds_from_percentiles(data, 95,
    data_min=min_precipitation_threshold_in_mm_per_3hours)

configs = {
            'HSS': {
                'cmap': 'viridis_r',
                'cbar_title': f'HSS',
                'alpha': 0.6,
                'vmin': 0.05,
                'vmax': 0.25,
                'cbar_extend': 'both',
                'title': '',
            },
          'Percentile': {
                'cmap': 'viridis_r',
                'cbar_title': f'95th rainfall percentile [mm/3h]',
                'alpha': 0.7,
                'vmin': 20.,
                'vmax': 0,
                'cbar_extend': 'max',
                'title': ''
          }
        }

out_path ='/path/to/figures/'
percentile = 95
file_name = None
idx = 0
data = ScoreData(percentile=threshold,
                ifs=ifs_geographic_scores[idx],
                dnn=dnn_weighted_geographic_scores[idx])


PlotSpatialHSSScores(data, percentile, configs, out_path, plot_percentiles=True).plot()

### Figure 4

In [None]:
file_name = '/path/to/figures/hss_skill_over_latitudes.pdf'

fig = plt.figure(figsize=(9,4))
ax1 = fig.add_subplot(111)

lats = np.arange(-50,52,0.51)

data = dnn_weighted_geographic_scores['heidke_skill_score'][4]
data = np.where(data < -990, 0, data )
data = data.mean(axis=1)
dnn = ax1.plot(lats, data, label='HSS DNN (Weighted)', color='tab:red')

data = ifs_geographic_scores['heidke_skill_score'][4]
data = np.where(data < -990, 0, data )
data = data.mean(axis=1)
ifs = ax1.plot(lats, data, label='HSS IFS', color='tab:blue')

data = dnn_scores[0]
data = np.where(data < -990, 0, data )
data = data.mean(axis=1)
dnn_mse = ax1.plot(lats, data, label='HSS DNN (MSE)', color='tab:green')

ax1.set_ylabel('HSS') 
ax1.set_xlabel('Latitude')
ax1.set_xticks([-40,-20,0,20,40])
ax1.set_xticklabels([r'$40^{\circ}$E',r'$20^{\circ}$E', r'$0^{\circ}$', r'$20^{\circ}$W', r'$40^{\circ}$W'])
ax1.set_ylim(0,0.35)

ax2 = ax1.twinx()
ax2.set_ylabel('95th rainfall percentile [mm/3h]') 
th = plt.plot(lats,threshold.mean(axis=1), color='k', label='TRMM 95th rainfall percentile')

lns = ifs+dnn+dnn_mse+th
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc=0)
ax2.set_ylim(0,23)
ax1.grid()
plt.savefig(file_name, bbox_inches='tight', format='pdf')
plt.show()
print(file_name)

### Figure S7

In [None]:
from src.evaluation_geographic import GeographicValidation

ifs = np.where(data.ifs<0.001, np.ones_like(data.ifs), np.ones_like(data.ifs)*-1)
dnn = np.where(data.dnn<0.001, np.ones_like(data.dnn), np.ones_like(data.ifs)*-1)

eval = GeographicValidation(ds.latitude, ds.longitude,
                               orography_flag=False,
                               mask_threshold=-1,
                               clean_threshold=None,
                               show_coordinates=False
                              )
metric_name = 'HSS'

configs['HSS']['title'] = None
configs['HSS']['cbar_title'] = 'DNN (Weighted) HSS'

plt.figure(figsize=(16,6))

eval.plot_overlap(metric_name, ifs, dnn, configs=configs, single_plot=False)

fname = f'figure_name.pdf'
plt.tight_layout()
plt.savefig(fname, format='pdf')

### Figure S1

In [None]:
dnn_weighted, trmm, ifs = load_high_resolution()

In [None]:
PlotSingleFrames(dnn_weighted, ifs, trmm, timestamps=['2012-07-16T00', '2013-07-16T00', '2014-07-16T00']).plot()