[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tee-lab/PyDaddy/blob/colab/notebooks/7_example_cell_migration.ipynb)

# Example analysis: SDEs for cancer cell migration

(This notebook assumes that you have gone through the [Getting Started](./1_getting_started.ipynb) and [Getting Started with Vector Data](./1_getting_started_vector.ipynb) notebooks.)

This notebook illustrates the use of PyDaddy to discover mesoscale SDEs for schooling fish. The notebook uses a dataset by [Brückner et. al.](https://doi.org/10.1038/s41567-019-0445-4), which is also provided with PyDaddy as an example dataset. The dataset consists of the position and velocity of a confined cancer cell, moving back and forth on a bridge-like micropattern. Brückner et. al. observed that the movement of the cell can be explained as a relaxation oscillation, with stochasticity playing a minor role.

## Initialization

In [None]:
import sys
sys.path.append('/Users/nabeel/Documents/Research/Code/pyddsde')
import pydaddy

import numpy as np
import matplotlib.pyplot as plt

In [None]:
data, t = pydaddy.load_sample_dataset('cell-data-cellhopping')

In [None]:
ddsde = pydaddy.Characterize(data=data, t=t, bins=21)

Note that some of the plots, namely $|\mathbf{x}|$ histogram and autocorrelation, are not meaningful in this context —— the individual components here are the position $x$ and velocity $v$, $|\mathbf{x}| = \sqrt{x^2 + v^2}$ is not a meaningful quantity. 

Visualize the drift and diffusion coefficients to guess appropriate polynomial orders for fitting.

In [None]:
ddsde.drift()

In [None]:
ddsde.diffusion(limits=[0, 0.1])

## Fitting

Here, the key goal model the dynamics of $v$, i.e. to discover $f_2$ and $G_{22}$. The dynamics of $x$ is given simply by $\dot x = v$, and we assume that there are no cross diffusion terms.

Based on visualizations, we choose a cubic function for the drift, and a quartic function for diffusion. The model diagnostics (see below) will verify that this choice is sufficient to capture the essential aspects of the model.

In [None]:
f1 = ddsde.fit('F1', order=3, threshold=0.5)
f1

In [None]:
g11 = ddsde.fit('G11', order=3, threshold=1)
g11

$f_1$ and $G_{11}$ are discovered correctly.

In [None]:
f2 = ddsde.fit('F2', order=3, threshold=0.01)
f2

In [None]:
g22 = ddsde.fit('G22', order=4, threshold=0.005)
g22

In [None]:
g12 = ddsde.fit('G12', order=2, threshold=1)
g12

## Diagnostics

For this dataset, we do the diagnostics manually. 

### Noise diagnostics

In [None]:
drift_raw = ddsde._ddsde._driftY_  # Raw estimate for drift based on forward difference

# Compute residuals
eta = (drift_raw - ddsde.F2(data[0][:-1], data[1][:-1])) / np.sqrt(ddsde.G22(data[0][:-1], data[1][:-1]))
lags, acf_eta = ddsde._acf(eta, t_lag=1000)


In [None]:
xxx = np.linspace(-3, 3, 100)

fig, ax = plt.subplots(1, 2, figsize=(14, 6))
ax[0].hist(eta, bins=51, range=(-3, 3), label='Actual', density=True)
ax[0].plot(xxx, np.exp(- xxx ** 2 / 2) / (np.sqrt(2 * np.pi)), label='Theoretical')
ax[1].plot(acf_eta[:100])

ax[0].set(title='Residual Distribution', xlabel='$r$', ylabel='$P(r)$')
ax[1].set(title='Residual Autocorrelation', xlabel='Lag', ylabel='Autocorrelation')
ax[0].legend()
plt.show()

The residual autocorrelation decays quickly (within one time-step). The residual distribution resembles a Gaussian, although more peaky and heavy-tailed.

### Model diagnostics

To check for model consistency, simulate a time series with the discovered SDE using the `ddsde.simulate` function. Below is a function to generate a simulated time series with for a specified length, ensuring that the cell does not cross the valid range of $(-1, 1)$. When the cell position crosses the valid range, the cell is reset to a random (valid) value and the simulation begins again.

When $x$ crosses the valid range (due to discretization errors etc.), the simulation will start to diverge. The resetting strategy prevents this.

In [None]:
def simulate(f, g, t_int, timepoints, x0, v0):
    """
    Args:
        f: Drift function
        g: Diffusion function
        t_int: Integration time step
        timepoints: Number of time points to simulate
        x0: Initial position
        v0: Initial velocity
    """
    
    x = np.empty(timepoints)
    v = np.empty(timepoints)

    x[0], v[0] = x0, v0
    
    for i in range(timepoints - 1):
        v[i + 1] = (v[i] +
                    t_int * f(x[i], v[i]) + 
                    np.sqrt(t_int * g(x[i], v[i])) * np.random.randn())
        x[i + 1] = x[i] + t_int * v[i]
        # print(g(x[i], v[i]))
        # print(v[i + 1], x[i + 1])
                
    return [x, v]

In [None]:
t_sim = 0.001
sim = simulate(ddsde.F2, ddsde.G22, t_int=t_sim, timepoints=1000000, x0=0.1, v0=0.1)
# sim = ddsde.simulate(t_int=t_sim, timepoints=1000000, x0=[0.1, 0.1])

Before proceeding to the estimation procedure, ensure that the simulation hasn't diverged by plotting it. In case the simulation has diverged, go back to the previous cell and re-simulate.

In [None]:
plt.plot(sim[0])
plt.plot(sim[1])
plt.show()

Now, re-estimate the drift and diffusion functions from the simulated time series.

In [None]:
ddsde_sim = pydaddy.Characterize(data=sim, t=t_sim, bins=21)

In [None]:
ddsde_sim.drift(limits=[-1, 1])

In [None]:
ddsde_sim.diffusion(limits=[0, 0.1])

In [None]:
ddsde_sim.fit('F1', order=3, threshold=0.5)

In [None]:
ddsde_sim.fit('F2', order=3, threshold=0.01)

In [None]:
# Compare with the original estimate for drift.
ddsde.F2

In [None]:
ddsde_sim.fit('G11', order=3, threshold=1)

In [None]:
ddsde_sim.fit('G12', order=3, threshold=1)

In [None]:
ddsde_sim.fit('G22', order=4, threshold=0.005)

In [None]:
# Compare with the original estimate for diffusion.
ddsde.G22

The re-estimated expressions deviate slightly from the original expressions for $f$ and $g^2$. To examine if the original estimates and re-estimates are meaningfully different, we can plot the functions overlaid on one another.

In [None]:
xx, vv = np.meshgrid(np.linspace(-1, 1, 201), np.linspace(-1, 1, 201))

fig, ax = plt.subplots(1, 2, figsize=(16, 7), subplot_kw=dict(projection='3d'))
ax[0].plot_wireframe(xx, vv, ddsde.F2(xx, vv), color='r', alpha=0.5, label='Original')
ax[0].plot_wireframe(xx, vv, ddsde_sim.F2(xx, vv), alpha=0.5, label='Re-estimated')
ax[1].plot_wireframe(xx, vv, ddsde.G22(xx, vv), color='r', alpha=0.5, label='Original')
ax[1].plot_wireframe(xx, vv, ddsde_sim.G22(xx, vv), alpha=0.5, label='Re-estimated')

ax[0].set(title='Drift', xlabel='$x$', ylabel='$v$', zlabel='$f(x, v)$')
ax[1].set(title='Diffusion', xlabel='$x$', ylabel='$v$', zlabel='$g^2(x, v)$')
plt.legend()
plt.show()

We conclude that the re-estimated drift and diffusion functions are not meaningfully different from the original estimates, i.e. the model is self-consistent.