In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from neuromancer.system import Node, System
from neuromancer.dynamics import integrators, ode
from neuromancer.trainer import Trainer
from neuromancer.problem import Problem
from neuromancer.dataset import DictDataset
from neuromancer.constraint import variable
from neuromancer.loss import PenaltyLoss
from neuromancer.modules import blocks
from neuromancer.psl import plot
from neuromancer import psl


from typing import Sequence
import abc
import torchsde




An ordinary differential equation is given by 

The general form of an ordinary differential equation (ODE) is:

$$ \frac{{dx}}{{dt}} = f(t, x) $$

A neural ordinary differential equation replaces the RHS with a neural network. That is, the evolution of a system over time is represented by a continuous flow governed by an ODE, where the dynamics are parameterized by a neural network, as shown below: 

a continuous-time NODE model: $\dot{x} = f_{\theta}(x)$ with trainable parameters $\theta$.

Given training data consisting of several time-series "episode" (e.g. the system dynamics at t, t+1, t+2 -- which in Neuromancer terminology would be a rollout of nsteps=2), we can train such Neural ODE: 

Next we need to solve the continuous-time NODE model with suitable ODE solver, e.g., [Runge–Kutta integrator](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods).  
$x_{k+1} = \text{ODESolve}(f_{\theta}(x_k))$ 

