Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delegate solve() routines to lower-level versions #87

Closed
pnkraemer opened this issue Oct 11, 2022 · 2 comments
Closed

Delegate solve() routines to lower-level versions #87

pnkraemer opened this issue Oct 11, 2022 · 2 comments

Comments

@pnkraemer
Copy link
Owner

How about we implement a couple of methods like

# The wrappers that a user would expect

def simulate_terminal_value(vector_field, t0, t1, u0, taylor_diff_fn, solver):
    taylor_coefficients = taylor_diff(vector_field, num=solver.num_derivatives)
    return odesimulate_terminal_value(vector_field, t0, t1, taylor_coeffs, solver)

def simulate_checkpoints(vector_field, ts, u0, taylor_diff_fn, solver):
    taylor_coefficients = taylor_diff(vector_field, num=solver.num_derivatives)
    return odesimulate_checkpoints(vector_field, ts, taylor_coeffs, solver)

# The actual solvers I'd like to provide

def odesimulate_terminal_value(vector_field, t0, t1, taylor_coeffs, solver):

    # Creates an initial solution object from the Taylor coefficients 
    # (But not the full state -- this decouples the Taylor-coefficient stuff 
    # from the state initialisation and essentially 
    # resolves #48 #85 and probably even more issues)
    # In the ``jax.optimizers`` world, it would be the initial PyTree of Params
    # But here, this is a little too solver-dependent to ask from the user.
    solution = solver.taylorcoefficients_to_solution(taylor_coefficients, t0, t1)

    def cond_fun(state):  # can make an argument, no problem
        return state.accepted.t < state.t1

    return simulate(vector_field, t0, t1, solution, solver)

def odesimulate_checkpoints(vector_field, ts, taylor_coefficients, solver):

    # See above
    solution = solver.taylorcoefficients_to_solution(taylor_coefficients, t0, t1)

    def cond_fun(state):  # can make an argument, no problem
        return state.accepted.t < state.t1

    full_solution = []  # pseudo-init_fn()
    for t0, t1 in zip(ts[:-1], ts[1:]):  # this would be a scan, actually. 
        solution = simulate(vector_field, t0, t1, solver, solution)  # pseudo-apply_fn()
        full_solution.append(solution)
    return full_solution  # pseudo-extract_fn()

# The low-level init-apply-extract schemes and while-loops
# We could even make the choice of backend function an argument of the simulation.

def simulate_no_lax(vector_field, t0, t1, solver: Solver[T], solution: T) -> T:
    problem = (vector_field, t0, t1)
    state = solver.init_fn(*problem, initial_solution)
    while cond_fun(state):
        state = solver.step_fn(*problem, state)
    solution = solver.extract_fn(state)
    return solution

def simulate(vector_field, solver, solution, cond_fun):
    state = solver.init_fn(initial_solution)
    state = lax.while_loop(cond_fun, lambda s: solver.step_fn(vector_field, state=s), state)
    return solver.extract_fn(state)

def simulate_diffrax(vector_field, solver, solution, cond_fun):
    state = solver.init_fn(initial_solution)
    state = diffrax.bounded_while_loop(cond_fun, lambda s: solver.step_fn(vector_field, state=s), state)
    return solver.extract_fn(state)

which would resolve a couple of problems:

  • init_fn and extract_fn could be inverse to each other (properly!) State vs. solution #85
  • The solver loses the Taylor-mode component, which is really something that extends the problem definition instead of helping to solve it (the clarity of how the ODE filter operates would improve!)
  • Code would be super readable because every function is minimal. We would not need many docs, because the code is so trivial.
@pnkraemer
Copy link
Owner Author

Open questions:

  • How do constant-step and constant-grid solving fit into this picture?
  • How well does end-of-time handling work in this context? (It should work well)
  • Is the solver an AdaptiveSolver or a generic ODE filter?

@pnkraemer
Copy link
Owner Author

Kind of resolved by now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant