<a href="https://colab.research.google.com/github/yuanwxu/mbpert/blob/main/mbpert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
import numpy as np
from scipy import linalg
from scipy.integrate import solve_ivp

In [None]:
class RK45():
  """explicit Runge-Kutta method of order 5(4)"""

  # Butcher tableau 
  # https://github.com/scipy/scipy/blob/b5d8bab88af61d61de09641243848df63380a67f/scipy/integrate/_ivp/rk.py#L280
  n_stages = 6
  C = torch.tensor([0, 1/5, 3/10, 4/5, 8/9, 1])
  A = torch.tensor([
      [0, 0, 0, 0, 0],
      [1/5, 0, 0, 0, 0],
      [3/40, 9/40, 0, 0, 0],
      [44/45, -56/15, 32/9, 0, 0],
      [19372/6561, -25360/2187, 64448/6561, -212/729, 0],
      [9017/3168, -355/33, 46732/5247, 49/176, -5103/18656]
  ])
  B = torch.tensor([35/384, 0, 500/1113, 125/192, -2187/6784, 11/84])

  def __init__(self, fun, t_span, step_size, args):
    self.t0, self.tf = map(float, t_span)
    
    # This line is from the source code of scipy's `solve_ivp`. It wraps the user
    # function in a lambda to hide additional arguments.
    self.fun = lambda t, x, fun=fun: fun(t, x, *args)

    self.h = step_size


  # This is a rewrite of `rk_step` in scipy.integrate.solve_ivp source code using Pytorch tensors
  def _step(self, t, y, f, K):
    """Perform a single Runge-Kutta step.
    Parameters as in https://github.com/scipy/scipy/blob/b5d8bab88af61d61de09641243848df63380a67f/scipy/integrate/_ivp/rk.py#L14

    Returns
    -------
    y_new : ndarray, shape (n,)
        Solution at t + h computed with a higher accuracy.
    f_new : ndarray, shape (n,)
        Derivative ``fun(t + h, y_new)``.
    """
    K[0] = f
    for s, (a, c) in enumerate(zip(RK45.A[1:], RK45.C[1:]), start=1):
        dy = torch.matmul(K[:s].T, a[:s]) * self.h
        K[s] = self.fun(t + c * self.h, y + dy)

    y_new = y + self.h * torch.matmul(K[:-1].T, RK45.B)
    f_new = self.fun(t + self.h, y_new)

    K[-1] = f_new

    return y_new, f_new    


  def solve(self, y):
    # Storage for RK stages
    K = torch.empty((RK45.n_stages + 1, y.numel()), dtype=y.dtype)

    # Initial time and derivative
    t = self.t0
    f = self.fun(t, y)
    #print(f"Initial t={t}, f={f}, x0={y}\n\n")

    while t <= self.tf:
      y_new, f_new = self._step(t, y, f, K)
      
      #print(f"Next f={f_new}, state={y_new}\n")

      # Stop early when the prediction is unbounded
      if torch.any(y_new.isnan() | 
                   (y_new.abs() > 1e6) |
                   (f_new.abs() >1e2)):
        print("Unbounded function prediction, stop early so not integarating the whole range.")
        break

      t += self.h
      y = y_new
      f = f_new

    #if f_new.detach().abs().mean().item() >= 1e-3:
    #  print("ODE solver did not converge to steady state solution in the given time range and step size.")

    return y


In [None]:
# A helper to reshape a tensor in Fortran-like order
# Reference: https://stackoverflow.com/questions/63960352/reshaping-order-in-pytorch-fortran-like-index-ordering
def reshape_fortran(x, shape):
    if len(x.shape) > 0:
        x = x.permute(*reversed(range(len(x.shape))))
    return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))

In [None]:
def glvp(t, x, r, A, eps, p):
    """Define generalized lotka-volterra dynamic system with perturbations
       To vectorized over conditions, create a long state vector holding 
       species abundances at time t across all conditions

       x --- (n_species*n_conds,) Species (dimensionless) absolute abundances 
                                  under all pert conditions. Formed by contiguous
                                  blocks of species abundances, where each block
                                  corresponds to a condition
       r --- (n_species) Growth rate
       A --- (n_species, n_species) Species interaction matrix
       eps --- (n_species,) Species susceptibility to perturbation
    """
    x = reshape_fortran(x, p.shape)
    out = x * (r[:, None] + A @ x + eps[:, None] * p)
    return reshape_fortran(out, (-1,))

In [None]:
# numpy version
def glvp2(t, x, r, A, eps, p):
    """To vectorize over conditions, create a long state vector of shape (n_specis*n_conds,)
       of contiguous blocks of species abundance, where each block corresponds to a condition
    """
    x = x.reshape(p.shape, order='F')
    out = x * (r[:, np.newaxis] + A @ x + eps[:, np.newaxis] * p)
    return np.ravel(out, order='F')

In [None]:
class MBPert(nn.Module):
  def __init__(self, U):
    """U (torch.Tensor): perturbation matrix of shape (n_species, n_conds)"""
    super().__init__()
    n_species = U.shape[0]
    self.r = nn.Parameter(torch.rand((n_species, )))
    self.eps = nn.Parameter(torch.randn(n_species, ))

    # Proper initialization of interaction matrix for stability
    self.A = 1 / (2 * n_species**(0.5)) * torch.randn(n_species, n_species)
    self.A = nn.Parameter(self.A.fill_diagonal_(-1))
    # mask = ~torch.eye(n_species, dtype=torch.bool)
    # self.A = -torch.eye(n_species) # making diag elements -1
    # self.A[mask] = 1 / (2 * n_species**(0.5)) * torch.randn(n_species**2 - n_species, requires_grad=True)
    # self.A = nn.Parameter(self.A)

    self.p = U
    self.solver = RK45(glvp, [0,20], 1e-1, args=(self.r, self.A, self.eps, self.p))

  def forward(self, x):
    return self.solver.solve(x)
    

**Test on toy example**

In [None]:
def get_ode_params(n_species, p, seed=None):
    """Get ODE parameters suited for simulation. 
       
    Args:
        n_species (int): number of species
        p (nparray): perturbation matrix returned from pert_mat() 
    
    Returns:
        (growth rate r, interaction matrix A, susceptibility vector eps,
         steady state solutions across all pert conditions)
    """
    if n_species != p.shape[0]:
        raise ValueError(
            "Number of species does not match first dimension of pert matrix.")

    rng = np.random.default_rng(seed)
    i = 0
    while True:
        i += 1
        # Diagonal of A: a_{ii} = -1, off-diag a_{ij} ~ N(0,1/(4n))
        A = rng.normal(0, 1 / (2 * n_species**(0.5)), (n_species, n_species))
        np.fill_diagonal(A, -1)

        # r: Unif(0, 1)
        r = rng.random((n_species, ))

        # eps: Unif(-0.2,1)
        eps = rng.uniform(-0.2, 1, (n_species, ))

        # Steady state solution across all pert conditions
        X_ss = -linalg.inv(A) @ (r[:, np.newaxis] + eps[:, np.newaxis] * p)

        # Check all solutions are positive
        if np.all(X_ss > 0):
            break

        if i >= 500:
            print(
                f'''Failed to find an all positive steady-state solutions across all 
                perturbation conditions after {i} attempts. Return a steady state 
                solution with negative entries.''')
            break

    return (r, A, eps, X_ss)  # namedtuple?

In [None]:
x0 = reshape_fortran(0.2 * torch.ones((3,3)), (-1,)) # 3 species, 3 conditions
p = torch.eye(3) # 3 single-species perturbations
r, A, eps, X_ss = get_ode_params(3, p.numpy(), seed=0)
x_true = torch.from_numpy(X_ss.astype('float32').reshape(-1, order='F'))

In [None]:
print(f'x0: {x0}\n')
print(f'p: {p}\n')
print(f'x_true: {x_true}\n')

x0: tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000])

p: tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

x_true: tensor([1.9095, 0.9769, 0.9886, 1.0155, 0.7485, 0.5896, 1.1531, 0.9953, 1.3845])



In [None]:
mbpert = MBPert(p)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(mbpert.parameters()) # Adam is much better than SGD in this case
#optimizer = torch.optim.SGD(mbpert.parameters(), lr=1e-3)

In [None]:
for i in range(400):
    # Forward pass
    x_pred = mbpert(x0)

    # Compute and print loss
    loss = criterion(x_pred, x_true)
    if i % 50 == 49:
        print(f"i={i}, loss={loss.item()}, x_pred={x_pred}") 

    # Zero gradients, perform a backward pass, and update parameters
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

i=49, loss=0.6304702162742615, x_pred=tensor([0.1588, 0.9540, 0.5373, 0.1130, 0.1930, 0.7643, 0.0784, 0.9718, 1.0748],
       grad_fn=<AddBackward0>)
i=99, loss=0.4330136477947235, x_pred=tensor([0.4588, 1.0868, 0.4836, 0.2853, 0.3504, 0.7058, 0.3092, 1.1315, 1.0851],
       grad_fn=<AddBackward0>)
