# Tutorial: Implementing your own Data Assimilation Scheme

As part of the Data Assimilation framework, classes which apply the assimilation steps are purposely kept seperate as shown in the [design flowchart](https://ewatercycle-da.readthedocs.io/en/latest/_images/method1_define_upfront.png).<br>
This allows users to implement their own schemes to assimilate models. <br>
Currently (april 2024) these must be implemented in the same  file as DA, but this is just a pointer to a class. <br>
You could with little modification change this (we might later actually). <br>
Regradless of where it lives, the implementation is still the same. <br> 
This tutorial will walk you through what you need to have in order to create a working framework. <br>
The framework has been designed to work with: 
- Particle Filters (PF)
- Ensemble Kalman Filter (EnKF)
- Ensemble Smoothers (ES)
- Ensemble Smoother Multiple Data Assimilation (ES-MDA)
- Iterative Ensemble Smoother (IES)
  
Others will likely work, but no guarantees. 4Dvar likely won't easily work as this requires time to be a parameter, which adds a whole layer of complexity. 
To explain how the scheme should be strucutred a `NewDataAssimilationScheme` class will be constructed here to show the different steps required.

eWaterCycle in general uses [pydantic](https://docs.pydantic.dev/latest/) for validation, thus you inherit the `BaseModel` 

In [5]:
from pydantic import BaseModel
class NewDataAssimilationScheme(BaseModel):
    """Implementation of a new DA scheme"""

The class has one Argument: 
 - `N` which is set when initialising the ensemble
The class has a 5 required attributes:
 - `hyperparameters` (dictionary): Combination of many different parameters which you might need. Can also be empty if not needed ofcourse
 - `obs` (float): observation value of the current model timestep, set in due course thus optional
 - `state_vectors` (np.ndarray): current state vector (i.e. input) per ensemble member with shape N x len(z), where z is the state vector. 
 - `predictions` (np.ndarray): contains prior modeled values per ensemble member, was assumed  N x 1 but dependant on what you pass as H opperator.
 - `new_state_vectors` (np.ndarray): updated state vector (i.e. output) per ensemble member & thus shape N x len(z)

Then add in any optional attributes you may want to access for debugging or other reasons:
- new_data_assimilation_specific_var (...): example of what you might add

In [13]:
from typing import Any
class NewDataAssimilationScheme(BaseModel):
    """Implementation of a new DA scheme"""
    N: int
    
    hyperparameters: dict = dict()
    obs: float | Any | None = None
    state_vectors: Any | None = None
    predictions: Any | None = None
    new_state_vectors: Any | None = None
    
    new_data_assimilation_specific_var: Any | None = None

The class requires the functon `update` which assimilates the current state vector and sets the new state vector

In [14]:
class NewDataAssimilationScheme(BaseModel):
    """Implementation of a new DA scheme"""
    N: int
    hyperparameters: dict = dict()
    obs: float | Any | None = None
    state_vectors: Any | None = None
    predictions: Any | None = None
    new_state_vectors: Any | None = None
    new_data_assimilation_specific_var: Any | None = None

    def update(self):
        """Takes current state vectors of ensemble and returns updated state vectors ensemble"""
        state_vector = self.state_vectors

        # ... 
        # insert scheme
        # ... 

        self.new_state_vectors = state_vector

As mentioned, the DA schemes are currently purposely controlled to avoid misuse/confusion. <br>
For this reasons if you implement as scheme you must add it to the `LOADED_METHODS` dictionary.<br>
Each of the different DA schemes have their own `.py` files under `ewatercycle_DA.data_assimilation_schemes`.
The current method looks like this: 
```py
    ...
    
    from ewatercycle_DA.data_assimilation_schemes.PF import ParticleFilter    
from ewatercycle_DA.data_assimilation_schemes.EnKF import EnsembleKalmanFilte
    r
    LOADED_METHODS: dict[str, Any] = dict(
                                            PF=ParticleFilter,
                                            EnKF=EnsembleKalmanFilter,
                                         
    ...)

```

Because everything in DA is shortend, adding our new method will be:

```py
    ...
    
    from ewatercycle_DA.data_assimilation_schemes.NDAS import NewDataAssimilationScheme
    
    ...

    LOADED_METHODS: dict[str, Any] = dict(
                                            PF=ParticleFilter,
                                            EnKF=EnsembleKalmanFilter,
                                            NDAS=NewDataAssimilationScheme
                                         )
    ...

```
This is if you put the file called `NDAS.py` file in the folder `data_assimilation_schemes` with the class `NewDataAssimilationScheme`.

_advanced:_ <br>
This dictionary merely holds a pointer to the class: you could also host the class somewhere else and provide it: 
```py
    from ewatercycle-NewDataAssimilationScheme import NewDataAssimilationScheme
    new_scheme_method = {'NDAS': NewDataAssimilationScheme}
    LOADED_METHODS.update(new_scheme_method)
```

As an example, the current ParticelFilter class looks like this 

In [21]:
class ParticleFilter(BaseModel):
    """Implementation of a particle filter scheme to be applied to the :py:class:`Ensemble`.

    note:
        The :py:class:`ParticleFilter` is controlled by the :py:class:`Ensemble` and thus has no time reference itself.
        No DA method should need to know where in time it is (for now).
        Currently assumed 1D grid.

    Args:
        N (int): Size of ensemble, passed down from DA.Ensemble().

    Attributes:
        hyperparameters (dict): Combination of many different parameters:
                                like_sigma_weights (float): scale/sigma of logpdf when generating particle weights

                                like_sigma_state_vector (float): scale/sigma of noise added to each value in state vector

        obs (float): observation value of the current model timestep, set in due course thus optional

        state_vectors (np.ndarray): state vector per ensemble member [N x len(z)]

        predictions (np.ndarray): contains prior modeled values per ensemble member [N x 1]

        new_state_vectors (np.ndarray): updated state vector per ensemble member [N x len(z)]

        weights (np.ndarray): contains weights per ensemble member per prior modeled values [N x 1]

        resample_indices (np.ndarray): contains indices of particles that are resampled [N x 1]


    All are :obj:`None` by default


    """
    # args
    N: int

    # required attributes
    hyperparameters: dict = dict(like_sigma_weights=0.05, like_sigma_state_vector=0.0005)
    obs: float | Any | None = None # TODO: refactor to np.ndarray
    state_vectors: Any | None = None
    predictions: Any | None = None
    new_state_vectors: Any | None = None

    # extra attributes
    weights: Any | None = None
    resample_indices: Any | None = None


    def update(self):
        """Takes current state vectors of ensemble and returns updated state vectors ensemble
        """
        self.generate_weights()

        # TODO: Refactor to be more modular i.e. remove if/else

        # 1d for now: weights is N x 1
        if self.weights[0].size == 1:
            self.resample_indices = random.choices(population=np.arange(self.N), weights=self.weights, k=self.N)

            new_state_vectors = self.state_vectors.copy()[self.resample_indices]
            new_state_vectors_transpose = new_state_vectors.T # change to len(z) x N so in future you can vary sigma

            # for now just constant perturbation, can vary this hyperparameter
            like_sigma = self.hyperparameters['like_sigma_state_vector']
            if type(like_sigma) is float:
                for index, row in enumerate(new_state_vectors_transpose):
                    row_with_noise = np.array([s + add_normal_noise(like_sigma) for s in row])
                    new_state_vectors_transpose[index] = row_with_noise

            elif type(like_sigma) is list and len(like_sigma) == len(new_state_vectors_transpose):
                for index, row in enumerate(new_state_vectors_transpose):
                    row_with_noise = np.array([s + add_normal_noise(like_sigma[index]) for s in row])
                    new_state_vectors_transpose[index] = row_with_noise
            else:
                raise RuntimeWarning(f"{like_sigma} should be float or list of length {len(new_state_vectors_transpose)}")



            self.new_state_vectors = new_state_vectors_transpose.T # back to N x len(z) to be set correctly

        # 2d weights is N x len(z)
        else:
            # handel each row separately:
            self.resample_indices = []
            for i in range(len(self.weights[0])):
                 self.resample_indices.append(random.choices(population=np.arange(self.N), weights=self.weights[:, i], k=self.N))
            self.resample_indices = np.vstack(self.resample_indices)

            new_state_vectors_transpose = self.state_vectors.copy().T
            for index, indices in enumerate(self.resample_indices):
                new_state_vectors_transpose[index] = new_state_vectors_transpose[index, indices]

            # for now just constant perturbation, can vary this hyperparameter
            like_sigma = self.hyperparameters['like_sigma_state_vector']
            for index, row in enumerate(new_state_vectors_transpose):
                row_with_noise = np.array([s + add_normal_noise(like_sigma) for s in row])
                new_state_vectors_transpose[index] = row_with_noise

            self.new_state_vectors = new_state_vectors_transpose.T  # back to N x len(z) to be set correctly



    def generate_weights(self):
        """Takes the ensemble and observations and returns the posterior"""

        like_sigma = self.hyperparameters['like_sigma_weights']
        difference = (self.obs - self.predictions)
        unnormalised_log_weights = scipy.stats.norm.logpdf(difference, loc=0, scale=like_sigma)
        normalised_weights = np.exp(unnormalised_log_weights - scipy.special.logsumexp(unnormalised_log_weights))

        self.weights = normalised_weights