For training we need to obtain accurate reverse-mode gradients of the integrated NODE system. This can be done in two ways, either by unrolling the operations of the ODE solver and using the [backpropagation through time](https://en.wikipedia.org/wiki/Backpropagation_through_time) (BPTT) algorithm, or via [Adjoint state method](https://en.wikipedia.org/wiki/Adjoint_state_method).

Schematics illustrating the adjoing method used in the [Neural Ordinary Differential Equations](https://arxiv.org/abs/1806.07366) paper:
<img src="../figs/NODE_backprop.png" width="500">  

Neuromancer provides a set of ODE solvers implemented in [integrators.py](https://github.com/pnnl/neuromancer/blob/master/src/neuromancer/dynamics/integrators.py).
For adjoint method we provide the interface to the [open-source implementation](https://github.com/rtqichen/torchdiffeq) via DiffEqIntegrator class.

We give a quick example here to motivate how one might include (and not include) stochasticity (randomness) into the system dynamics. 




The ordinary Lotka-Volterra system, also known as the predator-prey model, describes the dynamics of two interacting species in a biological community. The system consists of two coupled ordinary differential equations (ODEs), typically represented as follows:

$$
\begin{align*}
\frac{dX}{dt} &= (\alpha X - \beta XY) dt  \\
\frac{dY}{dt} &= (\delta XY - \gamma Y) dt
\end{align*}
$$

where:
- \(X\) and \(Y\) represent the population sizes of the prey and predator species, respectively.
- \(\alpha\), \(\beta\), \(\gamma\), and \(\delta\) are parameters governing the growth and interaction rates of the species.
- \(dW_1\) and \(dW_2\) are independent Wiener processes representing white noise in the population dynamics.
- \(\sigma_1\) and \(\sigma_2\) are the volatility parameters associated with the noise processes.

This system captures the stochastic fluctuations in population sizes due to random environmental factors, which can influence the dynamics of predator-prey interactions over time.


You have a neural network, denoted as $f_{\theta}$ which represents the right-hand side (RHS) of the ordinary differential equation (ODE). This neural network takes the current state of the system as input and outputs the rate of change (derivative) of the state variables. 

We then use ODE solver to generate data states at future time, e.g. at $t+1$ given by $x_{k+1} = \text{ODESolve}(f_{\theta}(x_k))$ 


You train the entire model, including the neural network dynamics and the ODE solver, end-to-end using pairs of consecutive time points (for the case of a rollout of 1 step, though this can be scaled up to predict longer horizons) from your dataset. We reshape our full system trajectory dataset into these bunched time "episodes" to achieve this rollout-based training. 

​

In [None]:
class LotkaVolterraHybrid(ode.ODESystem):

    def __init__(self, block, insize=2, outsize=2):
        """

        :param block:
        :param insize:
        :param outsize:
        """
        super().__init__(insize=insize, outsize=outsize)
        self.block = block
        self.alpha = nn.Parameter(torch.tensor([.10]), requires_grad=True)
        self.beta = nn.Parameter(torch.tensor([.10]), requires_grad=True)
        self.delta = nn.Parameter(torch.tensor([.10]), requires_grad=True)
        self.gamma = nn.Parameter(torch.tensor([.10]), requires_grad=True)
        assert self.block.in_features == 2
        assert self.block.out_features == 1

    def ode_equations(self, x):
        x1 = x[:, [0]]
        x2 = x[:, [-1]]
        dx1 = self.alpha*x1 - self.beta*self.block(x)
        dx2 = self.delta*self.block(x) - self.gamma*x2
        return torch.cat([dx1, dx2], dim=-1)

$\dot{x} = f_{\vec{\theta}}(x)$ 

Here, $  \vec{\theta} $ is a vector with parameters  $ [\alpha, \beta, \gamma, \delta, \theta'] $ and $\theta'$ parameterizes the multi-layer perception block

In [None]:

def get_data(sys, nsim, nsteps, ts, bs):
    """
    :param nsteps: (int) Number of timesteps for each batch of training data
    :param sys: (psl.system)
    :param ts: (float) step size
    :param bs: (int) batch size

    """
    train_sim, dev_sim, test_sim = [sys.simulate(nsim=nsim, ts=ts) for i in range(3)]
    nx = sys.nx
    nbatch = nsim//nsteps #500
    length = (nsim//nsteps) * nsteps #1000
    ts = torch.linspace(0,1,nsteps)
    print('train sim ', train_sim['X'].shape)

    trainX = train_sim['X'][:length].reshape(nbatch, nsteps, nx)
    trainX = torch.tensor(trainX, dtype=torch.float32)

    print(trainX.shape)# N x nsteps x state_size 

    train_data = DictDataset({'X': trainX, 'xn': trainX[:, 0:1, :]}, name='train')
    train_loader = DataLoader(train_data, batch_size=bs,
                              collate_fn=train_data.collate_fn, shuffle=True)

    devX = dev_sim['X'][:length].reshape(nbatch, nsteps, nx)
    devX = torch.tensor(devX, dtype=torch.float32)
    dev_data = DictDataset({'X': devX, 'xn': devX[:, 0:1, :]}, name='dev')
    dev_loader = DataLoader(dev_data, batch_size=bs,
                            collate_fn=dev_data.collate_fn, shuffle=True)

    testX = test_sim['X'][:length].reshape(1, nsim, nx)
    testX = torch.tensor(testX, dtype=torch.float32)
    test_data = {'X': testX, 'xn': testX[:, 0:1, :]}

    return train_loader, dev_loader, test_data, trainX

torch.manual_seed(0)

# %%  ground truth system
system = psl.systems['LotkaVolterra']
modelSystem = system()
ts = modelSystem.ts
nx = modelSystem.nx
raw = modelSystem.simulate(nsim=1000, ts=ts)
plot.pltOL(Y=raw['X'])
plot.pltPhase(X=raw['Y'])

## Neural ODEs with Added Stochasticity: 

A stochastic differential equation is given by :

$$ dx = f(t, x) \, dt + g(t, x) \, dW $$

THe $f$ term is known as the drift process; the $g$ term is known as the diffusion process. Note that if the diffusion process is zero then an SDE simplifies to an ODE and can be solved with via backpropagating ODE solver and doing the reverse-time ODE as we have shown previously. 

For simplicity we can assume there exists reverse-time integration/backpropagating through SDE solvers. This paper, https://arxiv.org/pdf/2001.01328, describes it in detail. 

A natural question therefore is how to train such Neural ODEs with Stochastic Terms. Can we do it using standard Neuromancer training procedure for Neural ODEs -- our reference tracking and finite difference losses. We attempt to do this below. Note that it will **not** work and the purpose of demonstrating this is to motivate the need for a variational inference approach to train the SDE -- the Latent SDE. 

We can formulate neural networks to parameterize not only the drift process, but also the diffusion process, e.g: 

$$ \dot{x} = f_{\vec{\theta_f}}(x) + g_{\vec{\theta_g}}(x) $$

Where $g$ is a neural network to model the stochastic process. 

To support this framework and integrate it with TorchSDE solvers, we define a base class: 

In [None]:
class BaseSDESystem(abc.ABC, nn.Module):
    """
    Base class for SDEs for integration with TorchSDE library
    """
    def __init__(self):
        super().__init__()
        self.noise_type = "diagonal"
        self.sde_type = "ito"
        self.in_features = 0
        self.out_features = 0

    @abc.abstractmethod
    def f(self, t, y):
        """
        Define the ordinary differential equations (ODEs) for the system.

        Args:
            t (Tensor): The current time (often unused)
            y (Tensor): The current state variables of the system.

        Returns:
            Tensor: The derivatives of the state variables with respect to time.
                    The output should be of shape [batch size x state size]
        """
        pass

    @abc.abstractmethod
    def g(self, t,y):
        """
        Define the diffusion equations for the system.

        Args:
            t (Tensor): The current time (often unused)
            y (Tensor): The current state variables of the system.

        Returns:
            Tensor: The diffusion coefficients per batch item (output is of size 
                    [batch size x state size]) for noise_type 'diagonal'
        """
        pass

We use a stochastic Lotka-Volterra model (for forward passes only) with user-defined parameters to generate ground truth data: 

In [None]:
# Define the Lotka-Volterra SDE
class LotkaVolterraSDE(nn.Module):
    def __init__(self, a, b, c, d, sigma1, sigma2):
        super().__init__()
        self.a = a
        self.b = b
        self.c = c
        self.d = d
        self.sigma1 = sigma1
        self.sigma2 = sigma2
        self.noise_type = "diagonal"
        self.sde_type = "ito"

    def f(self, t, x):
        x1 = x[:,[0]]
        x2 = x[:,[1]]
        dx1 = self.a * x1 - self.b * x1*x2
        dx2 = self.c * x1*x2 - self.d * x2
        foo = torch.cat([dx1, dx2], dim=-1)
        return torch.cat([dx1, dx2], dim=-1)

    def g(self, t, x):
        sigma_diag = torch.tensor([[self.sigma1, self.sigma2]])
        return sigma_diag #[batch_size x state size ]

# Define parameters
a = 1.1    # Prey growth rate
b = 0.4   # Predation rate
c = 0.1   # Predator growth rate
d = 0.4   # Predator death rate
sigma1 = 1
sigma2 = 0

# Create the SDE model
sde = LotkaVolterraSDE(a, b, c, d, sigma1, sigma2)


# Define time span
t_span = torch.linspace(0, 20, 2000)

# Initial condition
x0 = torch.tensor([10.0, 10.0]).unsqueeze(0) #[1x2]


# Integrate the SDE model
sol_train = torchsde.sdeint(sde, x0, t_span, method='euler')
sol_dev = torchsde.sdeint(sde, x0, t_span, method='euler')
sol_test = torchsde.sdeint(sde, x0, t_span, method='euler')




In [None]:
# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(t_span, sol_train[:, 0,0], label='Prey (x1)')
plt.plot(t_span, sol_train[:,0, 1], label='Predator (x2)')
plt.xlabel('Time')
plt.ylabel('Population')
plt.title('Stochastic Lotka-Volterra Predator-Prey Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
class LotkaVolterraSDELearnable(BaseSDESystem):
    def __init__(self, block, batch_size):
        super().__init__()
        self.block = block 
        self.alpha = nn.Parameter(torch.tensor([.10]), requires_grad=True)
        self.beta = nn.Parameter(torch.tensor([.10]), requires_grad=True)
        self.delta = nn.Parameter(torch.tensor([.10]), requires_grad=True)
        self.gamma = nn.Parameter(torch.tensor([.10]), requires_grad=True)
        self.g_params = nn.Parameter(torch.randn(batch_size, 2), requires_grad=True)  # Learnable parameters
    def f(self, t, y):

        x1 = y[:, [0]]
        x2 = y[:, [-1]]

        dx1 = self.alpha*x1 - self.beta*self.block(y)
        dx2 = self.delta*self.block(y) - self.gamma*x2

        return torch.cat([dx1, dx2], dim=-1)

    def g(self, t, y):
        return self.g_params

# construct UDE model in Neuromancer
net = blocks.MLP(2, 1, bias=True,
                    linear_map=torch.nn.Linear,
                    nonlin=torch.nn.GELU,
                    hsizes=4*[20])
fx = LotkaVolterraSDELearnable(block=net, batch_size = 2)


class BasicSDEIntegrator(integrators.Integrator): 
    """
    Integrator (from TorchSDE) for basic/explicit SDE case where drift (f) and diffusion (g) terms are defined 
    Returns a single tensor of size (t, batch_size, state_size).

    Please see https://github.com/google-research/torchsde/blob/master/torchsde/_core/sdeint.py
    Currently only supports Euler integration. Choice of integration method is dependent 
    on integral type (Ito/Stratanovich) and drift/diffusion terms
    """
    def __init__(self, block ): 
        """
        :param block: (nn.Module) The BasicSDE block
        """
        super().__init__(block) 


    def integrate(self, x): 
        """
        x is the initial datastate of size (batch_size, state_size)
        t is the time-step vector over which to integrate
        """
        t = torch.tensor([0.,0.01, 0.02], dtype=torch.float32)
        x = x.squeeze(1) #remove time step 
  
        ys = torchsde.sdeint(self.block, x, t, method='euler')
        ys = ys.permute(1, 0, 2)
        return ys 

integrator = BasicSDEIntegrator(fx) 
# integrate UDE model
# create symbolic UDE model
model_sde = Node(integrator, input_keys=['xn'], output_keys=['xn'])
dynamics_model_sde = model

Using mean squared error (MSE) to train a neural SDE on time pairs might encounter challenges due to the stochastic nature of SDEs. While MSE is a common loss function used for deterministic systems, it may not be directly applicable to stochastic systems like SDEs.

Here are some considerations when using MSE for training neural SDEs:

Ignoring Stochasticity: MSE only considers the deterministic part of the model and ignores the stochastic component represented by the diffusion term in the SDE. This can lead to suboptimal results as the model does not capture the inherent randomness in the system.

Overfitting the Drift Term: MSE optimization might focus excessively on minimizing the errors in the drift term while neglecting the diffusion term. This can result in overfitting of the deterministic part of the model and underfitting of the stochastic part.

SDEs inherently involve randomness or uncertainty, typically represented by the stochastic terms in the differential equations. Variational inference allows us to capture this uncertainty by providing a probabilistic characterization of the latent variables' distribution. Instead of obtaining a single point estimate, variational inference provides a full probabilistic description, including measures of uncertainty such as confidence intervals or predictive distributions. 

The Latent SDE is essentially a variational "autoencoder" where instead of seeking resynthesize samples corresponding to the "same time" instance, it tries to reconstruct future samples given the current dynamics of the system, where the dynamics are known to be modeled via a SDE. To do this, the latent space is **itself** going to be governed via an SDE and we perform integration on the latent space to synthesize forward-looking samples. 

Variational inference is a powerful method used to approximate complex posterior distributions in probabilistic models. In the context of latent stochastic differential equations (SDEs), variational inference plays a crucial role in estimating the posterior distribution of the latent variables given the observed data.

In latent SDEs, the goal is to infer the hidden or latent variables that govern the dynamics of the system. These latent variables capture unobserved factors that influence the observed data, such as underlying trends, patterns, or noise sources. However, directly computing the posterior distribution of the latent variables given the data is often analytically intractable due to the complex and nonlinear nature of the model.

Variational inference offers a solution to this problem by approximating the true posterior distribution with a simpler, parameterized distribution, often chosen from a family of distributions such as Gaussian distributions. This is done via an encoder network. The decoder network draws from samples of this learned, approximate posterior to reconstruct the data distribution. Using the KL divergence, between these distributions, we learn the latent space's parameters known as the variational parameters. 

## Latent SDE

1. **Encoder**:
   - The encoder maps each input data point $x$ to a distribution over latent variables $z$. This distribution is typically Gaussian with mean $\mu$ and standard deviation $\sigma$.
   $$
   q_{\phi}(z | x) = \mathcal{N}(\mu_{\phi}(x), \sigma_{\phi}(x))
   $$
   - Here, $\mu_{\phi}(x)$ and $\sigma_{\phi}(x)$ are the mean and standard deviation parameters of the Gaussian distribution, which are output by the encoder neural network parameterized by $\phi$.

2. **Latent Dynamics**:
   - The latent variables $z$ evolve over time according to a stochastic differential equation (SDE). The dynamics of $z$ are governed by drift and diffusion functions, similar to the SDE for the observed data $x$.
   $$
   dz_t = f(z_t, t) \, dt + G(z_t, t) \, dW_t
   $$
   - $f(z_t, t)$ represents the drift component, determining the deterministic evolution of the latent variables.
   - $G(z_t, t)$ represents the diffusion component, introducing stochasticity into the latent dynamics.
   - $dW_t$ is the increment of a Wiener process (Brownian motion), representing random noise.

3. **Decoder**:
   - The decoder takes samples from the latent space $z$ and maps them back to the data space $x$. It models the conditional distribution of $x$ given $z$.
   $$
   p_{\theta}(x | z)
   $$
   - The decoder neural network, parameterized by $\theta$, outputs the parameters of the conditional distribution $p_{\theta}(x | z)$, such as the mean and variance of a Gaussian distribution or the parameters of a Bernoulli distribution for binary data.

4. **Latent Variable Prior**:
   - We assume a prior distribution over the latent variables $z$. This distribution is typically chosen to be a standard Gaussian.
   $$
   p(z) = \mathcal{N}(0, I)
   $$

   though in TorchSDE's framework (and as shown in the code below), these are learnable parameters qz0_mean and qz0

5. **Objective Function**:
   - The objective function for training the Latent SDE model is similar to the ELBO in VAEs but now includes the evolution of latent variables governed by the SDE.
   $$
   \text{ELBO}(\theta, \phi; x) = \mathbb{E}_{q_{\phi}(z | x)} [\log p_{\theta}(x | z)] - \text{KL}[q_{\phi}(z | x) || p(z)]
   $$
   - The first term represents the reconstruction loss, measuring how well the decoder reconstructs the input data $x$ from the latent variable samples $z$.
   - The second term is the KL divergence between the approximate posterior $q_{\phi}(z | x)$ and the prior $p(z)$, which encourages the approximate posterior to match the prior.
