<a href="https://colab.research.google.com/github/yuanwxu/mbpert/blob/main/mbpert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
import numpy as np
from scipy import linalg
from scipy.integrate import solve_ivp

In [None]:
# These constants are definied the same as in `scipy.integrate.solve_ivp`
SAFETY = 0.9
MIN_FACTOR = 0.2  # Minimum allowed decrease in a step size.
MAX_FACTOR = 10  # Maximum allowed increase in a step size.

class RK45():
  """explicit Runge-Kutta method of order 5(4)"""

  # Butcher tableau 
  # https://github.com/scipy/scipy/blob/b5d8bab88af61d61de09641243848df63380a67f/scipy/integrate/_ivp/rk.py#L280
  n_stages = 6
  error_estimator_order = 4
  C = torch.tensor([0, 1/5, 3/10, 4/5, 8/9, 1])
  A = torch.tensor([
      [0, 0, 0, 0, 0],
      [1/5, 0, 0, 0, 0],
      [3/40, 9/40, 0, 0, 0],
      [44/45, -56/15, 32/9, 0, 0],
      [19372/6561, -25360/2187, 64448/6561, -212/729, 0],
      [9017/3168, -355/33, 46732/5247, 49/176, -5103/18656]
  ])
  B = torch.tensor([35/384, 0, 500/1113, 125/192, -2187/6784, 11/84])
  E = torch.tensor([-71/57600, 0, 71/16695, -71/1920, 17253/339200, -22/525, 1/40])

  def __init__(self, fun, t_span, args, first_step=1e-3):
    # Begining and end time points of integration
    self.t0, self.tf = map(float, t_span)
    
    # This line is from the source code of scipy's `solve_ivp`. It wraps the user
    # function in a lambda to hide additional arguments.
    self.fun = lambda t, x, fun=fun: fun(t, x, *args)

    # Relative and absolute tolerance as in solve_ivp
    self.rtol, self.atol = 1e-3, 1e-6

    # Initial step size, assume direction of integration is always positive, i.e, h > 0
    self.h0 = first_step

    # For use in updating step size
    # https://en.wikipedia.org/wiki/Adaptive_step_size#Embedded_error_estimates
    self.error_exponent = -1 / (self.error_estimator_order + 1)


  # This is a rewrite of `rk_step` in scipy.integrate.solve_ivp source code using Pytorch tensors
  def _step_rk(self, t, y, f, h, K):
    """Perform a single Runge-Kutta step.
    Parameters as in https://github.com/scipy/scipy/blob/b5d8bab88af61d61de09641243848df63380a67f/scipy/integrate/_ivp/rk.py#L14

    Returns
    -------
    y_new : ndarray, shape (n,)
        Solution at t + h computed with a higher accuracy.
    f_new : ndarray, shape (n,)
        Derivative ``fun(t + h, y_new)``.
    """
    K[0] = f
    for s, (a, c) in enumerate(zip(RK45.A[1:], RK45.C[1:]), start=1):
        dy = torch.matmul(torch.t(K[:s]), a[:s]) * h
        K[s] = self.fun(t + c * h, y + dy)

    y_new = y + h * torch.matmul(torch.t(K[:-1]), RK45.B)
    f_new = self.fun(t + h, y_new)

    K[-1] = f_new

    return y_new, f_new    

  # These two methods are equivalent to those in scipy `solve_ivp`
  def _estimate_error(self, K, h):
    return torch.matmul(torch.t(K), self.E) * h

  # NOTE: the `.item()` following `norm()` is essential, omitting it will get
  # runtime error "inplace modification of variables requiring gradient" that
  # is hard to debug!
  def _estimate_error_norm(self, K, h, scale):
    x = self._estimate_error(K, h) / scale
    return torch.linalg.norm(x).item() / x.numel() ** 0.5 
    

  # This is a rewrite of `_step_impl` using Pytorch tensors
  def _step_impl(self):
    # Current time, state and derivative 
    t, y, f = self.t, self.y, self.f

    rtol = self.rtol
    atol = self.atol
    h = self.h # current step size

    step_accepted = False
    step_rejected = False

    while not step_accepted:
      t_new = t + h

      if (t_new - self.tf) > 0:
        t_new = self.tf

      h = t_new - t

      y_new, f_new = self._step_rk(t, y, f, h, self.K)
      scale = atol + torch.maximum(torch.abs(y), torch.abs(y_new)) * rtol
      error_norm = self._estimate_error_norm(self.K, h, scale)

      if error_norm < 1:
          if error_norm == 0:
              factor = MAX_FACTOR
          else:
              factor = min(MAX_FACTOR, SAFETY * error_norm ** self.error_exponent)

          if step_rejected:
              factor = min(1, factor)

          h *= factor

          step_accepted = True
      else:
          h *= max(MIN_FACTOR, SAFETY * error_norm ** self.error_exponent)
          step_rejected = True

    # Update time, state, derivative and step size
    self.t, self.y, self.f, self.h = t_new, y_new, f_new, h


  def solve(self, y):
    # Initial time, step size, state and derivative dy/dt = f(t,y)
    self.t = self.t0
    self.h = self.h0
    self.y = y
    self.f = self.fun(self.t0, self.y)

    # Storage for RK stages
    self.K = torch.empty((RK45.n_stages + 1, y.numel()), dtype=y.dtype)
    
    while self.t < self.tf:
      self._step_impl()

    if self.f.abs().mean().item() > 1e-3:
      print(f'Mean absolute derivative at t={self.t}: {self.f.abs().mean().item()}')
      # print("ODE solver did not converge to steady state solution in the given time range.")

    return self.y
    

  # # Stop early when the prediction is unbounded
  # if torch.any(y_new.isnan() | 
  #               (y_new.abs() > 1e6) |
  #               (f_new.abs() >1e2)):
  #   print("Unbounded function prediction, stop early so not integarating the whole range.")
  #   break

      


In [None]:
# A helper to reshape a tensor in Fortran-like order
# Reference: https://stackoverflow.com/questions/63960352/reshaping-order-in-pytorch-fortran-like-index-ordering
def reshape_fortran(x, shape):
    if len(x.shape) > 0:
        x = x.permute(*reversed(range(len(x.shape))))
    return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))

In [None]:
def glvp(t, x, r, A, eps, p):
    """Define generalized lotka-volterra dynamic system with perturbations
       To vectorized over conditions, create a long state vector holding 
       species abundances at time t across all conditions

       x --- (n_species*n_conds,) Species (dimensionless) absolute abundances 
                                  under all pert conditions. Formed by contiguous
                                  blocks of species abundances, where each block
                                  corresponds to a condition
       r --- (n_species) Growth rate
       A --- (n_species, n_species) Species interaction matrix
       eps --- (n_species,) Species susceptibility to perturbation
    """
    x = reshape_fortran(x, p.shape)
    out = x * (r[:, None] + A @ x + eps[:, None] * p)
    return reshape_fortran(out, (-1,))

In [None]:
# numpy version
def glvp2(t, x, r, A, eps, p):
    """To vectorize over conditions, create a long state vector of shape (n_specis*n_conds,)
       of contiguous blocks of species abundance, where each block corresponds to a condition
    """
    x = x.reshape(p.shape, order='F')
    out = x * (r[:, np.newaxis] + A @ x + eps[:, np.newaxis] * p)
    return np.ravel(out, order='F')

In [None]:
class MBPert(nn.Module):
  def __init__(self, U):
    """U (torch.Tensor): perturbation matrix of shape (n_species, n_conds)"""
    super().__init__()
    n_species = U.shape[0]
    self.r = nn.Parameter(torch.rand((n_species, )))
    self.eps = nn.Parameter(torch.randn(n_species, ))

    # Proper initialization of interaction matrix for stability
    self.A = 1 / (2 * n_species**(0.5)) * torch.randn(n_species, n_species)
    self.A = nn.Parameter(self.A.fill_diagonal_(-1))
    # mask = ~torch.eye(n_species, dtype=torch.bool)
    # self.A = -torch.eye(n_species) # making diag elements -1
    # self.A[mask] = 1 / (2 * n_species**(0.5)) * torch.randn(n_species**2 - n_species, requires_grad=True)
    # self.A = nn.Parameter(self.A)

    self.p = U
    self.solver = RK45(glvp, [0,20], args=(self.r, self.A, self.eps, self.p))

  def forward(self, x):
    return self.solver.solve(x)
    

**Test on toy example**

In [None]:
def get_ode_params(n_species, p, seed=None):
    """Get ODE parameters suited for simulation. 
       
    Args:
        n_species (int): number of species
        p (nparray): perturbation matrix returned from pert_mat() 
    
    Returns:
        (growth rate r, interaction matrix A, susceptibility vector eps,
         steady state solutions across all pert conditions)
    """
    if n_species != p.shape[0]:
        raise ValueError(
            "Number of species does not match first dimension of pert matrix.")

    rng = np.random.default_rng(seed)
    i = 0
    while True:
        i += 1
        # Diagonal of A: a_{ii} = -1, off-diag a_{ij} ~ N(0,1/(4n))
        A = rng.normal(0, 1 / (2 * n_species**(0.5)), (n_species, n_species))
        np.fill_diagonal(A, -1)

        # r: Unif(0, 1)
        r = rng.random((n_species, ))

        # eps: Unif(-0.2,1)
        eps = rng.uniform(-0.2, 1, (n_species, ))

        # Steady state solution across all pert conditions
        X_ss = -linalg.inv(A) @ (r[:, np.newaxis] + eps[:, np.newaxis] * p)

        # Check all solutions are positive
        if np.all(X_ss > 0):
            break

        if i >= 500:
            print(
                f'''Failed to find an all positive steady-state solutions across all 
                perturbation conditions after {i} attempts. Return a steady state 
                solution with negative entries.''')
            break

    return (r, A, eps, X_ss)  # namedtuple?

In [None]:
x0 = reshape_fortran(0.2 * torch.ones((3,3)), (-1,)) # 3 species, 3 conditions
p = torch.eye(3) # 3 single-species perturbations
r, A, eps, X_ss = get_ode_params(3, p.numpy(), seed=0)
x_true = torch.from_numpy(X_ss.astype('float32').reshape(-1, order='F'))

In [None]:
print(f'x0: {x0}\n')
print(f'p: {p}\n')
print(f'x_true: {x_true}\n')

x0: tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000])

p: tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

x_true: tensor([1.9095, 0.9769, 0.9886, 1.0155, 0.7485, 0.5896, 1.1531, 0.9953, 1.3845])



In [None]:
mbpert = MBPert(p)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(mbpert.parameters()) # Adam is much better than SGD in this case
#optimizer = torch.optim.SGD(mbpert.parameters(), lr=1e-3)

In [None]:
for i in range(400):
    # Forward pass
    x_pred = mbpert(x0)

    # Compute and print loss
    loss = criterion(x_pred, x_true)
    if i % 50 == 49:
        print(f"i={i}, loss={loss.item()}, x_pred={x_pred}") 

    # Zero gradients, perform a backward pass, and update parameters
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

ODE solver did not converge to steady state solution in the given time range.
i=49, loss=0.3762568235397339, x_pred=tensor([1.1089e+00, 2.8623e-03, 9.3774e-01, 1.1826e+00, 1.9877e-06, 9.7273e-01,
        1.3347e+00, 9.8607e-04, 1.1906e+00], grad_fn=<AddBackward0>)
ODE solver did not converge to steady state solution in the given time range.
ODE solver did not converge to steady state solution in the given time range.
ODE solver did not converge to steady state solution in the given time range.
ODE solver did not converge to steady state solution in the given time range.
ODE solver did not converge to steady state solution in the given time range.
ODE solver did not converge to steady state solution in the given time range.
ODE solver did not converge to steady state solution in the given time range.
ODE solver did not converge to steady state solution in the given time range.
ODE solver did not converge to steady state solution in the given time range.
ODE solver did not converge to st

