In [1]:
import os
import contextlib

from pytassim.model.integration import RK4Integrator
from pytassim.model import Lorenz96
from pytassim.model.forward_model import forward_model

import numpy as np
import xarray as xr
import torch


from tqdm import tqdm_notebook

import logging

In [2]:
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)

In [3]:
rnd = np.random.RandomState(42)

BASE_PATH = '/scratch/local1/Data/neural_nets/neural_assim/data/'

General time and grid settings

In [4]:
# Before this number of days the run will be used as initialization
start_point = 1000
# The timedelta in model time unit 0.05 ~ 6 hours
dt = 0.05
dt_days = dt*5

end = 20000+start_point

all_steps = np.arange(0, end+dt_days, dt_days) 

nr_grids = 40

Special settings for VR1

In [5]:
start_state = rnd.normal(0, 0.01, size=(1, nr_grids))
F = 8.0

Initialize VR1

In [6]:
l96_vr1 = Lorenz96(forcing=F)
vr1_integrator = RK4Integrator(l96_vr1, dt=dt)

Generate the VR1 dataset

In [7]:
ds_vr1 = forward_model(all_steps, start_point, torch.tensor(start_state), vr1_integrator)



Settings for observations

Save VR1

In [8]:
vr1_path = os.path.join(BASE_PATH, 'train_vr1.nc')
ds_vr1.to_netcdf(vr1_path)

# Generate Ensemble

In [9]:
ens_size = 50

# Lets say that the ensemble forcing has a little bias and a pertubation of roundabout 6%
ens_f = torch.tensor(rnd.normal(0, 0.5, size=(1, ens_size, 1)) + F)

# Intialize the ensemble model and the integrator
l96_ensemble = Lorenz96(ens_f)
ensemble_integrator = RK4Integrator(l96_ensemble, dt=dt)

In [10]:
# Perturbations of the ensemble initial state are roundabout 10 % of the interspatial variability of VR1
ens_pert_std = 2.0

# We want to start every fifth day of VR1 a new ensemble run
ens_ana_time = 5

# Our forecast time is five days to get the same amount of samples as in VR1
ens_lead_time = 5

ens_fcst_steps = np.arange(0, ens_lead_time, dt_days)

In [11]:
base_states = ds_vr1.sel(time=ds_vr1.time%ens_ana_time==0)

In [12]:
ensemble_data = []

pbar = tqdm_notebook(total=len(base_states.time.values))

HBox(children=(IntProgress(value=0, max=4001), HTML(value='')))

In [13]:
%%capture
for ana_time in base_states.time.values:
    tmp_state = rnd.normal(scale=ens_pert_std, size=(ens_size, 1, 40))
    tmp_state = tmp_state + base_states.sel(time=ana_time).values
    tmp_state = tmp_state.transpose(1, 0, 2)
    ens_fcst = forward_model(ens_fcst_steps, 0, torch.tensor(tmp_state), ensemble_integrator)
    ens_fcst_vals = ens_fcst.values
    ensemble_data.append(ens_fcst_vals)
    pbar.update()

In [14]:
ensemble_data = np.array(ensemble_data).transpose(1, 0, 2, 3, 4)
ensemble_coords = dict(ens_fcst.coords)
ensemble_coords['analysis'] = base_states.time.values

In [15]:
ensemble_ds = xr.DataArray(
    data = ensemble_data,
    coords = ensemble_coords,
    dims = ['var_name', 'analysis', 'time', 'ensemble', 'grid']
)

In [16]:
ens_path = os.path.join(BASE_PATH, 'train_ens.nc')
ensemble_ds.to_netcdf(ens_path)