In [None]:
import os
import pickle as pk
#from ipywidgets import interact
from IPython.display import clear_output

import xarray as xr
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm

import distributed

from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
import cartopy.crs as ccrs
import matplotlib.dates as mdates
import matplotlib.lines as mpll
from pandas.plotting import register_matplotlib_converters

from pytassim.localization import GaspariCohn
from pytassim.model.terrsysmp import preprocess_cosmo
from pytassim.obs_ops.terrsysmp import CosmoT2mOperator
from pytassim.assimilation.filter.letkf_dist import DistributedLETKFUncorr
from py_bacy.intf_pytassim.io import load_observations
from py_bacy.intf_pytassim.clm import distance_func

#import common_utils

%matplotlib inline

In [None]:
rnd = np.random.RandomState(42)

In [None]:
plt.style.use('paper')
plt.style.use('egu_journals')
register_matplotlib_converters()
rotated_pole = ccrs.RotatedPole(pole_longitude=-171.0, pole_latitude=41.5)
plate_carree = ccrs.PlateCarree()

In [None]:
cluster = distributed.LocalCluster(n_workers=16, threads_per_worker=1, memory_limit='8GB')
client = distributed.Client(cluster)
client

# Load data

In [None]:
base_path = '/p/project/chbn29/hbn29p/Projects/phd_thesis/data/da_enkf_for_soil'

## H2O

In [None]:
vr_h2o_path = os.path.join(base_path, '016', 'h2o_cleaned.nc')
vr_h2o = xr.open_dataset(vr_h2o_path)['H2OSOI'].squeeze(drop=True).chunk((1, 302, 267))

In [None]:
ens_h2o_path = os.path.join(base_path, '020', 'h2o_cleaned.nc')
ens_h2o = xr.open_dataset(ens_h2o_path)['H2OSOI'].squeeze(drop=True).chunk((40, 1, 302, 267))

In [None]:
ens_h2o_first_path = os.path.join(base_path, '015', 'h2o_cleaned.nc')
ens_h2o_first = xr.open_dataset(ens_h2o_first_path)['H2OSOI'].squeeze(drop=True).chunk((40, 1, 302, 267))

In [None]:
ens_h2o = xr.concat([ens_h2o_first.sel(time='2015-07-31 12:00'), ens_h2o], dim='time')

## T2m

In [None]:
vr_t2m_path = os.path.join(base_path, '016', 't2m_cleaned.nc')
vr_t2m = xr.open_dataset(vr_t2m_path)['T_2M'].squeeze(drop=True).chunk((1, 109, 98))

In [None]:
ens_t2m_path = os.path.join(base_path, '020', 't2m_cleaned.nc')
ens_t2m = xr.open_dataset(ens_t2m_path)['T_2M'].squeeze(drop=True).chunk((40, 1, 109, 98))

## Prepare

In [None]:
ens_t2m = ens_t2m.sel(time=~ens_t2m.indexes['time'].duplicated())
vr_t2m = vr_t2m.sel(time=~vr_t2m.indexes['time'].duplicated())

In [None]:
ens_t2m['ensemble'] = ens_h2o['ensemble'] = np.arange(40)

In [None]:
bg_idx = ens_t2m.indexes['time']
bg_idx = bg_idx[bg_idx.minute == 0]
bg_idx = bg_idx[bg_idx >= pd.to_datetime('2015-07-31 12:00')]

### Load stations

In [None]:
stations_path = '/p/scratch/chbn29/hbn29p/data/tsmp/runs/utilities/stations.hd5'
df_stations = pd.read_hdf(stations_path, 'stations')

### Load constant

In [None]:
const_path = '/p/scratch/chbn29/hbn29p/data/tsmp/runs/utilities/cosmo_const.nc'
ds_cos_const = xr.open_dataset(const_path).load()

### Load LatLon

In [None]:
coords_latlon = np.stack((ens_t2m['lat'].values, ens_t2m['lon'].values), axis=-1)

In [None]:
coords_latlon.shape

### Define observation operator

In [None]:
obs_op = CosmoT2mOperator(df_stations, cosmo_coords=coords_latlon, cosmo_const=ds_cos_const)
obs_op.get_lapse_rate = lambda x: 0

### Prepare VR

In [None]:
vr_obs_prep = vr_t2m.expand_dims('var_name', axis=0).expand_dims('ensemble', axis=2).expand_dims('vgrid', axis=-3)
vr_obs_prep = vr_obs_prep.stack(grid=['rlat', 'rlon', 'vgrid'])
vr_obs_prep['var_name'] = ['T_2M']

### Load observations

In [None]:
%%capture
obs_path = '/p/scratch/chbn29/hbn29p/data/tsmp/runs/obs/ens/t2m_obs_016_0_1_long.nc'
obs_det = load_observations(obs_path)
obs_det.obs.operator = obs_op.get_obs_method

In [None]:
obs_vr_values = obs_det.obs.operator(vr_obs_prep).squeeze('ensemble')
obs_vr = obs_det.copy(deep=True)
obs_vr['observations'] = obs_vr_values
obs_vr['obs_grid_1'] = obs_det['obs_grid_1']
obs_vr.obs.operator = obs_op.get_obs_method

# Define assimilation

In [None]:
loc_radius = (15000, 0.7)
local_gc = GaspariCohn(loc_radius, distance_func)
letkf = DistributedLETKFUncorr(client=client, localization=local_gc, inf_factor=1.006, chunksize=1000)

## Create states

In [None]:
pseudo_state = ens_t2m.sel(time=bg_idx).expand_dims('var_name', axis=0).expand_dims('vgrid', axis=-3).stack(grid=['rlat', 'rlon', 'vgrid'])
pseudo_state['var_name'] = ['T_2M']
pseudo_state = pseudo_state.transpose('var_name', 'time', 'ensemble', 'grid').load()

In [None]:
background = ens_h2o.isel(levsoi=[4]).sel(time=bg_idx).expand_dims('var_name', axis=0).stack(grid=['lat', 'lon', 'levsoi'])
background['var_name'] = ['H2OSOI']
background = background.transpose('var_name', 'time', 'ensemble', 'grid')
background = background.sel(time=background.indexes['time'].minute == 0).load()

In [None]:
time_pbar = tqdm(background.time.values)

In [None]:
ds_ana = []
for time in time_pbar:
    time_pbar.set_postfix(time=pd.to_datetime(time).strftime('%m-%d %H:%MZ'))
    tmp_bg = background.sel(time=[time])
    tmp_pseudo_state = pseudo_state.sel(time=[time])
    tmp_obs = obs_vr.sel(time=[time])
    tmp_obs.obs.operator = obs_op.get_obs_method
    tmp_ana = letkf.assimilate(tmp_bg, tmp_obs, tmp_pseudo_state)
    ds_ana.append(tmp_ana)

In [None]:
ds_ana_concat = xr.concat(ds_ana, dim='time')

In [None]:
ds_ana_concat = ds_ana_concat.unstack('grid').squeeze()

In [None]:
ds_ana_concat.to_netcdf('/p/project/chbn29/hbn29p/Projects/phd_thesis/data/da_enkf_for_soil/020/da_offline_3d_enkf_nature.nc')