Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/torchsde integration4 #145

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Binary file added examples/SDEs/.DS_Store
Binary file not shown.
12 changes: 12 additions & 0 deletions examples/SDEs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# TorchSDE x NeuroMANCER

The example in this folder, sde_walkthrough.ipynb, demonstrates how functionality from TorchSDE can be, and is, integrated into the Neuromancer workflow. https://github.com/google-research/torchsde/tree/master

TorchSDE provides stochastic differential equation solvers with GPU spport and efficient backpropagation. They are based off this paper: http://proceedings.mlr.press/v108/li20i.html

Neuromancer already has robust and extensive library for Neural ODEs and ODE solvers. We extend that functionality to the stochastic case by incorporating TorchSDE solvers. To motivate and teach the user how one progresses from neural ODEs to "neural SDEs" we have written a lengthy notebook -- sde_walkthrough.ipynb

Please ensure torchsde is installed:
```
pip install torchsde
```
Empty file removed examples/SDEs/sde_test.py
Empty file.
5,777 changes: 5,777 additions & 0 deletions examples/SDEs/sde_walkthrough.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ dependencies = [
"cvxpy",
"cvxpylayers",
"casadi",
"wandb"
"wandb",
"torchsde"
]

version = "1.5.0"
Expand Down
58 changes: 58 additions & 0 deletions src/neuromancer/dynamics/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torchdiffeq import odeint_adjoint as odeint
import torchdiffeq
import torchsde
from abc import ABC, abstractmethod


Expand Down Expand Up @@ -81,6 +82,63 @@ def integrate(self, x, *args):
adjoint_options=dict(norm=make_norm(x)))
x_t = solution[-1]
return x_t

class BasicSDEIntegrator(Integrator):
"""
Integrator (from TorchSDE) for basic/explicit SDE case where drift (f) and diffusion (g) terms are defined
Returns a single tensor of size (t, batch_size, state_size).

Please see https://github.com/google-research/torchsde/blob/master/torchsde/_core/sdeint.py
Currently only supports Euler integration. Choice of integration method is dependent
on integral type (Ito/Stratanovich) and drift/diffusion terms
"""
def __init__(self, block):
"""
:param block: (nn.Module) The BasicSDE block
"""
super().__init__(block)

def integrate(self, x, t):
"""
x is the initial datastate of size (batch_size, state_size)
t is the time-step vector over which to integrate
"""
ys = torchsde.sdeint(self.block, x, t, method='euler')
return ys

class LatentSDEIntegrator(Integrator):
"""
Integrator (from TorchSDE) for LatentSDE case. Please see https://github.com/google-research/torchsde/blob/master/examples/latent_sde_lorenz.py for more
information. Integration here takes place in the latent space produced by the first-stage (encoding process) of the LatentSDE_Encoder block
Note that torchsde.sdeint() is called, like in BasicSDEIntegrator, and thus the output of integrate() is a single tensor of size (t, batch_size, latent_size)
In this case we also set logqp to True such that log ratio penalty is also returned.
PLease see: https://github.com/google-research/torchsde/blob/master/torchsde/_core/sdeint.py
"""
def __init__(self, block, dt=1e-2, method='euler', adjoint=False):
"""
:param block:(nn.Module) The LatentSDE_Encoder block
:param dt: (float, optional): The constant step size or initial step size for
adaptive time-stepping.
:param method: method (str, optional): Numerical integration method to use. Must be
compatible with the SDE type (Ito/Stratonovich) and the noise type
(scalar/additive/diagonal/general). Defaults to a sensible choice
depending on the SDE type and noise type of the supplied SDE.
"""
super().__init__(block)
self.method = method
self.dt = dt
self.adjoint = adjoint
if self.adjoint:
assert self.block.adjoint == True, "LatentSDE_Encoder block must have adjoint=True if using adjoint method here"

def integrate(self, x):
if not self.adjoint:
z0, xs, ts, qz0_mean, qz0_logstd = self.block(x)
zs, log_ratio = torchsde.sdeint(self.block, z0, ts, dt=self.dt, logqp=True, method=self.method)
else:
z0, xs, ts, qz0_mean, qz0_logstd, adjoint_params = self.block(x)
zs, log_ratio = torchsde.sdeint_adjoint(self.block, z0, ts, adjoint_params=adjoint_params, dt=self.dt, logqp=True, method=self.method)
return zs, z0, log_ratio, xs, qz0_mean, qz0_logstd


class Euler(Integrator):
Expand Down
288 changes: 288 additions & 0 deletions src/neuromancer/dynamics/sde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@


import abc
import torch
from torch import nn
import torchsde
import abc

from torch.distributions import Normal
from typing import Sequence


class BaseSDESystem(abc.ABC, nn.Module):
"""
Base class for SDEs for integration with TorchSDE library
"""
def __init__(self):
super().__init__()
self.noise_type = "diagonal" #only supports diagonal diffusion right now
self.sde_type = "ito" #only supports Ito integrals right now
self.in_features = 0 #for compatibility with Neuromancer integrators; unused
self.out_features = 0

@abc.abstractmethod
def f(self, t, y):
"""
Define the ordinary differential equations (ODEs) for the system.

Args:
t (Tensor): The current time (often unused)
y (Tensor): The current state variables of the system.

Returns:
Tensor: The derivatives of the state variables with respect to time.
The output should be of shape [batch size x state size]
"""
pass

@abc.abstractmethod
def g(self, t,y):
"""
Define the diffusion equations for the system.

Args:
t (Tensor): The current time (often unused)
y (Tensor): The current state variables of the system.

Returns:
Tensor: The diffusion coefficients per batch item (output is of size
[batch size x state size]) for noise_type 'diagonal'
"""
pass

class Encoder(nn.Module):
"""
Encoder module to handle time-series data (as in the case of stochastic data and SDE)
GRU is used to handle mapping to latent space in this case
This class is used only in LatentSDE_Encoder
"""
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size)
self.lin = nn.Linear(hidden_size, output_size)

def forward(self, inp):
out, _ = self.gru(inp)
out = self.lin(out)
return out

class LatentSDE_Encoder(BaseSDESystem):
def __init__(self, data_size, latent_size, context_size, hidden_size, ts, adjoint=False):
"""
LatentSDE_Encoder is a neural network module designed for encoding time-series data into a latent space representation,
which is then used to model the system dynamics using Stochastic Differential Equations (SDEs).

The primary purpose of this class is to transform high-dimensional time-series data into a lower-dimensional latent space
while capturing the underlying stochastic dynamics. This transformation facilitates efficient modeling, prediction, and
inference of complex temporal processes.

Taken from https://github.com/google-research/torchsde/blob/master/examples/latent_sde_lorenz.py and modified to support
NeuroMANCER library

:param data_size: (int) state size of the data
:param latent_size: (int) input latent size for the encoder
:param context_size: (int) size of context vector (output of encoder)
:param hidden_size: (int) size of the hidden layer of encoder
:param ts: (tensor) tensor of timesteps over which data should be predicted

"""
super().__init__()

self.adjoint = adjoint

# Encoder.
self.encoder = Encoder(input_size=data_size, hidden_size=hidden_size, output_size=context_size)
self.qz0_net = nn.Linear(context_size, latent_size + latent_size) #Layer to return mean and variance of the parameterized latent space

