In [None]:
import os
import sys
import yaml
import argparse
from glob import glob
from datetime import datetime, timedelta

import numpy as np
import xarray as xr

In [None]:
sys.path.insert(0, os.path.realpath('../libs/'))
import graph_utils as gu
import verif_utils as vu

In [None]:
config_name = os.path.realpath('plot_config.yml')

with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [None]:
# graph tools
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline

## Figure settings

In [None]:
need_publish = False

# True: publication quality figures
# False: low resolution figures in the notebook

if need_publish:
    dpi_ = conf['figure']['keys']['dpi']
else:
    dpi_ = 75

## Import data

In [None]:
# year range
year_range = conf['example']['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)

# lat/lon variables
OURS_dataset = xr.open_dataset(conf['geo']['geo_file_nc'])
x_OURS = np.array(OURS_dataset['longitude'])
y_OURS = np.array(OURS_dataset['latitude'])

lon_OURS, lat_OURS = np.meshgrid(x_OURS, y_OURS)

In [None]:
# ERA5 verif target
filename_ERA5 = sorted(glob(conf['example']['save_loc_target']))
filename_ERA5 = [fn for fn in filename_ERA5 if any(year in fn for year in years_pick)]

# merge yearly ERA5 as one
ds_ERA5 = [vu.get_forward_data(fn) for fn in filename_ERA5]
ds_ERA5_merge = xr.concat(ds_ERA5, dim='time')
    
# Select the specified variables and their levels
variables_levels = conf['example']['verif_variables']

# subset merged ERA5 and unify coord names
ds_ERA5_merge = vu.ds_subset_everything(ds_ERA5_merge, variables_levels)
ds_ERA5_merge = ds_ERA5_merge.rename({'latitude':'lat','longitude':'lon'})

In [None]:
# medium size model
filename_medium = sorted(glob(conf['example']['save_loc_base']+'*.nc'))
filename_medium = [fn for fn in filename_medium if any(year in fn for year in years_pick)]

# large model
filename_large = sorted(glob(conf['example']['save_loc_ours']+'*.nc'))
filename_large = [fn for fn in filename_large if any(year in fn for year in years_pick)]

## Pick example days and get its variable subsets

In [None]:
example_day = 2105
example_lead_index = 239 # day-10

varnames = ['V500', 'U500', 'T500', 'Q500', 'Z500', 'SP', 't2m']

dict_example = {}

for varname in varnames:
    ds_medium = xr.open_dataset(filename_medium[example_day])
    ds_large = xr.open_dataset(filename_large[example_day])
    
    var_medium = ds_medium[varname].isel(time=example_lead_index)
    var_large = ds_large[varname].isel(time=example_lead_index)

    dict_example['{}_medium'.format(varname)] = np.array(var_medium)
    dict_example['{}_large'.format(varname)] = np.array(var_large)

    ds_target = ds_ERA5_merge.sel(time=ds_medium['time'])
    dict_example['{}_target'.format(varname)] = np.array(ds_target[varname].isel(time=example_lead_index))

In [None]:
model_names = 

fig = plt.figure(figsize=(13, 13), dpi=dpi_)
gs = gridspec.GridSpec(5, 3, height_ratios=[1, 1, 1, 1, 1], width_ratios=[1, 1, 1])

ind_x = [0, 1, 2, 3, 4]
ind_y = [0, 1, 2]

AX = np.zeros((5, 3)).tolist()
AX_flat = []
for ix, ix_gs in enumerate(ind_x):
    for iy, iy_gs in enumerate(ind_y):
        AX[ix][iy] = plt.subplot(gs[ix_gs, iy_gs])
        AX_flat.append(plt.subplot(gs[ix_gs, iy_gs]))

plt.subplots_adjust(0, 0, 1, 1, hspace=0.1, wspace=0.1)

for ax in AX_flat:
    ax = gu.ax_decorate_box(ax)

