-
Notifications
You must be signed in to change notification settings - Fork 107
/
ode.py
483 lines (402 loc) · 16.8 KB
/
ode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from neuromancer.dynamics.library import FunctionLibrary
class SSM(nn.Module):
"""
Baseline class for (neural) state space model (SSM)
Implements discrete-time dynamical system:
x_k+1 = fx(x_k) + fu(u_k) + fd(d_k)
with variables:
x_k - states
u_k - control inputs
d_k - disturbances
"""
def __init__(self, fx, fu, nx, nu, fd=None, nd=0):
"""
:param fx: (nn.Module) state transition dynamics
:param fu: (nn.Module) input dynamics
:param nx: (int) number of states
:param nu: (int) number of inputs
:param fd: (nn.Module) disturbance dynamics
:param nd: (int) number of disturbances
"""
super().__init__()
self.fx, self.fu, self.fd = fx, fu, fd
self.nx, self.nu, self.nd = nx, nu, nd
self.in_features, self.out_features = nx+nu+nd, nx
def forward(self, x, u, d=None):
"""
:param x: (torch.Tensor, shape=[batchsize, nx])
:param u: (torch.Tensor, shape=[batchsize, nu])
:param d: (torch.Tensor, shape=[batchsize, nd])
:return: (torch.Tensor, shape=[batchsize, outsize])
"""
assert len(x.shape) == 2
assert len(u.shape) == 2
# state space model
x = self.fx(x) + self.fu(u)
# add disturbance dynamics
if self.fd is not None and d is not None:
assert len(d.shape) == 2
x += self.fd(d)
return x
class ODESystem(nn.Module, ABC):
"""
Class for defining RHS of arbitrary ODE functions,
can be mix-and-matched according to expert knowledge.
"""
def __init__(self, insize, outsize):
super().__init__()
self.in_features, self.out_features = insize, outsize
self.nx = outsize
self.nu = insize - outsize
@abstractmethod
def ode_equations(self, x, *args):
pass
def forward(self, x, *args):
assert len(x.shape) == 2
return self.ode_equations(x, *args)
class GeneralNetworkedODE(ODESystem):
"""
Coupled nonlinear dynamical system with heterogeneous agents. This class acts as an
aggregator for multiple interacting physics that contribute to the dynamics of one
or more agents.
"""
def __init__(self, map = None,
agents = None,
couplings = None,
insize = None,
outsize = None,
inductive_bias = "additive"
):
"""
:param map: mapping between state index and agent state name(s)
:param agents: list of ordered dicts, one per agent.
:param couplings: list of blocks. one per interaction type.
:param insize: dimensionality of input, including disturbances and control
:param outsize: dimensionality of output, just for agent evolution
:param inductive_bias: selection of inductive bias for ODE. additive or compositional
"""
super().__init__(insize=insize, outsize=outsize)
# Composition of network:
self.map = map
self.agents = nn.ModuleList(agents)
self.couplings = nn.ModuleList(couplings)
self.insize = insize
self.outsize = outsize
self.inductive_bias = inductive_bias
assert len(self.map) == len(self.agents)
def ode_equations(self, x, *args):
"""
Select the inductive bias to use for the problem:
- Additive: f(x_i) + sum(g(x_i,x_j))
- General: f(x_i, sum(g(x_i,x_j)))
- Composed: f(sum(g(x_i,x_j)))
"""
if self.inductive_bias == "additive":
dx = self.intrinsic_physics(x, *args) + self.coupling_physics(x, *args)
elif self.inductive_bias == "general":
#dx = self.intrinsic_physics(x,self.coupling_physics(x))
raise Exception("General RHS not implemented.")
elif self.inductive_bias == "compositional":
dx = self.intrinsic_physics(self.coupling_physics(x, *args), *args)
else:
raise Exception("No inductive bias match.")
return dx[:, :self.outsize]
def intrinsic_physics(self, x, *args):
"""
Calculate and return the contribution from all agents' intrinsic physics
"""
dx = torch.tensor([]) # initialize empty to avoid indexing tedium
features = torch.cat([x, *args], dim=-1)
# loop over agents and calculate contribution from intrinsic physics
for idx, agent_dict in enumerate(self.map):
dx = torch.cat((dx, self.agents[idx](features[:, list(agent_dict.values())])), -1)
return dx
def coupling_physics(self, x, *args):
"""
This coupling physics assumes that each coupling physics nn.Module contains the
connection information, including what agents are connected and if the connection
is symmetric.
"""
dx = torch.zeros_like(x)
features = torch.cat([x, *args], dim=-1)
# first loop over coupling physics listed in self.couplings
for physics in self.couplings:
# for each physics in self.couplings, loop over the pins and add contribution to dx
for pin in physics.pins:
send = self.map[pin[0]][physics.feature_name]
receive = self.map[pin[1]][physics.feature_name]
contribution = physics(features[:, [send, receive]])
dx[:, [send]] += contribution
if physics.symmetric:
dx[:, [receive]] -= contribution
return dx
class SINDy(ODESystem):
"""
Sparse Identification of Nonlinear Dynamics
Reference: https://www.pnas.org/doi/10.1073/pnas.1517384113
"""
def __init__(
self,
library,
threshold=1e-2
):
"""
:param library: (FunctionLibrary) the library of candidate functions
:param threshold: (float) all functions with coefficients lower than this are omitted
"""
assert isinstance(library, FunctionLibrary), "Must be valid library"
super().__init__(library.shape[1], library.shape[1])
self.library = library
self.threshold = threshold
init_coef = torch.rand(self.library.shape)
self.coef = torch.nn.Parameter(init_coef, requires_grad=True)
self.float()
def ode_equations(self, x):
"""
:param x: (torch.tensor) time series data
"""
assert x.ndim == 2, "Input must not be empty"
assert x.shape[1] == self.library.shape[1], "Must have same number of states as insize"
lib_eval = self.library.evaluate(x)
output = torch.matmul(lib_eval, self.coef)
return output
def __str__(self):
"""
return: (str) a list of the linear combinations of candidate functions for each state variable
"""
f_names = self.library.__str__()
f_names = f_names.split(", ")
return_str = ""
for i in range(self.library.shape[1]):
return_str += f"dx{i}/dt = "
for j in range(len(f_names)):
coef = self.coef[j, i]
if torch.abs(coef) > self.threshold:
func = f_names[j]
return_str += f"{coef:.3f}*{func} + "
return_str = return_str[:-2]
return_str += "\n"
return return_str
class TwoTankParam(ODESystem):
def __init__(self, insize=4, outsize=2):
"""
:param insize:
:param outsize:
"""
super().__init__(insize=insize, outsize=outsize)
self.c1 = nn.Parameter(torch.tensor([0.1]), requires_grad=True)
self.c2 = nn.Parameter(torch.tensor([0.1]), requires_grad=True)
def ode_equations(self, x, u):
# heights in tanks
h1 = torch.clip(x[:, [0]], min=0, max=1.0)
h2 = torch.clip(x[:, [1]], min=0, max=1.0)
# Inputs (2): pump and valve
pump = torch.clip(u[:, [0]], min=0, max=1.0)
valve = torch.clip(u[:, [1]], min=0, max=1.0)
# equations
dhdt1 = self.c1 * (1.0 - valve) * pump - self.c2 * torch.sqrt(h1)
dhdt2 = self.c1 * valve * pump + self.c2 * torch.sqrt(h1) - self.c2 * torch.sqrt(h2)
return torch.cat([dhdt1, dhdt2], dim=-1)
class DuffingParam(ODESystem):
def __init__(self, insize=3, outsize=2):
"""
:param insize:
:param outsize:
"""
super().__init__(insize=insize, outsize=outsize)
self.alpha = nn.Parameter(torch.tensor([1.0]), requires_grad=False)
self.beta = nn.Parameter(torch.tensor([5.0]), requires_grad=False)
self.delta = nn.Parameter(torch.tensor([0.02]), requires_grad=False)
self.gamma = nn.Parameter(torch.tensor([8.0]), requires_grad=False)
self.omega = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
def ode_equations(self, x, u):
# heights in tanks
x0 = x[:, [0]] # (# batches,1)
x1 = x[:, [1]]
t = u
# equations
dx0dt = x1
dx1dt = -self.delta*x1 - self.alpha*x0 - self.beta*x0**3 + self.gamma*torch.cos(self.omega*t)
return torch.cat([dx0dt, dx1dt], dim=-1)
class BrusselatorParam(ODESystem):
def __init__(self, insize=2, outsize=2):
"""
:param insize:
:param outsize:
"""
super().__init__(insize=insize, outsize=outsize)
self.alpha = nn.Parameter(torch.tensor([5.0]), requires_grad=True)
self.beta = nn.Parameter(torch.tensor([5.0]), requires_grad=True)
def ode_equations(self, x):
x1 = x[:, [0]]
x2 = x[:, [-1]]
dx1 = self.alpha + x2*x1**2 - self.beta*x1 - x1
dx2 = self.beta*x1 - x2*x1**2
return torch.cat([dx1, dx2], dim=-1)
class BrusselatorHybrid(ODESystem):
def __init__(self, block, insize=2, outsize=2):
"""
:param block:
:param insize:
:param outsize:
"""
super().__init__(insize=insize, outsize=outsize)
self.block = block
self.alpha = torch.nn.Parameter(torch.tensor([5.0], requires_grad=True))
self.beta = torch.nn.Parameter(torch.tensor([5.0], requires_grad=True))
assert self.block.in_features == 2
assert self.block.out_features == 1
def ode_equations(self, x):
x1 = x[:, [0]]
x2 = x[:, [-1]]
dx1 = self.alpha + self.block(x) - self.beta*x1 - x1
dx2 = self.beta*x1 -self.block(x)
return torch.cat([dx1, dx2], dim=-1)
class LotkaVolterraHybrid(ODESystem):
def __init__(self, block, insize=2, outsize=2):
"""
:param block:
:param insize:
:param outsize:
"""
super().__init__(insize=insize, outsize=outsize)
self.block = block
self.alpha = nn.Parameter(torch.tensor([.10]), requires_grad=True)
self.beta = nn.Parameter(torch.tensor([.10]), requires_grad=True)
self.delta = nn.Parameter(torch.tensor([.10]), requires_grad=True)
self.gamma = nn.Parameter(torch.tensor([.10]), requires_grad=True)
assert self.block.in_features == 2
assert self.block.out_features == 1
def ode_equations(self, x):
x1 = x[:, [0]]
x2 = x[:, [-1]]
dx1 = self.alpha*x1 - self.beta*self.block(x)
dx2 = self.delta*self.block(x) - self.gamma*x2
return torch.cat([dx1, dx2], dim=-1)
class LotkaVolterraParam(ODESystem):
def __init__(self, insize=2, outsize=2):
"""
:param insize:
:param outsize:
"""
super().__init__(insize=insize, outsize=outsize)
self.alpha = nn.Parameter(torch.tensor([1.0]), requires_grad=True)
def ode_equations(self, x):
x1 = x[:, [0]]
x2 = x[:, [-1]]
dx1 = self.alpha*x1 - 0.4*x1*x2
dx2 = 0.1*x1*x2 - 0.4*x2
return torch.cat([dx1, dx2], dim=-1)
class VanDerPolControl(ODESystem):
def __init__(self, insize=3, outsize=2):
"""
:param insize:
:param outsize:
"""
super().__init__(insize=insize, outsize=outsize)
self.mu = nn.Parameter(torch.tensor([1.0]), requires_grad=True)
def ode_equations(self, x, u):
x1 = x[:, [0]]
x2 = x[:, [1]]
dx1 = x2
dx2 = self.mu*(1 - x1**2)*x2 - x1 + u
return torch.cat([dx1, dx2], dim=-1)
class LorenzParam(ODESystem):
def __init__(self, insize=3, outsize=3):
"""
:param insize:
:param outsize:
"""
super().__init__(insize=insize, outsize=outsize)
self.rho = torch.nn.Parameter(torch.tensor([5.0], requires_grad=True))
self.sigma = torch.nn.Parameter(torch.tensor([5.0], requires_grad=True))
self.beta = torch.nn.Parameter(torch.tensor([5.0], requires_grad=True))
def ode_equations(self, x):
x1 = x[:, [0]]
x2 = x[:, [1]]
x3 = x[:, [-1]]
dx1 = self.sigma*(x2 - x1)
dx2 = x1*(self.rho - x3) - x2
dx3 = x1*x2 - self.beta*x3
return torch.cat([dx1, dx2, dx3], dim=-1)
class LorenzControl(ODESystem):
def __init__(self, insize=5, outsize=3):
"""
:param insize:
:param outsize:
"""
super().__init__(insize=insize, outsize=outsize)
self.rho = torch.nn.Parameter(torch.tensor([28.0], requires_grad=True))
self.sigma = torch.nn.Parameter(torch.tensor([10.0], requires_grad=True))
self.beta = torch.nn.Parameter(torch.tensor([2.66667], requires_grad=True))
def ode_equations(self, x, u):
x1 = x[:, [0]]
x2 = x[:, [1]]
x3 = x[:, [2]]
u1 = u[:, [0]]
u2 = u[:, [1]]
dx1 = self.sigma * (x2 - x1) + u1
dx2 = x1 * (self.rho - x3) - x2
dx3 = x1 * x2 - self.beta * x3 - u2
return torch.cat([dx1, dx2, dx3], dim=-1)
class CSTR_Param(ODESystem):
def __init__(self, insize=3, outsize=2):
"""
:param insize:
:param outsize:
"""
super().__init__(insize=insize, outsize=outsize)
self.nx = 2
self.nu = 1
# Volumetric Flowrate (m^3/sec)
self.q = torch.nn.Parameter(torch.tensor([100.0], requires_grad=True))
# Volume of CSTR (m^3)
self.V = torch.nn.Parameter(torch.tensor([100.0], requires_grad=False))
# Density of A-B Mixture (kg/m^3)
self.rho = torch.nn.Parameter(torch.tensor([1000.0], requires_grad=True))
# Heat capacity of A-B Mixture (J/kg-K)
self.Cp = torch.nn.Parameter(torch.tensor([0.239], requires_grad=True))
# Heat of reaction for A->B (J/mol)
self.mdelH = torch.nn.Parameter(torch.tensor([5e4], requires_grad=False))
# E - Activation energy in the Arrhenius Equation (J/mol)
# R - Universal Gas Constant = 8.31451 J/mol-K
self.EoverR = torch.nn.Parameter(torch.tensor([8750.], requires_grad=False))
# Pre-exponential factor (1/sec)
self.k0 = torch.nn.Parameter(torch.tensor([7.2e10], requires_grad=False))
# U - Overall Heat Transfer Coefficient (W/m^2-K)
# A - Area - this value is specific for the U calculation (m^2)
self.UA = torch.nn.Parameter(torch.tensor([5e4], requires_grad=False))
# Disturbances: Tf - Feed Temperature (K), Caf - Feed Concentration (mol/m^3)
self.Tf = torch.nn.Parameter(torch.tensor([350.], requires_grad=False))
self.Caf = torch.nn.Parameter(torch.tensor([1.], requires_grad=False))
def ode_equations(self, x, u):
Ca = x[:, [0]] # state: Concentration of A in CSTR (mol/m^3)
T = x[:, [1]] # state: Temperature in CSTR (K)
Tc = u # control: Temperature of cooling jacket (K)
# reaction rate
rA = self.k0 * torch.exp(-self.EoverR / T) * Ca
# Calculate concentration derivative
dCadt = self.q / self.V * (self.Caf - Ca) - rA
# Calculate temperature derivative
dTdt = self.q / self.V * (self.Tf - T) \
+ self.mdelH / (self.rho * self.Cp) * rA \
+ self.UA / self.V / self.rho / self.Cp * (Tc - T)
return torch.cat([dCadt, dTdt], dim=-1)
ode_param_systems_auto = {'LorenzParam': LorenzParam,
'LotkaVolterraParam': LotkaVolterraParam,
'BrusselatorParam': BrusselatorParam}
ode_param_systems_nonauto = {'DuffingParam': DuffingParam,
'TwoTankParam': TwoTankParam,
'LorenzControl': LorenzControl,
'CSTR_Param': CSTR_Param,
'VanDerPolControl': VanDerPolControl}
ode_hybrid_systems_auto = {'LotkaVolterraHybrid': LotkaVolterraHybrid,
'BrusselatorHybrid': BrusselatorHybrid}
ode_networked_systems = {'GeneralNetworkedODE': GeneralNetworkedODE}
odes = {**ode_param_systems_auto,
**ode_param_systems_nonauto,
**ode_hybrid_systems_auto,
**ode_networked_systems}