In [None]:
import os
import pickle as pk
from ipywidgets import interact
from copy import deepcopy

from tqdm import tqdm_notebook as tqdm

import xarray as xr
import numpy as np
import pandas as pd

import distributed

from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import matplotlib.pyplot as plt
import matplotlib.colors
import cartopy.crs as ccrs
import matplotlib.dates as mdates
from pandas.plotting import register_matplotlib_converters

from pytassim.localization import GaspariCohn
from pytassim.model.terrsysmp import preprocess_cosmo
from pytassim.obs_ops.terrsysmp import CosmoT2mOperator
from pytassim.assimilation import LETKFUncorr
from py_bacy.intf_pytassim.io import load_observations
from py_bacy.intf_pytassim.clm import distance_func

In [None]:
rnd = np.random.RandomState(42)

In [None]:
DENSITY = 1000

In [None]:
plt.style.use('paper')
plt.style.use('egu_journals')
register_matplotlib_converters()

In [None]:
rotated_pole = ccrs.RotatedPole(pole_longitude=-171.0, pole_latitude=41.5)
plate_carree = ccrs.PlateCarree()

In [None]:
cluster = distributed.LocalCluster(n_workers=1, threads_per_worker=1, memory_limit='4GB')
client = distributed.Client(cluster)
client

# Load data

In [None]:
base_path = '/p/project/chbn29/hbn29p/Projects/phd_thesis/data/da_enkf_for_soil/'

## H2O

In [None]:
vr_h2o_path = os.path.join(base_path, '016', 'h2o_cleaned.nc')
vr_h2o = xr.open_dataset(vr_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).chunk((1, 302, 267))

In [None]:
da_h2o_first_path = os.path.join(base_path, '015', 'h2o_cleaned.nc')
da_h2o_first = xr.open_dataset(da_h2o_first_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).chunk((10, 1, 302, 267))

In [None]:
da_h2o_path = os.path.join(base_path, '020', 'h2o_cleaned.nc')
da_h2o = xr.open_dataset(da_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).chunk((10, 1, 302, 267))

In [None]:
da_h2o = xr.concat([da_h2o_first.sel(time='2015-07-31 12:00'), da_h2o], dim='time')

## T2m

In [None]:
vr_t2m_path = os.path.join(base_path, '016', 't2m_cleaned.nc')
vr_t2m = xr.open_dataset(vr_t2m_path)['T_2M'].squeeze(drop=True).chunk((1, 109, 98))

In [None]:
da_t2m_path = os.path.join(base_path, '020', 't2m_cleaned.nc')
da_t2m = xr.open_dataset(da_t2m_path)['T_2M'].squeeze(drop=True).chunk((10, 1, 109, 98))

In [None]:
da_t2m_first_path = os.path.join(base_path, '015', 't2m_cleaned.nc')
da_t2m_first = xr.open_dataset(da_t2m_first_path)['T_2M'].squeeze(drop=True).chunk((10, 1, 109, 98))

In [None]:
da_t2m = xr.concat([da_t2m_first.sel(time='2015-07-31 12:00'), da_t2m], dim='time')

## Const data

In [None]:
clm_const_path = '/p/scratch/chbn29/hbn29p/data/tsmp/runs/utilities/clm_aux.nc'
clm_const_ds = xr.open_dataset(clm_const_path)
delta_z = clm_const_ds['DZSOI'].T.isel(levsoi=4)

## Prepare

In [None]:
da_t2m = da_t2m.sel(time=~da_t2m.indexes['time'].duplicated())
vr_t2m = vr_t2m.sel(time=~vr_t2m.indexes['time'].duplicated())

In [None]:
da_h2o['ensemble'] = da_t2m['ensemble'] = np.arange(40)

In [None]:
bg_idx = da_t2m.indexes['time']
bg_idx = bg_idx[bg_idx.minute == 0]
bg_idx = bg_idx[bg_idx >= pd.to_datetime('2015-07-31 12:00')]

## Create background

In [None]:
background = da_h2o.sel(time=bg_idx).stack(grid=['lat', 'lon'])
background = background.transpose('time', 'ensemble', 'grid')
background = background.sel(time=background.indexes['time'].minute == 0)

In [None]:
vr_h2o_stacked = vr_h2o.sel(time=bg_idx).stack(grid=['lat', 'lon'])
vr_h2o_stacked = vr_h2o_stacked.transpose('time', 'grid')
vr_h2o_stacked = vr_h2o_stacked.sel(time=vr_h2o_stacked.indexes['time'].minute == 0)

## Create clm coordinates

In [None]:
clm_coords_rotated = rotated_pole.transform_points(plate_carree, background.lon.values, background.lat.values)
clm_rot_index = pd.MultiIndex.from_arrays([clm_coords_rotated[:, 0], clm_coords_rotated[:, 1]], names=['rlon', 'rlat'])

In [None]:
clm_rlon = xr.DataArray(clm_coords_rotated[:, 0], coords={'grid': background.grid}, dims=['grid'])
clm_rlat = xr.DataArray(clm_coords_rotated[:, 1], coords={'grid': background.grid}, dims=['grid'])

### Get pseudo state

In [None]:
pseudo_state = da_t2m.sel(time=bg_idx).drop(['lon', 'lat'])
pseudo_state = pseudo_state.interp(rlon=clm_rlon, rlat=clm_rlat, method='linear').drop(['rlon', 'rlat'])
pseudo_state = pseudo_state.transpose('time', 'ensemble', 'grid')

In [None]:
vr_t2m_interp = vr_t2m.sel(time=bg_idx).drop(['lon', 'lat'])
vr_t2m_interp = vr_t2m_interp.interp(rlon=clm_rlon, rlat=clm_rlat, method='linear').drop(['rlon', 'rlat'])
vr_t2m_interp = vr_t2m_interp.transpose('time', 'grid')

# Assimilate T2m nature grid point based with EnKF

# Estimate vertical weight

In [None]:
EARTH_RADIUS = 6371000
DEG_TO_M = 2 * np.pi / 360 * EARTH_RADIUS
def distance_func(x, y):
    diff_obs_cos_deg = y[:, :-1] - x[:-1]
    diff_obs_cos_m = diff_obs_cos_deg * DEG_TO_M
    dist_obs_cos_2d = np.sqrt(np.sum(diff_obs_cos_m**2, axis=-1))
    dist_obs_vert = np.abs(y[:, -1]-x[-1])
    return dist_obs_cos_2d, dist_obs_vert

loc_radius = (15000, 0.7)
local_gc = GaspariCohn(loc_radius, distance_func)
_, vert_weight = local_gc.localize_obs((0, 0, -da_h2o.levsoi.values), np.array([[0, 0, 0]]))

## Get innovation

In [None]:
innov_t2m = vr_t2m_interp-pseudo_state.mean('ensemble')

## Estimate Kalman gain

In [None]:
bht = xr.dot(background-background.mean('ensemble'), pseudo_state-pseudo_state.mean('ensemble'), dims='ensemble') / (len(pseudo_state['ensemble']) - 1)
hbht = pseudo_state.var('ensemble', ddof=1)
hbht_r = hbht + 0.1 ** 2

gain = bht / hbht_r * vert_weight

## Estimate analysis

In [None]:
inc_gp_ana = gain * innov_t2m
da_h2o_gp_ana = (background.mean('ensemble') + inc_gp_ana).unstack('grid')

In [None]:
da_h2o_gp_ana.to_netcdf('/p/project/chbn29/hbn29p/Projects/phd_thesis/data/da_enkf_for_soil/020/da_offline_1d_enkf_nature.nc')

# Assimilate H2OSoi nature grid point based with EnKF

## Get innovation

In [None]:
innov_h2o = vr_h2o_stacked-background.mean('ensemble')

## Estimate Kalman gain

In [None]:
bht = background.var('ensemble', ddof=1)
hbht = background.var('ensemble', ddof=1)
hbht_r = hbht + 0.01 ** 2

gain = bht / hbht_r * vert_weight

## Estimate analysis

In [None]:
inc_h2o_ana = gain * innov_h2o
da_h2o_h2o_ana = (background.mean('ensemble') + inc_h2o_ana).unstack('grid')

In [None]:
da_h2o_h2o_ana.to_netcdf('/p/project/chbn29/hbn29p/Projects/phd_thesis/data/da_enkf_for_soil/020/da_offline_1d_enkf_h2o.nc')