In [1]:

try:
    import smolgp
except ImportError:
    %pip install -q smolgp

import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

(derivative)=

# Derivative Observations

In `tinygp`, one can define a custom covariance function with the appropriate derivatives to make the resulting GP be of the derivative of the process. Similarly, one can construct a kernel from a linear combination of a kernel with its derivative(s). See [this `tinygp` tutorial](https://tinygp.readthedocs.io/en/stable/tutorials/derivative.html) for morme details.

In `smolgp`, observing the derivative of a process is conceptually easier. Recall the observation matrix e.g. for the SHO is `H = [1, 0]`, which picks out the latent process in the first position of the state vector. If we want the derivative, we would simply instead use `H = [0, 1]` and so on for higher dimensional kernels (e.g., the MatÃ©rn-5/2 is defined by a 3rd order SDE and so one can go up to the second derivative with `H = [0, 0, 1]`).

To do this, define a `Wrapper` kernel that mirrors the kernel of interest but overloads the `observation_matrix`. E.g. for the SHO:

In [2]:
import equinox as eqx
from tinygp.helpers import JAXArray
from smolgp.kernels import Wrapper

class SHODerivative(Wrapper):
    """A GP for the first derivative of a SHO"""

    omega: JAXArray | float
    quality: JAXArray | float
    sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(()))

    def __init__(self, omega: JAXArray | float, 
             quality: JAXArray | float, 
             sigma: JAXArray | float = 1.0,
             name: str='SHODerivative'):
        self.omega = omega
        self.quality = quality
        self.sigma = sigma
        self.name = name
        self.kernel = smolgp.kernels.SHO(omega=omega, quality=quality, sigma=sigma)

    def observation_matrix(self, X: JAXArray) -> JAXArray:
        """The observation model H for the derivative of a SHO process"""
        del X
        return jnp.array([[0, 1]])

In [3]:
kernel = SHODerivative(omega=2*jnp.pi /50, quality=5.0, sigma=1.0)