In [None]:
import os
from copy import deepcopy

import xarray as xr
import numpy as np
import pandas as pd
import scipy.stats
import scipy.ndimage
from tqdm import tqdm_notebook as tqdm

import distributed

from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
import cartopy.crs as ccrs
import matplotlib.dates as mdates
import matplotlib.lines as mpll
import matplotlib.gridspec as gs
from pandas.plotting import register_matplotlib_converters

In [None]:
rnd = np.random.RandomState(42)

In [None]:
plt.style.use('paper')
plt.style.use('egu_journals')
plt.rcParams['text.usetex'] = False
register_matplotlib_converters()
rotated_pole = ccrs.RotatedPole(pole_longitude=-171.0, pole_latitude=41.5)
plate_carree = ccrs.PlateCarree()

# Load data

In [None]:
base_path = '/work/um0203/u300636/for2131/runs/'

## H2O

In [None]:
vr_h2o_path = os.path.join(base_path, 'da_enkf_for_soil', '016', 'h2o_cleaned.nc')
vr_h2o = xr.open_dataset(vr_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).sel(time='2015-08-07 18:00')

In [None]:
det_h2o_path = os.path.join(base_path, 'da_enkf_for_soil', '018', 'h2o_cleaned.nc')
det_h2o = xr.open_dataset(det_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).sel(time='2015-08-07 18:00')

In [None]:
ens_h2o_path = os.path.join(base_path, 'da_enkf_for_soil', '015', 'h2o_cleaned.nc')
ens_h2o = xr.open_dataset(ens_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).sel(time='2015-08-07 18:00')

In [None]:
da_off_h2o_path = os.path.join(base_path, 'da_enkf_for_soil', '019', 'h2o_cleaned.nc')
da_off_h2o = xr.open_dataset(da_off_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).sel(time='2015-08-07 18:00')

In [None]:
sekf_h2o_path = os.path.join(base_path, 'da_enkf_for_soil', '023', 'juwels', 'h2o_cleaned.nc')
sekf_h2o = xr.open_dataarray(sekf_h2o_path).squeeze(drop=True).isel(levsoi=4).sel(time='2015-08-07 18:00')

## Load observations

In [None]:
stations_path = os.path.join(base_path, 'utilities', 'stations.hd5')
ds_stations = pd.read_hdf(stations_path)

# Estimate assimilation impact

In [None]:
def calc_bounds(grid):
    grid_delta = np.mean(np.diff(grid))
    new_grid = np.concatenate((grid, [grid[-1]+grid_delta]), axis=0)
    new_grid -= grid_delta / 2
    return new_grid

def calc_pcolormesh_grid(*orig_grid):
    new_grids = tuple([calc_bounds(g) for g in orig_grid])
    return new_grids

In [None]:
da_off_inc = (da_off_h2o.mean('ensemble')-ens_h2o.mean('ensemble'))
sekf_inc = (sekf_h2o-det_h2o)

In [None]:
da_off_err = (da_off_h2o.mean('ensemble')-vr_h2o)
sekf_err = (sekf_h2o-vr_h2o)

ens_err = (ens_h2o.mean('ensemble')-vr_h2o)
det_err = (det_h2o-vr_h2o)

In [None]:
pcm_lon_lat = calc_pcolormesh_grid(da_off_err.lon.values, da_off_err.lat.values)
cmap = plt.get_cmap('BrBG')
norm = mplc.BoundaryNorm(np.linspace(-0.075, 0.075, 51), cmap.N, clip=True)

In [None]:
grid_spec = gs.GridSpec(2, 41, wspace=0, hspace=0)

figsize = [s for s in plt.rcParams['figure.figsize']]

figure = plt.figure(dpi=300, figsize=figsize)
ax_ens = figure.add_subplot(grid_spec[0, :20])
ax_ls_inc = figure.add_subplot(grid_spec[0, 20:40])
ax_ls_err = figure.add_subplot(grid_spec[1, :20])
ax_se_err = figure.add_subplot(grid_spec[1, 20:40])
ax_cbar = figure.add_subplot(grid_spec[:, 40:])

ax_ens.pcolormesh(*pcm_lon_lat, ens_err, vmin=-0.075, vmax=0.075, cmap=cmap, norm=norm)

#ax_det.pcolormesh(*pcm_lon_lat, det_err, vmin=-0.075, vmax=0.075, cmap=cmap, norm=norm)
ax_ls_err.pcolormesh(*pcm_lon_lat, da_off_err, vmin=-0.075, vmax=0.075, cmap=cmap, norm=norm)
ax_se_err.pcolormesh(*pcm_lon_lat, sekf_err, vmin=-0.075, vmax=0.075, cmap=cmap, norm=norm)
cf=ax_ls_inc.pcolormesh(*pcm_lon_lat, da_off_inc, vmin=-0.075, vmax=0.075, cmap=cmap, norm=norm)
ax_ls_inc.scatter(ds_stations['Länge'], ds_stations['Breite'], s=5, marker='x', color='black', zorder=1)

#cf = ax_se_inc.pcolormesh(*pcm_lon_lat, sekf_inc, vmin=-0.075, vmax=0.075, cmap=cmap, norm=norm)

cbar = plt.colorbar(cf, cax=ax_cbar)
cbar.set_label('Difference in soil moisture (m$^3$/m$^3$)')
cbar.set_ticks(np.linspace(-0.075, 0.075, 7))

ax_ens.text(x=0.02, y=0.925, s='(a)', c='black', transform=ax_ens.transAxes, va='center', ha='left')
ax_ls_inc.text(x=0.02, y=0.925, s='(b)', c='black', transform=ax_ls_inc.transAxes, va='center', ha='left')
ax_ls_err.text(x=0.02, y=0.925, s='(c)', c='black', transform=ax_ls_err.transAxes, va='center', ha='left')
ax_se_err.text(x=0.02, y=0.925, s='(d)', c='black', transform=ax_se_err.transAxes, va='center', ha='left')
#ax_det.text(x=0.02, y=0.925, s='(b)', c='black', transform=ax_det.transAxes, va='center', ha='left'
#ax_se_inc.text(x=0.02, y=0.925, s='(f)', c='black', transform=ax_se_inc.transAxes, va='center', ha='left')

# ax_ens.text(x=0.02, y=0.075, s='Error ENS to NATURE', c='black', transform=ax_ens.transAxes, va='center', ha='left', zorder=2,
#             bbox=dict(facecolor='white', linewidth=0))
# #ax_det.text(x=0.02, y=0.075, s='Error DET to NATURE', c='black', transform=ax_det.transAxes, va='center', ha='left')
# ax_ls_err.text(x=0.02, y=0.075, s='Error LETKF Soil to NATURE', c='black', transform=ax_ls_err.transAxes, va='center', ha='left')
# ax_se_err.text(x=0.02, y=0.075, s='Error SEKF to NATURE', c='black', transform=ax_se_err.transAxes, va='center', ha='left')
# ax_ls_inc.text(x=0.02, y=0.075, s='Increment LETKF Soil to ENS', c='black', transform=ax_ls_inc.transAxes, va='center', ha='left')
#ax_se_inc.text(x=0.02, y=0.075, s='Increment SEKF to DET', c='black', transform=ax_se_inc.transAxes, va='center', ha='left')

ax_ens.set_ylim(*pcm_lon_lat[1][[0, -1]])
ax_ls_err.set_ylim(*pcm_lon_lat[1][[0, -1]])
ax_se_err.set_ylim(*pcm_lon_lat[1][[0, -1]])
ax_ls_inc.set_ylim(*pcm_lon_lat[1][[0, -1]])

ax_se_err.set_yticks([])
ax_ls_inc.set_yticks([])

ax_ens.set_ylabel('Latitude (deg)')
ax_ls_err.set_ylabel('Latitude (deg)')


ax_ens.set_xlim(*pcm_lon_lat[0][[0, -1]])
ax_ls_err.set_xlim(*pcm_lon_lat[0][[0, -1]])
ax_se_err.set_xlim(*pcm_lon_lat[0][[0, -1]])
ax_ls_inc.set_xlim(*pcm_lon_lat[0][[0, -1]])

ax_ens.set_xticks([])
ax_ls_err.set_xticks([8, 9, 10])
ax_se_err.set_xticks([8, 9, 10])
ax_ls_inc.set_xticks([])

ax_ls_err.set_xlabel('Longitude (deg)')
ax_se_err.set_xlabel('Longitude (deg)')

plt.show()
figure.savefig('../figures/fig_06_spatial_impact.png', dpi=300)