In [None]:
# External packages

import intake
from easygems import healpix as egh

import numpy as np

import matplotlib.pyplot as plt
import cartopy.crs as ccrs

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
# Functions from our repo

import tools.utils as sc

In [None]:
# Time period
# time = ('2020-04-01','2020-04-30')
time = ('2020-08-01','2020-08-31')

# Region
map_domain = sc.domains10x10['namibian']

In [None]:
# Dataset specification

hknode = 'EU'

simulations = {
    'IFS': {
        'id':    'ifs_tco3999-ng5_rcbmf_cf',
        'opt':   {'zoom':11, 'time':'PT1H'},
        'rnm':   {'level':'pressure', 'value':'cell', 'clwvi':'lwp'},
        'vunits': 'hPa'
    },
    'ICON': {
        'id':    'icon_d3hp003',
        'opt':   {'zoom':11, 'time':'PT6H', 'time_method':'inst'},
        'rnm':   {'qall':'lwc'}
    },
    'NICAM': {
        'id':    'nicam_gl11',
        'opt':   {'zoom': 9, 'time':'PT6H'},
        'rnm':   {'lev':'pressure', 'qall':'lwc'}
    },
    'UM': {
        'id':    'um_glm_n2560_RAL3p3',
        'opt':   {'zoom':10, 'time':'PT3H'},
        'rnm':   {'clw':'lwc'}
    }
}

In [None]:
# Load datasets

cat = intake.open_catalog("https://digital-earths-global-hackathon.github.io/catalog/catalog.yaml")[hknode]

ds = {}

for name, params in simulations.items():
    ds[name] = cat[params['id']](**params['opt']).to_dask() \
        .rename(params['rnm']).pipe(egh.attach_coords)

    if 'pressure' in ds[name] and 'units' not in ds[name]['pressure'].attrs:
        ds[name]['pressure'].attrs['units'] = params['vunits']
    
    print(name)

In [None]:
%%time

# Select time period and region

for name in ds.keys():
    cells = egh.isel_extent(ds[name],map_domain)
    ds[name] = ds[name].sel(time=slice(*time)).isel({'cell':cells})

In [None]:
%%time

# Derive lwp, wamax and wamin

for name in ds.keys():

    if 'lwp' not in ds[name] and 'lwc' in ds[name]:
        ds[name]['lwp'] = sc.integrate_wrt_pressure(ds[name]['lwc'])
    ds[name]['lwp'].attrs = {'name':'LWP','units':'kg/m^2'}
    print(name+' LWP integration')
    
    if 'wa' in ds[name]:
        ds[name]['wamax'] = sc.reduce_below(ds[name]['wa'],900e2,np.max)
        ds[name]['wamax'].attrs = {'name':'Max w <900hPa','units':'m/s'}
        print(name+' wa max <900hPa')

    if 'wa' in ds[name]:
        ds[name]['wamin'] = sc.reduce_below(ds[name]['wa'],900e2,np.min)
        ds[name]['wamin'].attrs = {'name':'Min w <900hPa','units':'m/s'}
        print(name+' wa min <900hPa')

In [None]:
# Plot example maps

plot_path = f"./figures/map/"

time_step = np.timedelta64(10,'D')
times = np.arange( list(ds.values())[0].time[0].values,
                   list(ds.values())[0].time[-1].values,
                   time_step )

variables = {
    'lwp':   {'cmap':'Blues_r', 'vmin':0,    'vmax':1  },
    'wamax': {'cmap':'bwr',     'vmin':-0.4, 'vmax':0.4},
    'wamin': {'cmap':'bwr',     'vmin':-0.4, 'vmax':0.4}
}


Ncol = len(variables)
Nrow = len(ds)
    
for t in times:

    fig, axs = plt.subplots(Nrow,Ncol,figsize=(4*Ncol,4*Nrow),
                            subplot_kw={"projection": ccrs.PlateCarree()},
                            sharex=True, sharey=True, constrained_layout=True)

    # Ensure axs is 2D
    if Nrow == 1:
        axs = np.expand_dims(axs, axis=0)
    if Ncol == 1:
        axs = np.expand_dims(axs, axis=1)

    for r, name in enumerate(ds.keys()):
        for c, (var, opt) in enumerate(variables.items()):
            da = ds[name][var].sel(time=t)
            ax = axs[r,c]
            gl = sc.draw_map(ax,map_domain)
            im = egh.healpix_show(da, ax=ax, **opt)

            if c>0:
                gl.left_labels = False
            if r<Nrow-1:
                gl.bottom_labels = False
            if r==0:
                ax.set_title(f"{da.attrs['name']} [{da.attrs['units']}]")
            if r==Nrow-1:
                cb = plt.colorbar(im,ax=ax,location='bottom',shrink=0.9,pad=0.02)
            if c==0:
                ax.text(-0.2, 0.5, simulations[name]['id'],
                    transform=ax.transAxes, va='center', ha='right',
                    fontsize='large', fontweight='bold', rotation=90)

    datestr = t.astype('datetime64[h]').item().strftime('%Y-%m-%dT%H')
    fig.suptitle(datestr)
    plt.savefig(plot_path+datestr,bbox_inches='tight',dpi=600)
    print(datestr)
    # plt.close(fig)