i=149, loss=0.2534507215023041, x_pred=tensor([0.7812, 1.2150, 0.5827, 0.4884, 0.5018, 0.7762, 0.5908, 1.2747, 1.2474],
       grad_fn=<AddBackward0>)
i=199, loss=0.1249798834323883, x_pred=tensor([1.1154, 1.2279, 0.6764, 0.7104, 0.5682, 0.8224, 0.8894, 1.2833, 1.3458],
       grad_fn=<AddBackward0>)
i=249, loss=0.05721984803676605, x_pred=tensor([1.3982, 1.1691, 0.7443, 0.8972, 0.5695, 0.8403, 1.1294, 1.2131, 1.3793],
       grad_fn=<AddBackward0>)
i=299, loss=0.03296193107962608, x_pred=tensor([1.5817, 1.1019, 0.7854, 1.0121, 0.5533, 0.8387, 1.2646, 1.1358, 1.3787],
       grad_fn=<AddBackward0>)
i=349, loss=0.02518528141081333, x_pred=tensor([1.6744, 1.0637, 0.8109, 1.0621, 0.5504, 0.8296

In [None]:
print(x_true)
print(f'Estimated A: {mbpert.A}\nTrue: {A}\n')
print(f'Estimated r:{mbpert.r}\nTrue: {r}\n')
print(f'Estimated eps: {mbpert.eps}\nTrue: {eps}\n')

tensor([1.9095, 0.9769, 0.9886, 1.0155, 0.7485, 0.5896, 1.1531, 0.9953, 1.3845])
Estimated A: Parameter containing:
tensor([[-0.6572, -0.0185,  0.2996],
        [ 0.1113, -1.2442,  0.4821],
        [ 0.2032, -0.1336, -1.0052]], requires_grad=True)
True: [[-1.         -0.03813539  0.18487409]
 [ 0.03028206 -1.          0.1043835 ]
 [ 0.37643239  0.27339872 -1.        ]]

Estimated r:Parameter containing:
tensor([0.3427, 0.3189, 0.5595], requires_grad=True)
True: [0.93507242 0.81585355 0.0027385 ]

Estimated eps: Parameter containing:
tensor([0.7611, 0.8222, 0.8329], requires_grad=True)
True: [ 0.82888513 -0.15969731  0.67558654]



**Test ODE solver**

In [None]:
r.dtype, p.dtype

(dtype('float64'), torch.float32)

In [None]:
r, A, eps, X_ss = get_ode_params(3, p.numpy(), seed=0)
solver = RK45(glvp, [0,20], 1e-2, args=(torch.from_numpy(r.astype('float32')), torch.from_numpy(A.astype('float32')), torch.from_numpy(eps.astype('float32')), p))

In [None]:
x_pred = solver.solve(x0)

In [None]:
x_pred

tensor([1.9095, 0.9769, 0.9886, 1.0154, 0.7484, 0.5894, 1.1531, 0.9953, 1.3845])

In [None]:
# Now use scipy's solve_ivp
sol = solve_ivp(glvp2, [0, 20], x0.numpy(), args=(r, A, eps, p.numpy()))

In [None]:
sol.y[:,-1] # should be similar to x_pred

array([1.91097956, 0.97688841, 0.98805588, 1.01544636, 0.74836967,
       0.58935627, 1.15307287, 0.99528851, 1.38448868])

**General testing**

In [None]:
torch.nn.Parameter(torch.randn((2,2)))

Parameter containing:
tensor([[-0.9155,  0.6164],
        [ 0.1008,  0.7246]], requires_grad=True)

In [None]:
A = nn.Parameter(torch.randn(3, 3))

In [None]:
A.numel()

9

In [None]:
A.data.fill_diagonal_(-1)

tensor([[-1.0000,  1.3261,  1.2774],
        [-0.9245, -1.0000, -0.0726],
        [-1.1760,  0.0340, -1.0000]])

In [None]:
A.dtype

torch.float32

In [None]:
a = torch.tensor([1,1,1])
a

tensor([1, 1, 1])

In [None]:
a[:,None].shape

torch.Size([3, 1])

In [None]:
b = torch.arange(1,16).reshape(3,5)

In [None]:
b

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15]])

In [None]:
a[:,None] + b

tensor([[ 2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16]])

In [None]:
(a[:,None] + 1) * b

tensor([[ 2,  4,  6,  8, 10],
        [12, 14, 16, 18, 20],
        [22, 24, 26, 28, 30]])

In [None]:
torch.abs(torch.tensor(torch.inf)) > 1e6

tensor(True)