You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am solving a simple problem below using DiscreteTerminatingEvent. Once the event is triggered, the integration stops, but the solver returns 'inf' values for the time steps following the event's trigger time. Is there a way to avoid this, so that the solver returns function evaluations only for the time steps before the event-trigger time, similar to how solve_ivp in SciPy does?
"import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, Dopri5, DiscreteTerminatingEvent
'''Define the terminating event function with two conditions'''
def terminating_event_fxn(state, args, **kwargs):
prey_population = state.y[0]
predator_population = state.y[1]
A = (prey_population < 5) | (predator_population > 15)
return A
'''Set up the ODE term, solver, and the initial conditions'''
term = ODETerm(vector_field)
solver = Dopri5()
t0 = 0
t1 = 140
dt0 = 0.1
y0 = jnp.array([10.0, 10.0])
args = (0.1, 0.02, 0.4, 0.02)
saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))
'''Define the terminating event'''
terminating_event = DiscreteTerminatingEvent(terminating_event_fxn)
'''Solve the ODE with the terminating event'''
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, discrete_terminating_event=terminating_event)
I'm afraid not. All JAX arrays have to have a size known at compile time. However, the time of the event isn't known until runtime. As such Diffrax works by initialising an array of the appropriate size (here, of length given by saveat.ts) all filled with inf. Then it fills in this array as the integration progresses.
I am solving a simple problem below using DiscreteTerminatingEvent. Once the event is triggered, the integration stops, but the solver returns 'inf' values for the time steps following the event's trigger time. Is there a way to avoid this, so that the solver returns function evaluations only for the time steps before the event-trigger time, similar to how solve_ivp in SciPy does?
"import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, Dopri5, DiscreteTerminatingEvent
def vector_field(t, y, args):
prey, predator = y
α, β, γ, δ = args
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
return jnp.array([d_prey, d_predator])
'''Define the terminating event function with two conditions'''
def terminating_event_fxn(state, args, **kwargs):
prey_population = state.y[0]
predator_population = state.y[1]
'''Set up the ODE term, solver, and the initial conditions'''
term = ODETerm(vector_field)
solver = Dopri5()
t0 = 0
t1 = 140
dt0 = 0.1
y0 = jnp.array([10.0, 10.0])
args = (0.1, 0.02, 0.4, 0.02)
saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))
'''Define the terminating event'''
terminating_event = DiscreteTerminatingEvent(terminating_event_fxn)
'''Solve the ODE with the terminating event'''
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, discrete_terminating_event=terminating_event)
'''Plot the results'''
plt.plot(sol.ts, sol.ys[:, 0], label="Prey")
plt.plot(sol.ts, sol.ys[:, 1], label="Predator")
plt.legend()
plt.show()
print(sol.ys[:, 0].size)
print(sol.ts.shape)
"
The text was updated successfully, but these errors were encountered: