In [None]:
import os
from copy import deepcopy

import xarray as xr
import numpy as np
import pandas as pd
import scipy.stats
import scipy.ndimage
from tqdm.notebook import tqdm


import distributed

from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
import cartopy.crs as ccrs
import matplotlib.dates as mdates
import matplotlib.lines as mpll

from pytassim.localization import GaspariCohn
from py_bacy.intf_pytassim.clm import distance_func
from py_bacy.intf_pytassim.io import load_observations

In [None]:
rnd = np.random.RandomState(42)

In [None]:
plt.style.use('paper')
plt.style.use('egu_journals')
plt.rcParams['text.usetex'] = False
rotated_pole = ccrs.RotatedPole(pole_longitude=-171.0, pole_latitude=41.5)
plate_carree = ccrs.PlateCarree()

In [None]:
std_corr = np.sqrt(41 / 40)

In [None]:
cluster = distributed.LocalCluster(n_workers=4, threads_per_worker=1, memory_limit='16GB')
client = distributed.Client(cluster)
client

# Load data

In [None]:
base_path = '/work/um0203/u300636/for2131/runs/da_enkf_for_soil'

## H2O

In [None]:
vr_h2o_path = os.path.join(base_path, '016', 'h2o_cleaned.nc')
vr_h2o = xr.open_dataset(vr_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).chunk((1, 302, 267))[..., 30:-30, 30:-30]

In [None]:
da_h2o_path = os.path.join(base_path, '020', 'h2o_cleaned.nc')
da_h2o = xr.open_dataset(da_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).chunk((40, 1, 302, 267))[..., 30:-30, 30:-30]

## T2m

In [None]:
vr_t2m_path = os.path.join(base_path, '016', 't2m_cleaned.nc')
vr_t2m = xr.open_dataset(vr_t2m_path)['T_2M'].squeeze(drop=True).chunk((1, 109, 99))

In [None]:
da_t2m_path = os.path.join(base_path, '020', 't2m_cleaned.nc')
da_t2m = xr.open_dataset(da_t2m_path)['T_2M'].squeeze(drop=True).chunk((40, 1, 109, 99))

## Additional data

## Pre-process data

In [None]:
da_bg_h2o = da_h2o.isel(time=~da_h2o.indexes['time'].duplicated())
da_bg_t2m = da_t2m.isel(time=~da_t2m.indexes['time'].duplicated())[:, 1:]
vr_h2o = vr_h2o.isel(time=~vr_h2o.indexes['time'].duplicated())
vr_t2m = vr_t2m.isel(time=~vr_t2m.indexes['time'].duplicated())[1:]

## Get clm coordinates

In [None]:
prep_clm = da_bg_h2o.stack(grid=['lat', 'lon'])
clm_coords_rotated = rotated_pole.transform_points(plate_carree, prep_clm.lon.values, prep_clm.lat.values)
clm_rot_index = pd.MultiIndex.from_arrays([clm_coords_rotated[:, 0], clm_coords_rotated[:, 1]], names=['rlon', 'rlat'])
clm_rlon = xr.DataArray(clm_coords_rotated[:, 0], coords={'grid': prep_clm.grid}, dims=['grid'])
clm_rlat = xr.DataArray(clm_coords_rotated[:, 1], coords={'grid': prep_clm.grid}, dims=['grid'])

In [None]:
da_bg_t2m = da_bg_t2m.interp(rlon=clm_rlon, rlat=clm_rlat, method='linear').drop(['rlon', 'rlat']).unstack('grid')
vr_t2m_interp = vr_t2m.interp(rlon=clm_rlon, rlat=clm_rlat, method='linear').drop(['rlon', 'rlat']).unstack('grid')

# Estimate semi-variogram

In [None]:
t2m_perts = (da_bg_t2m - da_bg_t2m.mean('ensemble')).stack(grid=['lat', 'lon']).reset_index('grid', drop=True)
h2o_perts = (da_bg_h2o - da_bg_h2o.mean('ensemble')).stack(grid=['lat', 'lon']).reset_index('grid', drop=True)

t2m_err = (da_bg_t2m.mean('ensemble') - vr_t2m_interp).stack(grid=['lat', 'lon']).reset_index('grid', drop=True)
h2o_err = (da_bg_h2o.mean('ensemble') - vr_h2o).stack(grid=['lat', 'lon']).reset_index('grid', drop=True)

In [None]:
def estimate_dist(idx_first, idx_second):
    lat_dist = (clm_rlat.isel(grid=idx_first).values-clm_rlat.isel(grid=idx_second).values)**2
    lon_dist = (clm_rlon.isel(grid=idx_first).values-clm_rlon.isel(grid=idx_second).values)**2
    dist = np.sqrt(lat_dist+lon_dist) * m_per_deg
    return dist

In [None]:
earth_radius = 6378137
earth_perim = 2 * np.pi * earth_radius
m_per_deg = earth_perim / 360

In [None]:
nr_samples = 2000000
bound = 50000

In [None]:
idx_first = rnd.choice(len(prep_clm.grid), size=nr_samples)
idx_second = rnd.choice(len(prep_clm.grid), size=nr_samples)
dist = estimate_dist(idx_first, idx_second)

In [None]:
too_large = dist > bound
nr_large = np.sum(too_large)
pbar = tqdm()
while nr_large > 0:
    idx_first[too_large] = rnd.choice(len(prep_clm.grid), size=nr_large)
    idx_second[too_large] = rnd.choice(len(prep_clm.grid), size=nr_large)
    dist[:] = estimate_dist(idx_first, idx_second)[None, :]
    too_large = dist > bound
    nr_large = np.sum(too_large)
    pbar.set_postfix(sum=nr_large)
    pbar.update()

In [None]:
sel_t2m_perts = t2m_perts.isel(grid=idx_first)[:, 3::4][:, 6:]
sel_h2o_perts = h2o_perts.isel(grid=idx_second)[:, 3::4][:, 6:]

sel_t2m_err = t2m_err.isel(grid=idx_first)[3::4][6:]
sel_h2o_err = h2o_err.isel(grid=idx_second)[3::4][6:]

In [None]:
ens_cov = (sel_t2m_perts * sel_h2o_perts).sum('ensemble') / 39

ens_cov_gp = ((t2m_perts * h2o_perts).sum('ensemble') / 39).mean('grid') - (t2m_perts.sum('ensemble') / 39).mean('grid') * (h2o_perts.sum('ensemble') / 39).mean('grid')
err_cov_gp = (t2m_err * h2o_err).mean('grid') - t2m_err.mean('grid') * h2o_err.mean('grid')

In [None]:
def get_err_gain(err_t2m, err_h2o, smi, bins=100):
    err_prod = err_t2m * err_h2o
    bin_err_t2m = scipy.stats.binned_statistic(smi, err_t2m, statistic='mean', bins=bins)[0]
    bin_err_h2o = scipy.stats.binned_statistic(smi, err_h2o, statistic='mean', bins=bins)[0]
    bin_err_cov = scipy.stats.binned_statistic(smi, err_prod, statistic='mean', bins=bins)[0]
    #bin_err_var_t2m = scipy.stats.binned_statistic(smi, err_t2m**2, statistic='mean', bins=bins)[0]
    cov = bin_err_cov - bin_err_h2o * bin_err_t2m
    return cov

In [None]:
def exponential(x, a, tau, b):
    return a*np.exp(-x/tau) + b

In [None]:
dist_bins = np.arange(0, 51000, 2000)
plot_bins = dist_bins[:-1] + (dist_bins[1]-dist_bins[0]) / 2
nr_bins = scipy.stats.binned_statistic(dist, np.ones_like(dist), statistic='sum', bins=dist_bins)[0]

In [None]:
localisation = GaspariCohn(15000, dist_func=lambda x, y: y)

In [None]:
loc_weights = localisation.localize_obs(np.array([0]), dist)[1]

In [None]:
time_noon = pd.to_datetime('2015-08-01 12:00')
time_night = pd.to_datetime('2015-08-01 19:00')
time_morning = pd.to_datetime('2015-08-03 06:00')

figsize = [s*0.48 for s in plt.rcParams['figure.figsize']]
plt.rcParams['lines.linewidth'] = 0.8
fig, ax = plt.subplots(figsize=figsize)

err_gain = get_err_gain(sel_t2m_err.sel(time=time_noon).values, sel_h2o_err.sel(time=time_noon).values, dist, dist_bins)
decorr_noon = scipy.optimize.curve_fit(exponential, plot_bins, err_gain, p0=[0.005, 15000, -0.003])[0][1]
_ = ax.plot(plot_bins, err_gain, c='#FF504F', label=time_noon.strftime('%m/%d %H%M UTC'))


err_gain = get_err_gain(sel_t2m_err.sel(time=time_night).values, sel_h2o_err.sel(time=time_night).values, dist, dist_bins)
_ = ax.plot(plot_bins, err_gain, c='#4488B3', label=time_night.strftime('%m/%d %H%M UTC'))
#_ = ax.axvline(x=decorr, c='C1')

err_gain = get_err_gain(sel_t2m_err.sel(time=time_morning).values, sel_h2o_err.sel(time=time_morning).values, dist, dist_bins)
decorr_morn = scipy.optimize.curve_fit(exponential, plot_bins, err_gain, p0=[0.005, 15000, -0.003])[0][1]
_ = ax.plot(plot_bins, err_gain, c='#7ACCFF', label=time_morning.strftime('%m/%d %H%M UTC'))


binned_mean = scipy.stats.binned_statistic(dist, ens_cov.sel(time=time_noon).values, statistic='mean', bins=dist_bins)[0]
mean_ens_plt = ax.plot(plot_bins, binned_mean, c='#FF504F', label='Mean Ens gain', ls='dashed')
loc_mean = scipy.stats.binned_statistic(dist, loc_weights*ens_cov.sel(time=time_noon).values, statistic='mean', bins=dist_bins)[0]
mean_ens_plt = ax.plot(plot_bins, loc_mean, c='#FF504F', alpha=.5, label='Mean Ens gain', ls='dotted')

binned_mean = scipy.stats.binned_statistic(dist, ens_cov.sel(time=time_night).values, statistic='mean', bins=dist_bins)[0]
mean_ens_plt = ax.plot(plot_bins, binned_mean, c='#4488B3', label='Mean Ens gain', ls='dashed')
loc_mean = scipy.stats.binned_statistic(dist, loc_weights*ens_cov.sel(time=time_night).values, statistic='mean', bins=dist_bins)[0]
mean_ens_plt = ax.plot(plot_bins, loc_mean, c='#4488B3', alpha=.5, label='Mean Ens gain', ls='dotted')

binned_mean = scipy.stats.binned_statistic(dist, ens_cov.sel(time=time_morning).values, statistic='mean', bins=dist_bins)[0]
mean_ens_plt = ax.plot(plot_bins, binned_mean, c='#7ACCFF', label='Mean Ens gain', ls='dashed')
loc_mean = scipy.stats.binned_statistic(dist, loc_weights*ens_cov.sel(time=time_morning).values, statistic='mean', bins=dist_bins)[0]
mean_ens_plt = ax.plot(plot_bins, loc_mean, c='#7ACCFF', alpha=.5, label='Mean Ens gain', ls='dotted')

_ = ax.axhline(y=0, c='black', alpha=0.5)

_ = ax.axvline(x=15000, c='black', label='Localization radius')
#_ = ax.axvline(x=decorr_noon, c='C0')
#_ = ax.axvline(x=decorr_morn, c='C2')

ax.set_ylabel(r'Covariance (K m$^3$/m$^3$)')
ax.set_xlabel(r'Horizontal Distance (m)')
ax.set_xlim(0, 40000)
ax.ticklabel_format(style='plain', useOffset=False, axis='x')

own_handles = [
    mpll.Line2D([0], [0], color='black', lw=2, label='Error cov'),
    mpll.Line2D([0], [0], color='black', lw=2, ls='--', label='Ensemble cov'),
    mpll.Line2D([0], [0], color='black', alpha=0.5, lw=2, ls='dotted', label='Localised ensemble cov'),
    #mpll.Line2D([0], [0], color='white', lw=0, label=''),
    mpll.Line2D([0], [0], color='#FF504F', lw=1, label=time_noon.strftime('%m/%d %H%M UTC')),
    mpll.Line2D([0], [0], color='#4488B3', lw=1, label=time_night.strftime('%m/%d %H%M UTC')),
    mpll.Line2D([0], [0], color='#7ACCFF', lw=1, label=time_morning.strftime('%m/%d %H%M UTC')),
]
legend = ax.legend(
    loc='lower center', bbox_to_anchor=(0.446, 1.07), handles=own_handles, ncol=2
)
legend.get_frame().set_linewidth(0.8)

plt.show()
fig.savefig('../figures/fig_11_covariance_dist.png', dpi=300)