---
title: "Integro-Difference Equation Models: Fitting (Prototype)"
format:
  html:
    code-fold: true
    toc: true
    mathjax: 
      extensions: ["breqn", "bm"]
jupyter: python3
include-in-header:
  - text: |
      <script>
      window.MathJax = {
        loader: {
          load: ['[tex]/upgreek', '[tex]/boldsymbol', '[tex]/physics', '[tex]/breqn'
        },
        tex: {
          packages: {
            '[+]': ['upgreek', 'boldsymbol', 'physics', 'breqn']
          }
        }
      };
      </script>
bibliography: Bibliography.bib
---




[Index](./index.html)

\DeclareMathOperator{\var}{\mathbb{V}\mathrm{ar}}
\DeclareMathOperator{\cov}{\mathbb{C}\mathrm{ov}}
\renewcommand*{\vec}[1]{\boldsymbol{\mathbf{#1}}}
\newcommand\eqc{\stackrel{\mathclap{c}}{=}}

## Target Spatially Invariant Kernel Model

Using ```gen_example_idem``` with the argument ````k_spat_inv=True```, we can easily generate a model to create a synthetic dataset to fit to.


In [None]:
#| output: false
import sys
import os
sys.path.append(os.path.abspath('../src/jax_idem'))

import jax
import utilities
import IDEM

from utilities import *
from IDEM import *
import warnings

key = jax.random.PRNGKey(12)
keys = rand.split(key, 3)

process_basis = place_basis(nres=2, min_knot_num=5)
nbasis = process_basis.nbasis

m_0 = jnp.zeros(nbasis).at[16].set(1)
sigma2_0 = 0.001

truemodel = gen_example_idem(
    keys[0], k_spat_inv=True,
    process_basis=process_basis,
    m_0=m_0, sigma2_0=sigma2_0
)

# Simulation
T = 10
                                            
process_data, obs_data = truemodel.simulate(nobs=50, T=T + 1, key=keys[1])


# Plotting
gif_st_grid(process_data, output_file="target_process.gif")
gif_st_pts(obs_data, output_file="synthetic_observations.gif")
plot_kernel(truemodel.kernel, output_file="target_kernel.png")

::: {#fig-example layout-ncol=3}

![Process](target_process.gif)

![Observations](synthetic_observations.gif)

![Kernel](target_kernel.png)

An example target simulation, with the underlying process (left), noisy observations, and the direction of 'flow' dictated by the kernel (right).

:::

We now create a 'shell' model, which we will fit to the above data, initialising all relevent parameters.


In [None]:
#| output: false
# use the same kernel basis as the true model for now
K_basis = truemodel.kernel.basis
# scale and shape of the kernel will be the same, but the offsets will be estimated
k = (
    jnp.array([150]),
    jnp.array([0.002]),
    jnp.array([0]),
    jnp.array([0]),
)
# This is the kind of kernel used by ```gen_example_idem```
kernel = param_exp_kernel(K_basis, k)

process_basis2 = place_basis(nres=1, min_knot_num=5) # courser process basis with 25 total basis functions
nbasis0 = process_basis2.nbasis

model0 = IDEM_model(
        process_basis = process_basis2,
        kernel=kernel,
        process_grid = create_grid(jnp.array([[0, 1], [0, 1]]), jnp.array([41, 41])),
        sigma2_eta = truemodel.sigma2_eta,
        sigma2_eps = truemodel.sigma2_eps,
        beta = jnp.array([0, 0, 0]),
        m_0 = jnp.zeros(nbasis0),
        sigma2_0=truemodel.sigma2_0)

And we can simulate from this initial model with


In [None]:
#| output: false
unfit_process_data, unfit_obs_data = model0.simulate(nobs=1, T=T + 1, key=key)
# Plotting
gif_st_grid(unfit_process_data, output_file="unfit_process.gif")

::: {#fig-example-2 layout-ncol=1}

![Unfit Process](unfit_process.gif)

The unfit 'shell' model which we will use as an initial model for fittingto the synthetic 'true' model. As we can see, the numbers are significantly lower, due to the initial value of the process being 0, and no motion being present.

:::

## Filtering (and smoothing)

The first step is to apply the kalman filter to this model, using the data contained in ```obs_data``` from the 'true' model. 
This can be easily do through the functions ```IDEM.filter``` and ```IDEM.smooth```, which output a tuple with the relevant output quantities.
These filtered and smoothed processes can be plotted, and look good. However, given ```model0``` has no movement, unsurprisingly, the likelihood is lower than that of the true model.


In [None]:
# Currently, the Kalman filter requires the data to be in wide format.
obs_data_wide = ST_towide(obs_data)
     
# although it is irrelevent for this particular model, we need to put in the covariate matrix into filter
obs_locs = jnp.column_stack((obs_data_wide.x, obs_data_wide.y))
nobs = obs_locs.shape[0]
X_obs = jnp.column_stack([jnp.ones(nobs), obs_locs])

     
ll, ms, Ps, mpreds, Ppreds, Ks = model0.filter(obs_data_wide, X_obs)

# Make this filtered means into an ST_Data_long in the Y space
filt_data = basis_params_to_st_data(ms, model0.process_basis, model0.process_grid)


# We can similarily smooth the model as well
m_tTs, P_tTs, Js = model0.smooth(ms, Ps, mpreds, Ppreds)
smooth_data = basis_params_to_st_data(
    m_tTs, model0.process_basis, model0.process_grid
)
# plot the filtered and smoothed data
gif_st_grid(filt_data, output_file="filtered.gif")
gif_st_grid(smooth_data, output_file="smoothed.gif")

true_ll, _, _, _, _, _ = truemodel.filter(obs_data_wide, X_obs)

print(f"The log likelihood (up to a constant) of the unfit model is {ll}")
print(f"The log likelihood (up to a constant) of the true model is {true_ll}")

::: {#fig-example-3 layout-ncol=2}

![Filtered process means](filtered.gif)

![Smoothed process means](smoothed.gif)

(write a description)

:::


## Fitting

We will now fit all the parameters that differ between the true model and ```model0```; these are the kernel 'drift'  parameters, ```IDEM.kernel.parmas[2:3]```, and the initial value of the process basis coefficients, ```IDEM.m_0```.
We can do this simply by creating an objective function which takes these parameters and outputs the negative log-likelihood from the kalman filter.
Since most functions of the project are written with jit and auto-differentiation, we can also get the gradient of this objective.


In [None]:
#| eval: true
#| echo: true
#| output: true

nobs = obs_locs.shape[0]
PHI_obs = model0.process_basis.mfun(obs_locs)
PHI = model0.process_basis.mfun(model0.process_grid.coords)
GRAM = (PHI.T @ PHI) * model0.process_grid.area

# Function to construct the M matrix from the kernel parameters; this will be built in to IDEM in the future
@jax.jit
def con_M(k):
    @jax.jit
    def kernel(s, r):
        theta = (
            k[0] @ model0.kernel.basis[0].vfun(s),
            k[1] @ model0.kernel.basis[1].vfun(s),
            jnp.array(
                [
                    k[2] @ model0.kernel.basis[2].vfun(s),
                    k[3] @ model0.kernel.basis[3].vfun(s),
                ]
            ),
        )
        return theta[0] * jnp.exp(-(jnp.sum((r - s - theta[2]) ** 2)) / theta[1])

    K = outer_op(model0.process_grid.coords, model0.process_grid.coords, kernel)
    return solve(GRAM, PHI.T @ K @ PHI) * model0.process_grid.area**2

@jax.jit
def objective(params):
    m_0 = params["m_0"]

    # and the first two kernel params struggle very much to fit
    ks = (jnp.array([150]), jnp.array([0.002]), jnp.array([params["k1"]]), jnp.array([params["k2"]]))
     
    M = model0.con_M(ks)
     
    Sigma_eta = model0.sigma2_eta * jnp.eye(nbasis0)
    Sigma_eps = model0.sigma2_eps * jnp.eye(nobs)
    P_0 = model0.sigma2_0 * jnp.eye(nbasis0)
    
    carry, seq = kalman_filter(
        m_0,
        P_0,
        M,
        PHI_obs,
        Sigma_eta,
        Sigma_eps,
        model0.beta,
        obs_data_wide.z,
        X_obs,
    )
    return -carry[4]

#param0 = jnp.concatenate(
#    [model0.m_0, jnp.array([model0.kernel.params[2][0], model0.kernel.params[3][0]])]
#)

param0 = {"m_0": model0.m_0,
          "k1": jnp.array(0.0),
          "k2": jnp.array(0.0)}

obj_grad = jax.grad(objective)
     
print("The initial value of the negative log-likelihood is", objective(param0))
print("with gradient", obj_grad(param0))

We can then use standard optimisation techniques to optimise. 
For example, using the ADAM optimiser in OPTAX,


In [None]:
#| eval: false
#| echo: true
#| output: false

import optax
     
start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)

param_ad = param0
opt_state = optimizer.init(param_ad)

# A simple update loop.
for i in range(10):
    grad = obj_grad(param_ad)
    updates, opt_state = optimizer.update(grad, opt_state)
    param_ad = optax.apply_updates(param_ad, updates)
    nll = objective(param_ad)

print(param_ad)

Putting these parameters into a new model;


In [None]:
#| eval: false
#| echo: true
#| output: false

fitted_m_0 = param_ad["m_0"]
fitted_ks = (jnp.array([150]), jnp.array([0.002]), jnp.array([param_ad["k1"]]), jnp.array([param_ad["k2"]]))
fitted_kernel = param_exp_kernel(K_basis, fitted_ks)

fitted_model = IDEM(
        process_basis = process_basis2,
        kernel = fitted_kernel,
        process_grid = create_grid(jnp.array([[0, 1], [0, 1]]), jnp.array([41, 41])),
        sigma2_eta = truemodel.sigma2_eta,
        sigma2_eps = truemodel.sigma2_eps,
        beta = jnp.array([0, 0, 0]),
        m_0 = fitted_m_0,
        sigma2_0=truemodel.sigma2_0)

fit_process_data, fit_obs_data = fitted_model.simulate(nobs=50, T=T + 1, key=key)
gif_st_grid(fit_process_data, output_file="fitted_process.gif")
plot_kernel(fitted_model.kernel, output_file="fitted_kernel.png")

::: {#fig-example-3 layout-ncol=2}

![Fitted process simulation](fitted_process.gif)

![Fitted kernel](fitted_kernel.png)

(write description)

:::

# Optimising over all parameters, with constraints

Above, we only fitted the data to the initial process coefficients and the two offset terms in the kernel. 
We actually fixed all the variances, and the scale and shape of the kernel. 
Ideally, we want to be able to fit for those too.
Of course, these parameters are constrained; each one is non negative, and it may be worth bounding them above too to avoid anything going up indefinitely.

To ensure all optimsers work in a similar way, OPTAX uses projections to handle this kind of constraints. Firstly, lets re-write some of the above code to also include the other parameters.


In [None]:
#| eval: false
#| echo: true
#| output: false

k = (
    jnp.array([100.0]),
    jnp.array([0.001]),
    jnp.array([0.0]),
    jnp.array([0.0]),
)
# This is the kind of kernel used by ```gen_example_idem```
kernel = param_exp_kernel(K_basis, k)

model1 = IDEM(
        process_basis = process_basis2,
        kernel=kernel,
        process_grid = create_grid(jnp.array([[0, 1], [0, 1]]), jnp.array([41, 41])),
        sigma2_eta = 0.01,
        sigma2_eps = 0.01,
        beta = jnp.array([0.0, 0.0, 0.0]),
        m_0 = jnp.zeros(nbasis0),
        sigma2_0=0.01)
# a model with inaccurate 'guesses'
v_unfit_process_data, v_unfit_obs_data = model1.simulate(nobs=1, T=T + 1, key=key)
# Plotting
gif_st_grid(v_unfit_process_data, output_file="very_unfit_process.gif")

ll, _, _, _, _, _ = model1.filter(obs_data_wide, X_obs)
print(ll)

::: {#fig-example-4 layout-ncol=1}

![Very unfit process](very_unfit_process.gif)

(write description)

:::

Presumably due to the higher variances, the log likelihood here is higher that that of ```model0```, but still much short of the true model.
Making the optimisation function,


In [None]:
#| eval: false
#| echo: true
#| output: false

@jax.jit
def objective(params):
    m_0, sigma2_0, sigma2_eta, sigma2_eps, ks = params

    #sigma2_0, sigma2_eta, sigma2_eps = truemodel.sigma2_0, truemodel.sigma2_eta, truemodel.sigma2_eps
    #ks = (ks[0], ks[1], ks[2], ks[3])
    M = con_M(ks)
     
    Sigma_eta = sigma2_eta * jnp.eye(nbasis0)
    Sigma_eps = sigma2_eps * jnp.eye(nobs)
    P_0 = sigma2_0 * jnp.eye(nbasis0)
    
    carry, seq = kalman_filter(
        m_0,
        P_0,
        M,
        PHI_obs,
        Sigma_eta,
        Sigma_eps,
        model0.beta,
        obs_data_wide.z,
        X_obs,
    )
    return -carry[4]

obj_grad = jax.grad(objective)

params0 = (model1.m_0,
           model1.sigma2_0,
           model1.sigma2_eta,
           model1.sigma2_eps,
           model1.kernel.params)

If we naively apply what we did before, we will get nans as some variances become negative, causing some cholesky decomposition to fail. 
We can remedy this by defining a bounding box for the data;


In [None]:
#| eval: false
#| echo: false
#| output: false

lower = (jnp.full(nbasis0, -jnp.inf),
         jnp.array(0.0),
         jnp.array(0.0),
         jnp.array(0.0),
         (jnp.array(0.0), jnp.array(0.0), jnp.array(-jnp.inf), jnp.array(-jnp.inf)))

upper = (jnp.full(nbasis0, jnp.inf),
         jnp.array(0.1),
         jnp.array(0.1), 
         jnp.array(0.1),
         (jnp.array(500.0), jnp.array(0.01), jnp.array(jnp.inf), jnp.array(jnp.inf)))

# with this many parameters, must use a lower starting learning rate
start_learning_rate = 1e-3
optimizer = optax.adam(start_learning_rate)

params_ad = params0
opt_state = optimizer.init(params_ad)

# A simple update loop.
for i in range(10):
    grad = obj_grad(params_ad)
    updates, opt_state = optimizer.update(grad, opt_state)
    params_ad = optax.apply_updates(params_ad, updates)
    params_ad = optax.projections.projection_box(params_ad, lower, upper)
    nll = objective(params_ad)

print(params_ad)

This can take some time, so this code is not run interactively here. Instead, there are some scripts included here to run the built-in methods containing the logic above.

There is now a built in method for fitting this kind of model, ```IDEM.fit```. Reading some data (generated from the R-IDE package) and formatting it with ```panda```,


In [None]:
import pandas as pd

df = pd.read_csv("z_obs.csv")

date_mapping = {date: i-1 for i, date in enumerate(df['time'].unique(), 1)}
df['time'] = df['time'].map(date_mapping)

::: {#fig-example-5 layout-ncol=2}

![Fitted process simulation](new_fitted_process.gif)

![Fitted kernel](new_fitted_kernel.png)

(write description)
:::


# Information Filter

We can write the Kalman filter in a different, and possibly more useful, form; called the information filter.
This uses, instead of means and variances, the information matrices and vectors, $Q_{k\mid l} = P_{k\mid l}^{-1}$$ and $$\nu_{k\mid l} = Q_{k\mid l}m_{k\mid l}$, respectively.
See the [subheading on this in the Mathematics page](./mathematics.html) for more detail.

Theoretically, these are identical. 
However, this form allows to more easily begin with infinite variance ($Q_0 = 0$), which is possible with the kalman filter, but reuires skipping the first step with analytically obtained results.
The information filter also accounts for time-variyng data better (changing number of observation locations, as well as those observations changing locations), since it requires a scan over vectors in the (always constant) state space, as opposed tot he changing data space.

Let's start by generating some data with much more random observation locations.


In [None]:
key = jax.random.PRNGKey(1)
keys = rand.split(key, 3)

# We'll re-use truemodel form before, but simulate from different locations

nobs = jax.random.randint(keys[1], (T,), 50, 101)

locs_keys = jax.random.split(keys[2], T)

obs_locs = jnp.vstack(
            [
                jnp.column_stack(
                    [
                        jnp.repeat(t + 1, n),
                        rand.uniform(
                            locs_keys[t],
                            shape=(n, 2),
                            minval=0,
                            maxval=1,
                        ),
                    ]
                )
                for t, n in enumerate(nobs)
            ]
        )

# Simulation, but this time providing a custom value for obs_locs
process_data, obs_data = truemodel.simulate(
            T=T, key=keys[1], obs_locs=obs_locs
        )

# Plotting
gif_st_grid(process_data, output_file="target_process_2.gif")
gif_st_pts(obs_data, output_file="synthetic_observations_2.gif")

::: {#fig-example layout-ncol=2}

![Process](target_process_2.gif)

![Observations](synthetic_observations_2.gif)

An example target simulation, with the underlying process (left), noisy observations (right), this time randomly variying in number and positions.

:::

The function  ```kalman_filter``` has no support for these kinds of observation, but ```information_filter``` does;


In [None]:
nbasis = truemodel.process_basis.nbasis
nu_0 = jnp.zeros(nbasis)
Q_0 = jnp.zeros((nbasis, nbasis)) # infinite variance!

obs_locs = jnp.column_stack(
           jnp.column_stack((obs_data.x, obs_data.y))
).T

X_obs = jnp.column_stack([jnp.ones(obs_locs.shape[0]), obs_locs[:, -2:]])

nus, Qs = truemodel.filter_information(
           obs_data,
           X_obs,
           nu_0,
           Q_0,
)

Of course, the $\nu_{t}$ values which are the output of the information filter need a little more work to visualise like the means of the Kalman filter, so let's use a ```jnp.solve``` to extract the means;


In [None]:
ms = jnp.linalg.solve(Qs, nus)
print(ms.shape)
print(Qs)
print(ms[0])

filt_data = basis_params_to_st_data(ms, truemodel.process_basis, truemodel.process_grid)

gif_st_grid(filt_data, output_file="filtered_2.gif")

::: {#fig-example-80orsomething layout-ncol=1}

![Filtered process means](filtered_2.gif)

(write a description)

:::