# Decoder.
self.f_net = nn.Sequential(
nn.Linear(latent_size + context_size, hidden_size),
nn.Softplus(),
nn.Linear(hidden_size, hidden_size),
nn.Softplus(),
nn.Linear(hidden_size, latent_size),
)
self.h_net = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.Softplus(),
nn.Linear(hidden_size, hidden_size),
nn.Softplus(),
nn.Linear(hidden_size, latent_size),
)
# This needs to be an element-wise function for the SDE to satisfy diagonal noise.
self.g_nets = nn.ModuleList(
[
nn.Sequential(
nn.Linear(1, hidden_size),
nn.Softplus(),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
for _ in range(latent_size)
]
)
self.projector = nn.Linear(latent_size, data_size)

self.pz0_mean = nn.Parameter(torch.zeros(1, latent_size))
self.pz0_logstd = nn.Parameter(torch.zeros(1, latent_size))

self._ctx = None
self.ts = ts

def contextualize(self, ctx):
self._ctx = ctx # A tuple of tensors of sizes (T,), (T, batch_size, d).

def f(self, t, y):
ts, ctx = self._ctx

i = min(torch.searchsorted(ts, t, right=True), len(ts) - 1)

return self.f_net(torch.cat((y, ctx[i]), dim=1))

def h(self, t, y):
return self.h_net(y)

def g(self, t, y): # Diagonal diffusion.
y = torch.split(y, split_size_or_sections=1, dim=1)
out = [g_net_i(y_i) for (g_net_i, y_i) in zip(self.g_nets, y)]
return torch.cat(out, dim=1)

def forward(self, xs):
# Contextualization is only needed for posterior inference.
ctx = self.encoder(torch.flip(xs, dims=(0,)))
ctx = torch.flip(ctx, dims=(0,))
self.contextualize((self.ts, ctx))

qz0_mean, qz0_logstd = self.qz0_net(ctx[0]).chunk(chunks=2, dim=1)
z0 = qz0_mean + qz0_logstd.exp() * torch.randn_like(qz0_mean)
if not self.adjoint:
return z0, xs, self.ts, qz0_mean, qz0_logstd
else:
adjoint_params = (
(ctx,) +
tuple(self.f_net.parameters()) + tuple(self.g_nets.parameters()) + tuple(self.h_net.parameters())
)
return z0, xs, self.ts, qz0_mean, qz0_logstd, adjoint_params

class LatentSDE_Decoder(BaseSDESystem):
"""
Second part of Wrapper for torchsde's Latent SDE class to integrate with Neuromancer. This takes in output of
LatentSDEIntegrator and decodes it back into the "real" data space and also outputs associated Gaussian distributions
to be used in the final loss function.
Please see https://github.com/google-research/torchsde/blob/master/examples/latent_sde_lorenz.py

:param data_size: (int) state size of the data
:param latent_size: (int) input latent size for the encoder
:param noise_std: (float) standard deviation of the Gaussian noise applied during decoding
"""
def __init__(self, data_size, latent_size, noise_std):
super().__init__()
self.noise_std = noise_std
self.pz0_mean = nn.Parameter(torch.zeros(1, latent_size))
self.pz0_logstd = nn.Parameter(torch.zeros(1, latent_size))
self.projector = nn.Linear(latent_size, data_size)

def f(self, t, y):
pass #unused

def g(self, t, y):
pass #unused

def forward(self, xs, zs, log_ratio, qz0_mean, qz0_logstd):
_xs = self.projector(zs)
xs_dist = Normal(loc=_xs, scale=self.noise_std)
log_pxs = xs_dist.log_prob(xs).sum(dim=(0, 2)).mean(dim=0)

qz0 = torch.distributions.Normal(loc=qz0_mean, scale=qz0_logstd.exp())
pz0 = torch.distributions.Normal(loc=self.pz0_mean, scale=self.pz0_logstd.exp())
logqp0 = torch.distributions.kl_divergence(qz0, pz0).sum(dim=1).mean(dim=0)
logqp_path = log_ratio.sum(dim=0).mean(dim=0)
return _xs, log_pxs, logqp0 + logqp_path, log_ratio

"""
---------------------------------- Data Generation Classes, for forward pass only -------------------------------------------
"""
class StochasticLorenzAttractor(BaseSDESystem):
def __init__(self, a: Sequence = (10., 28., 8 / 3), b: Sequence = (.1, .28, .3)):
super().__init__()
self.a = a
self.b = b

def f(self, t, y):
x1, x2, x3 = torch.split(y, split_size_or_sections=(1, 1, 1), dim=1)
a1, a2, a3 = self.a

f1 = a1 * (x2 - x1)
f2 = a2 * x1 - x2 - x1 * x3
f3 = x1 * x2 - a3 * x3
return torch.cat([f1, f2, f3], dim=1)

def g(self, t, y):
x1, x2, x3 = torch.split(y, split_size_or_sections=(1, 1, 1), dim=1)
b1, b2, b3 = self.b

g1 = x1 * b1
g2 = x2 * b2
g3 = x3 * b3
return torch.cat([g1, g2, g3], dim=1)

@torch.no_grad()
def sample(self, x0, ts, noise_std, normalize):
"""Sample data for training. Store data normalization constants if necessary."""
xs = torchsde.sdeint(self, x0, ts)
if normalize:
mean, std = torch.mean(xs, dim=(0, 1)), torch.std(xs, dim=(0, 1))
xs.sub_(mean).div_(std).add_(torch.randn_like(xs) * noise_std)
return xs


class SDECoxIngersollRand(BaseSDESystem):
def __init__(self, alpha: float=0.1,
beta: float=0.05,
sigma: float=0.02):
super().__init__()
self.alpha = alpha
self.beta = beta
self.sigma = sigma

def f(self, t, y):
r = y
return self.alpha * (self.beta - r)

def g(self, t, y):
r = y
return self.sigma * torch.sqrt(torch.abs(r))


class SDEOrnsteinUhlenbeck(BaseSDESystem):
def __init__(self, theta: float = 0.1, sigma: float = 0.2):
super(BaseSDESystem).__init__()
self.theta = theta
self.sigma = sigma

def f(self, t, y):
return -self.theta * y

def g(self, t, y):
return self.sigma


class LotkaVolterraSDE(BaseSDESystem):
def __init__(self, a, b, c, d, g_params):
super().__init__()
self.a = a
self.b = b
self.c = c
self.d = d
self.g_params = g_params

def f(self, t, x):
x1 = x[:,[0]]
x2 = x[:,[1]]
dx1 = self.a * x1 - self.b * x1*x2
dx2 = self.c * x1*x2 - self.d * x2
return torch.cat([dx1, dx2], dim=-1)

def g(self, t, x):
return self.g_params