import numpy as np import torch from torch import tensor from torchdiffeq import odeint_adjoint class ODEFunc(torch.nn.Module): def __init__(self, gamma, ca, la, cr, lr): super(ODEFunc, self).__init__() self.gamma = gamma self.ca = ca self.cr = cr self.la = la self.lr = lr def df(self, t, Z): N = int(len(Z) / 4) Z = torch.reshape(Z, (4 * N,)) dZ = torch.zeros_like(Z) X = Z[:N] Y = Z[N:2 * N] Vx = Z[2 * N:3 * N] Vy = Z[3 * N:] Vxdiff = Vx.reshape(-1, 1) - Vx Vydiff = Vy.reshape(-1, 1) - Vy Xdiff = X.reshape(-1, 1) - X Ydiff = Y.reshape(-1, 1) - Y R2 = Xdiff ** 2 + Ydiff ** 2 WV = 1 / ((1 + R2) ** gamma) R = torch.sqrt(R2) WR = (ca / la) * np.exp(-R / la) - (cr / lr) * np.exp(-R / la) WR = WR / R WR[torch.abs(WR) == np.inf] = 0.0 dZ[:N] = Vx dZ[N:2 * N] = Vy dZ[2 * N:3 * N] = -torch.sum(Vxdiff * WV, dim=1) * (1 / N) - torch.sum(Xdiff * WR, dim=1) * (1 / N) dZ[3 * N:] = -torch.sum(Vydiff * WV, dim=1) * (1 / N) - torch.sum(Ydiff * WR, dim=1) * (1 / N) return dZ class Model(torch.nn.Module): def __init__(self, t, gamma, ca, la, cr, lr): super(Model, self).__init__() self.t = t self.fun = ODEFunc(gamma, ca, la, cr, lr) def forward(self, Z0): res = odeint_adjoint(self.fun, Z0, self.t) return res gamma = 0.15 ca = 200.0 la = 100.0 cr = 500.0 lr = 2.0 N = 20 Znp = np.random.uniform(0, 10, 4 * N) Z = tensor(Znp, dtype=torch.float32) tnp = np.arange(0, 10, 0.1) t = tensor(tnp, dtype=torch.float32) model = Model(t, gamma, ca, la, cr, lr) res = model(Z)