In [None]:
from enkf_lorenz.models import Lorenz96
from enkf_lorenz.integrator import RK4Integrator
from enkf_lorenz.utilities import forward_model
from enkf_lorenz.observation.generator import observation_generator
from enkf_lorenz.assimilation import Letkf

import tensorflow as tf
import keras.backend as k_backend
import numpy as np
import logging
import xarray as xr
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook

from model_dense import load_model

In [None]:
rnd = np.random.RandomState(1)

In [None]:
model_path = '/scratch/local1/Data/neural_nets/neural_assim/models/lorenz_dense_cycle_mem_201806221357/model-5'

# Virtual reality

In [None]:
# Before this number of days the run will be used as initialization
start = 1000
# The timedelta in model time unit 0.05 ~ 6 hours
dt = 0.05
dt_days = dt*5

end = 730 + start

all_steps = np.arange(0, end+dt_days, dt_days) 
all_steps = np.arange(0, end+dt_days, dt_days) 

nr_grids = 40

In [None]:
start_state = rnd.normal(0, 0.01, size=(1, nr_grids))
F = 8

In [None]:
l96_vr1 = Lorenz96(F, nr_grids)
vr1_integrator = RK4Integrator(l96_vr1, dt=dt)

In [None]:
ds_vr1 = forward_model(
    all_steps, start, start_state, vr1_integrator, nr_grids=nr_grids
)

# Observations

In [None]:
obs_bias = 0
obs_std = 0.3
obs_random = rnd.normal
obs_indices = [6, 17, 25, 36, 13]
obs_timestep = 2

In [None]:
ds_obs = observation_generator(ds_vr1, obs_random, obs_indices, obs_timestep,
                               time_axis='time', loc=obs_bias, scale=obs_std)

In [None]:
ds_vr1.to_netcdf('/scratch/local1/Data/lorenz_test/vr1_test.nc')

In [None]:
ds_obs.to_netcdf('/scratch/local1/Data/lorenz_test/obs_test.nc')

# Assimilation models

In [None]:
letkf = Letkf(loc_radius=99, obs_err=0.3, adap_inflation=False)

# Set up ensemble

In [None]:
ens_size = 50

# Lets say that the ensemble forcing has a little bias and a pertubation of roundabout 6%
ens_f = rnd.normal(0, 0.5, size=(1, ens_size, 1)) + F

# Intialize the ensemble model and the integrator
l96_ensemble = Lorenz96(ens_f, nr_grids)
ensemble_integrator = RK4Integrator(l96_ensemble, dt=dt)

In [None]:
# Perturbations of the ensemble initial state are roundabout 10 % of the interspatial variability of VR1
ens_pert_std = 0.3

# We want to start every fifth day of VR1 a new ensemble run
ens_ana_time = obs_timestep

# Our forecast time is five days to get the same amount of samples as in VR1
ens_lead_time = 5

ens_fcst_steps = np.arange(0, ens_lead_time, dt_days)

In [None]:
start_state = ds_vr1.isel(time=0)

In [None]:
ens_start_pert = rnd.normal(0, 0.5, size=(1, ens_size, nr_grids))
ens_start_state = xr.DataArray(
    data = ens_start_pert,
    coords = dict(
        varname=['T', ],
        grid=np.arange(nr_grids),
        ensemble=np.arange(ens_size)
    ),
    dims = ['varname', 'ensemble', 'grid',]
)
ens_start_state = ens_start_state + ds_vr1.isel(time=0)

## Assimilation with the LETKF

In [None]:
letkf_forecasts = []
analysis = None
latest_state = ens_start_state
for obs_time in tqdm_notebook(ds_obs.time):
    if analysis is not None:
        analysis = letkf.assimilate(analysis, ds_obs, analysis_time=obs_time)
        latest_state = analysis.squeeze(dim='varname')
    ensemble_forecast = forward_model(
        ens_fcst_steps, 0,
        latest_state.values,
        ensemble_integrator,
        nr_grids=nr_grids, ens_mems=ens_size
    )
    letkf_forecasts.append(ensemble_forecast)
    analysis = ensemble_forecast.sel(
        time=slice(ens_ana_time, ens_ana_time), drop=False
    )
    analysis['time'] += obs_time

In [None]:
letkf_concated = xr.concat(letkf_forecasts, dim='analysis').dropna('time')
letkf_concated['analysis'] = ds_obs.time.values
letkf_concated = letkf_concated.squeeze()

In [None]:
letkf_concated.to_netcdf('/scratch/local1/Data/lorenz_test/letkf_test.nc')