In [None]:
import os
import pickle as pk
from ipywidgets import interact
from copy import deepcopy

from tqdm import tqdm_notebook as tqdm

import xarray as xr
import numpy as np
import pandas as pd

import cartopy.crs as ccrs
import matplotlib.pyplot as plt

import distributed

from pytassim.localization import GaspariCohn

In [None]:
rnd = np.random.RandomState(42)

In [None]:
DENSITY = 1000

In [None]:
rotated_pole = ccrs.RotatedPole(pole_longitude=-171.0, pole_latitude=41.5)
plate_carree = ccrs.PlateCarree()

In [None]:
client = distributed.Client()
client

# Load data

In [None]:
base_path = '/work/um0203/u300636/for2131/runs/da_enkf_for_soil/'
util_dir = '/work/um0203/u300636/for2131/runs/utilities'

## H2O

In [None]:
vr_h2o_path = os.path.join(base_path, '016', 'h2o_cleaned.nc')
vr_h2o = xr.open_dataset(vr_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).chunk((1, 302, 267))

In [None]:
ens_h2o_path = os.path.join(base_path, '015', 'h2o_cleaned.nc')
ens_h2o = xr.open_dataset(ens_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).chunk((40, 1, 302, 267))

In [None]:
sekf_h2o_path = os.path.join(base_path, '023', 'juwels', 'h2o_cleaned.nc')
sekf_h2o = xr.open_dataset(sekf_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).chunk((1, 302, 267))

## T2m

In [None]:
vr_t2m_path = os.path.join(base_path, '016', 't2m_cleaned.nc')
vr_t2m = xr.open_dataset(vr_t2m_path)['T_2M'].squeeze(drop=True).chunk((1, 109, 98))

In [None]:
ens_t2m_path = os.path.join(base_path, '015', 't2m_cleaned.nc')
ens_t2m = xr.open_dataset(ens_t2m_path)['T_2M'].squeeze(drop=True).chunk((40, 1, 109, 98))

In [None]:
sekf_t2m_path = os.path.join(base_path, '023', 'juwels', 't2m_smoother_cleaned.nc')
sekf_t2m = xr.open_dataset(sekf_t2m_path)['T_2M'].squeeze(drop=True).chunk((1, 109, 98))

In [None]:
const_path = '/work/um0203/u300636/for2131/runs/utilities/clm_aux.nc'
const_data = xr.open_dataset(const_path).isel(time=0)
level_sat = const_data['WATSAT'].isel(levsoi=4)

In [None]:
jacob_path = os.path.join(base_path, '023', 'juwels', 'jacobian_cleaned.nc')
jacob_sekf = xr.open_dataset(jacob_path)['H2OSOI_LIQ'].squeeze(drop=True).isel(levtot=4).chunk((1, 302, 267))

In [None]:
gain_path = os.path.join(base_path, '023', 'juwels', 'gain_orig_sekf.nc')
gain_sekf = xr.open_dataset(gain_path)['H2OSOI_LIQ'].squeeze(drop=True).chunk((302, 267, 1, )).stack(grid=('lat', 'lon'))

## Prepare

In [None]:
ens_t2m = ens_t2m.sel(time=~ens_t2m.indexes['time'].duplicated())
vr_t2m = vr_t2m.sel(time=~vr_t2m.indexes['time'].duplicated())

In [None]:
ens_t2m['ensemble'] = ens_h2o['ensemble'] = np.arange(40)

In [None]:
fg_time = sekf_t2m.indexes['time'][12*4-1::12*4][:-1]
bg_time = fg_time.normalize()

## Create background

In [None]:
sekf_bg = sekf_h2o.sel(time=bg_time).stack(grid=['lat', 'lon'])
sekf_bg = sekf_bg.transpose('time', 'grid')

In [None]:
ens_bg = ens_h2o.sel(time=bg_time).stack(grid=['lat', 'lon'])
ens_bg = ens_bg.transpose('time', 'ensemble', 'grid')

In [None]:
vr_h2o_stacked = vr_h2o.sel(time=bg_time).stack(grid=['lat', 'lon'])
vr_h2o_stacked = vr_h2o_stacked.transpose('time', 'grid')

## Create clm coordinates

In [None]:
clm_coords_rotated = rotated_pole.transform_points(plate_carree, sekf_bg.lon.values, sekf_bg.lat.values)
clm_rot_index = pd.MultiIndex.from_arrays([clm_coords_rotated[:, 0], clm_coords_rotated[:, 1]], names=['rlon', 'rlat'])

In [None]:
clm_rlon = xr.DataArray(clm_coords_rotated[:, 0], coords={'grid': sekf_bg.grid}, dims=['grid'])
clm_rlat = xr.DataArray(clm_coords_rotated[:, 1], coords={'grid': sekf_bg.grid}, dims=['grid'])

### Get pseudo state

In [None]:
sekf_fg = sekf_t2m.sel(time=fg_time).drop(['lon', 'lat'])
sekf_fg = sekf_fg.interp(rlon=clm_rlon, rlat=clm_rlat, method='linear').drop(['rlon', 'rlat'])
sekf_fg = sekf_fg.transpose('time', 'grid')

In [None]:
ens_fg = ens_t2m.sel(time=fg_time).drop(['lon', 'lat'])
ens_fg = ens_fg.interp(rlon=clm_rlon, rlat=clm_rlat, method='linear').drop(['rlon', 'rlat'])
ens_fg = ens_fg.transpose('time', 'ensemble', 'grid')

In [None]:
vr_t2m_interp = vr_t2m.sel(time=fg_time).drop(['lon', 'lat'])
vr_t2m_interp = vr_t2m_interp.interp(rlon=clm_rlon, rlat=clm_rlat, method='linear').drop(['rlon', 'rlat'])
vr_t2m_interp = vr_t2m_interp.transpose('time', 'grid')

# Estimate vertical weight

In [None]:
EARTH_RADIUS = 6371000
DEG_TO_M = 2 * np.pi / 360 * EARTH_RADIUS
def distance_func(x, y):
    diff_obs_cos_deg = y[:, :-1] - x[:-1]
    diff_obs_cos_m = diff_obs_cos_deg * DEG_TO_M
    dist_obs_cos_2d = np.sqrt(np.sum(diff_obs_cos_m**2, axis=-1))
    dist_obs_vert = np.abs(y[:, -1]-x[-1])
    return dist_obs_cos_2d, dist_obs_vert

loc_radius = (15000, 0.7)
local_gc = GaspariCohn(loc_radius, distance_func)
_, vert_weight = local_gc.localize_obs((0, 0, -ens_h2o.levsoi.values), np.array([[0, 0, 0]]))

# Get innovation

In [None]:
innov_t2m = vr_t2m_interp-sekf_fg

# Assimilate T2m nature grid point based with finite differences

In [None]:
b_scale = 0.01
obs_cov = 0.01

In [None]:
sekf_increment = (gain_sekf * innov_t2m).drop('time')

In [None]:
sekf_analysis = (sekf_bg + sekf_increment).unstack('grid')
sekf_analysis = (sekf_analysis / level_sat).clip(min=0, max=1) * level_sat

In [None]:
sekf_analysis.to_netcdf('/work/um0203/u300636/for2131/runs/da_enkf_for_soil/023/juwels/da_offline_sekf_nature.nc')

# Assimilate T2m nature grid point based with ECMWF strategy

In [None]:
def get_gain(h_jacob, b_scale, obs_cov):
    b_matrix = b_scale ** 2
    h_jacob = h_jacob.expand_dims('pseudo_time', axis=0)
    h_jacob_norm = h_jacob.rename({'pseudo_time': 'pseudo_time_1'})
    htr = h_jacob / obs_cov
    htrh = (h_jacob_norm * h_jacob).sum(['pseudo_time', 'pseudo_time_1']) / obs_cov
    cov_ana = 1 / (1 / b_matrix + htrh)
    gain = (cov_ana * htr).squeeze('pseudo_time')
    return gain

In [None]:
def get_cov(x, y, dim='ensemble'):
    x_perts = x-x.mean('ensemble')
    y_perts = y-y.mean('ensemble')
    dot_prod = xr.dot(x_perts, y_perts, dims=dim)
    cov = dot_prod / (len(x_perts[dim])-1)
    return cov

In [None]:
cov = get_cov(ens_fg, ens_bg.drop('time'))
var = ens_bg.drop('time').var(dim='ensemble', ddof=1)
jacob_ens = vert_weight * cov / var
jacob_ens = jacob_ens.where(np.abs(jacob_ens)<50, other=0)
gain_senkf = get_gain(jacob_ens, b_scale, obs_cov)

In [None]:
senkf_increment = (gain_senkf * innov_t2m).drop('time')
senkf_analysis = (sekf_bg + senkf_increment).unstack('grid')
senkf_analysis = (senkf_analysis / level_sat).clip(min=0, max=1) * level_sat
senkf_analysis.to_netcdf('/work/um0203/u300636/for2131/runs/da_enkf_for_soil/023/juwels/da_offline_senkf.nc')

# Assimilate T2m nature grid point based with EnKF

In [None]:
bht = get_cov(ens_fg, ens_bg.drop('time'))
hbht = ens_fg.var('ensemble', ddof=1)
hbht_r = hbht + 0.1 ** 2
gain_ens = bht / hbht_r * vert_weight

In [None]:
ens_sekf_increment = (gain_ens * innov_t2m).drop('time')

In [None]:
ens_sekf_analysis = (sekf_bg + ens_sekf_increment).unstack('grid')
ens_sekf_analysis = (ens_sekf_analysis / level_sat).clip(min=0, max=1) * level_sat

In [None]:
ens_sekf_analysis.to_netcdf('/work/um0203/u300636/for2131/runs/da_enkf_for_soil/023/juwels/da_offline_enkf_nature.nc')