In [None]:
print(x_true)
print(f'Estimated A: {mbpert.A}\nTrue: {A}\n')
print(f'Estimated r:{mbpert.r}\nTrue: {r}\n')
print(f'Estimated eps: {mbpert.eps}\nTrue: {eps}\n')

tensor([1.9095, 0.9769, 0.9886, 1.0155, 0.7485, 0.5896, 1.1531, 0.9953, 1.3845])
Estimated A: Parameter containing:
tensor([[-0.6572, -0.0185,  0.2996],
        [ 0.1113, -1.2442,  0.4821],
        [ 0.2032, -0.1336, -1.0052]], requires_grad=True)
True: [[-1.         -0.03813539  0.18487409]
 [ 0.03028206 -1.          0.1043835 ]
 [ 0.37643239  0.27339872 -1.        ]]

Estimated r:Parameter containing:
tensor([0.3427, 0.3189, 0.5595], requires_grad=True)
True: [0.93507242 0.81585355 0.0027385 ]

Estimated eps: Parameter containing:
tensor([0.7611, 0.8222, 0.8329], requires_grad=True)
True: [ 0.82888513 -0.15969731  0.67558654]



**Test ODE solver**

In [None]:
r.dtype, p.dtype

(dtype('float64'), torch.float32)

In [None]:
r, A, eps, X_ss = get_ode_params(3, p.numpy(), seed=0)
solver = RK45(glvp, [0,20], args=(torch.from_numpy(r.astype('float32')), torch.from_numpy(A.astype('float32')), torch.from_numpy(eps.astype('float32')), p))

In [None]:
x_pred = solver.solve(x0)

In [None]:
x_pred

tensor([1.9110, 0.9769, 0.9880, 1.0154, 0.7484, 0.5894, 1.1531, 0.9953, 1.3845])

In [None]:
# Now use scipy's solve_ivp
sol = solve_ivp(glvp2, [0, 20], x0.numpy(), args=(r, A, eps, p.numpy()))

In [None]:
sol.y[:,-1] # should be similar to x_pred

array([1.91097956, 0.97688841, 0.98805588, 1.01544636, 0.74836967,
       0.58935627, 1.15307287, 0.99528851, 1.38448868])

**General testing**

In [None]:
torch.nn.Parameter(torch.randn((2,2)))

Parameter containing:
tensor([[-0.9155,  0.6164],
        [ 0.1008,  0.7246]], requires_grad=True)

In [None]:
A = nn.Parameter(torch.randn(3, 3))

In [None]:
A.numel()

9

In [None]:
A.data.fill_diagonal_(-1)

tensor([[-1.0000,  1.3261,  1.2774],
        [-0.9245, -1.0000, -0.0726],
        [-1.1760,  0.0340, -1.0000]])

In [None]:
A.dtype

torch.float32

In [None]:
a = torch.tensor([1,1,1])
a

tensor([1, 1, 1])

In [None]:
a[:,None].shape

torch.Size([3, 1])

In [None]:
b = torch.arange(1,16).reshape(3,5)

In [None]:
b

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15]])

In [None]:
a[:,None] + b

tensor([[ 2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16]])

In [None]:
(a[:,None] + 1) * b

tensor([[ 2,  4,  6,  8, 10],
        [12, 14, 16, 18, 20],
        [22, 24, 26, 28, 30]])

In [None]:
torch.abs(torch.tensor(torch.inf)) > 1e6

tensor(True)

In [None]:
a = torch.rand(3,3)

In [None]:
a

tensor([[0.4725, 0.2918, 0.8230],
        [0.3189, 0.0867, 0.9742],
        [0.6353, 0.7285, 0.5710]])

In [None]:
b = torch.t(a)

In [None]:
a[0,0] = 99
a

tensor([[9.9000e+01, 2.9181e-01, 8.2298e-01],
        [3.1887e-01, 8.6709e-02, 9.7421e-01],
        [6.3533e-01, 7.2852e-01, 5.7097e-01]])

In [None]:
b

tensor([[9.9000e+01, 3.1887e-01, 6.3533e-01],
        [2.9181e-01, 8.6709e-02, 7.2852e-01],
        [8.2298e-01, 9.7421e-01, 5.7097e-01]])