## Initial setup

At first, we create an instance of `UnstructuredGrid` which paritions the grid and handles all communication between individual processors.

In [None]:
# init the petsc library, which returns an MPI communicator that is used everywhere
from enstools.mpi import init_petsc, onRank0
comm = init_petsc()

In [None]:
# create the grid object itself
from enstools.mpi.grids import UnstructuredGrid
from enstools.io import read
grid_ds = read("/archive/meteo/external-models/dwd/grids/icon_grid_0016_R02B06_G.nc")
grid = UnstructuredGrid(grid_ds, overlap=25, comm=comm)

In [None]:
grid.ncells

### Create some observations

Observations are stored in the DWD feedback file format (http://www2.cosmo-model.org/content/model/documentation/core/cosmoFeedbackFileDefinition.pdf) for potantial later compatibility with the KENDA system and Leo's python tools. 

The class `FeedbackFile` implements functions to extract observations on pressure or model levels from a given model output file.

In [None]:
from enstools.da.support import FeedbackFile, LevelType
import getpass
import numpy as np
import os

# create a temporal folder for all files we create here and later
tmp_folder = f"/project/meteo/scratch/{getpass.getuser()}/enstools-nda-test"
obs_file = f"{tmp_folder}/tmp-obs.nc"

# we only do that on the first MPI rank
if onRank0(comm):
    # only create the folder once
    os.makedirs(tmp_folder, exist_ok=True)
        
    # create a new feedback file object for the same grid as our grid object above
    ff = FeedbackFile(filename=None, gridfile=grid_ds)
    
    # create two observations at the equator in model level 63 (about 500 hPa)
    lon = np.asarray([-10.0, 0, 10.0]) * np.pi / 180.0
    lat = np.asarray([0.0, 0.0, 0.0]) * np.pi / 180.0
    nature_run_file = "/archive/meteo/external-models/dwd/icon/oper/icon_oper_eps_gridded-global_rolling/202002/20200201T00/igaf2020020100.m040.grb"
    ff.add_observation_from_model_output(nature_run_file, variables=["T"], error={"T": 1.0}, lon=lon, lat=lat, levels=[63], level_type=LevelType.MODEL_LEVEL)
    
    # modify the observations to get more impact
    ff.data["obs"].values += 5
    
    # store the observations to a temporal file
    ff.write_to_file(obs_file)

## Create a DA context

The actual data assimilation is handled by the `DataAssimilation` class. Is has methods to read the state, partition observations into not overlapping subsets of observations, to run the actual data assimilation algorithm, and to store the state back to files.

In [None]:
# create the DA object. It makes use of the grid object for communication
from enstools.da.nda import DataAssimilation
da = DataAssimilation(grid)

In [None]:
# Here the load a number of ensemble members into memory. If the notebook crashs, increase the memory!
da.load_state("/archive/meteo/external-models/dwd/icon/oper/icon_oper_eps_gridded-global_rolling/202002/20200201T00/igaf2020020100.m00[1-9].grb")

In [None]:
# create a copy of the loaded state in order to run the algorithm multiple times. Attention: this double memory consumption!
state_backup = da.backup_state()

# the shape of the state variable is (n-cells, sum of all levels from all variables, n-members)
state_backup.shape

In [None]:
# load observations. This step will also partition the observations into not overlapping subsets
da.load_observations(obs_file)

## Run the default algorithm. 

The `DataAssimilation` has a method `run`, that runs any algorithm that is provided. The algorithm is expected to have an `assimilate` methode and a `weights_for_gridpoint` methode.

In [None]:
# get a copy of the Temperature field before running the algorithm
t_before = da.get_state_variable("T")
print(t_before.shape)
t_before_mean = t_before[:,63,:].mean(axis=1)
print(t_before_mean.shape)

In [None]:
# running the algorithm for the first time will cause the JIT functions to be compiled
from enstools.da.nda.algorithms.default import Default
da.run(Default)

In [None]:
# check the results
t_after_mean = da.get_state_variable("T")[:,63,:].mean(axis=1)
t_diff = t_after_mean - t_before_mean
np.abs(t_diff).max()

In [None]:
# plot the difference on a regular grid
from enstools.misc import generate_coordinates
from enstools.interpolation import nearest_neighbour
from enstools.plot import contour

plon, plat = generate_coordinates(0.2, lon_range=[-20, 20], lat_range=[-10, 10])
f_interpol = nearest_neighbour(grid_ds["clon"], grid_ds["clat"], plon, plat, src_grid="unstructured", dst_grid="regular")

In [None]:
t_diff_interpol = f_interpol(t_diff.reshape(1, grid.ncells))
contour(t_diff_interpol[0, ...], lon=plon, lat=plat, levels_center_on_zero=True, cmap="PuOr")

## Store the result into files

For now, only storing the complete state is supported. 

In [None]:
da.save_state(tmp_folder)

## Running a new algorithm

### Class for a new algorithm

Have a look at `enstools.da.nda.algorithms.__init__.py` for the arguments

In [None]:
from enstools.da.nda.algorithms.algorithm import Algorithm, model_equivalent, covariance
from numba import jit, prange, i4, f4
import numpy as np


class FancyNew(Algorithm):

    @staticmethod
    @jit("void(f4[:,:,::1], i4[:,::1], f4[:,::1], i4[:,::1], i4[:,::1], i4[:,::1], f4[:,::1], i1[::1])",
         nopython=True, nogil=True, parallel=True,
         locals={"i_report": i4, "i_obs": i4, "i_radius": i4, "i_layer": i4, "i_points": i4, "i_cell": i4,
                 "p_equivalent": f4, "denominator": f4, "p": f4})
    def assimilate(state: np.ndarray, state_map: np.ndarray,
                   observations: np.ndarray, observation_type: np.ndarray, reports: np.ndarray,
                   points_in_radius: np.ndarray, weights: np.ndarray, updated: np.ndarray):
        """
        see Algorithm class for documentation of arguments.
        """
        # temporal variables
        n_varlayer = state.shape[1]
        n_ens = state.shape[2]
        n_inv = 1. / (n_ens - 1)
        equivalent = np.empty(n_ens, dtype=np.float32)
        deviation_equivalent_mean = np.empty(n_ens, dtype=np.float32)
        innovation = np.empty(n_ens, dtype=np.float32)
        random_error = np.empty(n_ens, dtype=np.float32)

        # observations are processed one by one in the order that they are listed in the reports array
        for i_report in range(reports.shape[0]):

            # all observations in this report are located at this index within the local part of the grid.
            grid_index = reports[i_report, 2]
            assert grid_index != -1

            # loop over all observations in this report
            for i_obs in range(reports[i_report, 0], reports[i_report, 0] + reports[i_report, 1]):
                # get model equivalents for the given observation and the mean which is later used for covariances
                # for observation on model levels, model_equivalent returns just the corresponding gird cell.
                model_equivalent(state, state_map, grid_index, observations, observation_type, i_obs,
                                 equivalent, deviation_equivalent_mean)

                # calculate innovation from observation value[i_obs, 0] and observation error[i_obs, 0]
                random_error[:] = np.random.normal(0, observations[i_obs, 1], n_ens)
                innovation[:] = observations[i_obs, 0] + random_error - equivalent

                # calculate variance of model equivalent
                p_equivalent = np.sum(deviation_equivalent_mean**2) * n_inv
                denominator = 1.0 / (p_equivalent * observations[i_obs, 1]**2)

                # loop over all grid cells and all variables that are within the localization radius
                # This loop runs in parallel if NUMBA_NUM_THREADS is larger than 1.
                i_points = reports[i_report, 3]
                for i_radius in prange(points_in_radius.shape[1]):
                    # the number of points for each observation is not constant. stop the loop as soon as we reach
                    # a grid cell index of -1
                    i_cell = points_in_radius[i_points, i_radius]
                    if i_cell == -1:
                        continue

                    # mark the current point as updated. This will cause updates of overlapping areas between processors
                    updated[i_cell] = 1

                    # loop over all layers of the state, this is also a loop over all variables as variables are stacked
                    # on top of each other in the state variable.
                    for i_layer in range(n_varlayer):
                        # calculate covariance between model equivalent and the current location in the state
                        p = covariance(state, i_cell, i_layer, deviation_equivalent_mean) * weights[i_points, i_radius]

                        # update the state at the current location
                        for i_ens in range(n_ens):
                            state[i_cell, i_layer, i_ens] += p * denominator * innovation[i_ens]


In [None]:
# run the new FancyNew algorithm
# at first, we restore the state from before running the Default algorithm
da.restore_state(state_backup)

In [None]:
da.run(FancyNew)

In [None]:
# check the results
t_after_mean = da.get_state_variable("T")[:,63,:].mean(axis=1)
t_diff = t_after_mean - t_before_mean
np.abs(t_diff).max()

In [None]:
t_diff_interpol = f_interpol(t_diff.reshape(1, grid.ncells))
contour(t_diff_interpol[0, ...], lon=plon, lat=plat, levels_center_on_zero=True, cmap="PuOr")