You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# The wrappers that a user would expectdefsimulate_terminal_value(vector_field, t0, t1, u0, taylor_diff_fn, solver):
taylor_coefficients=taylor_diff(vector_field, num=solver.num_derivatives)
returnodesimulate_terminal_value(vector_field, t0, t1, taylor_coeffs, solver)
defsimulate_checkpoints(vector_field, ts, u0, taylor_diff_fn, solver):
taylor_coefficients=taylor_diff(vector_field, num=solver.num_derivatives)
returnodesimulate_checkpoints(vector_field, ts, taylor_coeffs, solver)
# The actual solvers I'd like to providedefodesimulate_terminal_value(vector_field, t0, t1, taylor_coeffs, solver):
# Creates an initial solution object from the Taylor coefficients # (But not the full state -- this decouples the Taylor-coefficient stuff # from the state initialisation and essentially # resolves #48 #85 and probably even more issues)# In the ``jax.optimizers`` world, it would be the initial PyTree of Params# But here, this is a little too solver-dependent to ask from the user.solution=solver.taylorcoefficients_to_solution(taylor_coefficients, t0, t1)
defcond_fun(state): # can make an argument, no problemreturnstate.accepted.t<state.t1returnsimulate(vector_field, t0, t1, solution, solver)
defodesimulate_checkpoints(vector_field, ts, taylor_coefficients, solver):
# See abovesolution=solver.taylorcoefficients_to_solution(taylor_coefficients, t0, t1)
defcond_fun(state): # can make an argument, no problemreturnstate.accepted.t<state.t1full_solution= [] # pseudo-init_fn()fort0, t1inzip(ts[:-1], ts[1:]): # this would be a scan, actually. solution=simulate(vector_field, t0, t1, solver, solution) # pseudo-apply_fn()full_solution.append(solution)
returnfull_solution# pseudo-extract_fn()# The low-level init-apply-extract schemes and while-loops# We could even make the choice of backend function an argument of the simulation.defsimulate_no_lax(vector_field, t0, t1, solver: Solver[T], solution: T) ->T:
problem= (vector_field, t0, t1)
state=solver.init_fn(*problem, initial_solution)
whilecond_fun(state):
state=solver.step_fn(*problem, state)
solution=solver.extract_fn(state)
returnsolutiondefsimulate(vector_field, solver, solution, cond_fun):
state=solver.init_fn(initial_solution)
state=lax.while_loop(cond_fun, lambdas: solver.step_fn(vector_field, state=s), state)
returnsolver.extract_fn(state)
defsimulate_diffrax(vector_field, solver, solution, cond_fun):
state=solver.init_fn(initial_solution)
state=diffrax.bounded_while_loop(cond_fun, lambdas: solver.step_fn(vector_field, state=s), state)
returnsolver.extract_fn(state)
which would resolve a couple of problems:
init_fn and extract_fn could be inverse to each other (properly!) State vs. solution #85
The solver loses the Taylor-mode component, which is really something that extends the problem definition instead of helping to solve it (the clarity of how the ODE filter operates would improve!)
Code would be super readable because every function is minimal. We would not need many docs, because the code is so trivial.
The text was updated successfully, but these errors were encountered:
How about we implement a couple of methods like
which would resolve a couple of problems:
The text was updated successfully, but these errors were encountered: