-
Notifications
You must be signed in to change notification settings - Fork 0
/
bench_diffrax.py
150 lines (109 loc) · 2.99 KB
/
bench_diffrax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python
# coding: utf-8
# %%
# Benchmarking Diffrax ODE solvers for ensemble problems, via vmap. The Lorenz ODE is integrated by Tsit5.
# Created By: Utkarsh
# Last Updated: 19 April 2023
# %%
import time
import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import os
import timeit
import sys
numberOfParameters = int(sys.argv[1])
# %%
from jax.lib import xla_bridge
print("Working on :", xla_bridge.get_backend().platform)
# %%
# Defining the Lorenz Problem
class Lorenz(eqx.Module):
k1: float
def __call__(self, t, y, args):
f0 = 10.0*(y[1] - y[0])
f1 = self.k1 * y[0] - y[1] - y[0] * y[2]
f2 = y[0] * y[1] - (8/3)*y[2]
return jnp.stack([f0, f1, f2])
# %%
# JIT compilation of ODE solver
@jax.jit
def main(k1):
lorenz = Lorenz(k1)
terms = diffrax.ODETerm(lorenz)
t0 = 0.0
t1 = 1.0
y0 = jnp.array([1.0, 0.0, 0.0])
dt0 = 0.001
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts = jnp.array([t0,t1]))
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-3)
sol = diffrax.diffeqsolve(
terms,
solver,
t0,
t1,
dt0,
y0,
)
return sol
# %%
# Uncomment for smoke test
# main(28.0)
# start = time.time()
# sol = main(28.0)
# end = time.time()
# print("Results:")
# for ti, yi in zip(sol.ts, sol.ys):
# print(f"t={ti.item()}, y={yi.tolist()}")
# print(f"Took {sol.stats['num_steps']} steps in {end - start} seconds.")
# %%
# Setting up parameters for parallel simulation
parameterList = jnp.linspace(0.0,21.0,numberOfParameters)
# %%
# Use jax.vmap to compute parallel solutions of the ODE
res = timeit.repeat(lambda: jax.vmap(main)(parameterList),repeat = 100,number = 1)
best_time = min(res)*1000
print("{:} ODE solves with fixed time-stepping completed in {:.1f} ms".format(numberOfParameters, best_time))
# %%
# Save the minimum time
file = open("./data/JAX/Jax_times_unadaptive.txt","a+")
file.write('{0} {1}\n'.format(numberOfParameters, best_time))
file.close()
# %%
# Repeat the same for adaptive time-stepping
@jax.jit
def main(k1):
lorenz = Lorenz(k1)
terms = diffrax.ODETerm(lorenz)
t0 = 0.0
t1 = 1.0
y0 = jnp.array([1.0, 0.0, 0.0])
dt0 = 0.001
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts = jnp.array([t0,t1]))
stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
sol = diffrax.diffeqsolve(
terms,
solver,
t0,
t1,
dt0,
y0,
# saveat=saveat,
stepsize_controller=stepsize_controller,
)
return sol
# %%
import timeit
# %%
res = timeit.repeat(lambda: jax.vmap(main)(parameterList),repeat = 100,number = 1)
# %%
best_time = min(res)*1000
print("{:} ODE solves with adaptive time-stepping completed in {:.1f} ms".format(numberOfParameters, best_time))
# %%
file = open("./data/JAX/Jax_times_adaptive.txt","a+")
file.write('{0} {1}\n'.format(numberOfParameters, best_time))
file.close()