Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weird behaviour due to defaults when using Implicit-Euler #156

Closed
SimiPixel opened this issue Sep 9, 2022 · 6 comments
Closed

Weird behaviour due to defaults when using Implicit-Euler #156

SimiPixel opened this issue Sep 9, 2022 · 6 comments
Labels
refactor Tidy things up

Comments

@SimiPixel
Copy link
Contributor

SimiPixel commented Sep 9, 2022

When using dfx.ImplicitEuler() with everything set to default an error is raised

missing rtol and atol of NewtonNonlinearSolver

You are then prompted to set these values in the stepsize-controller, because it is by default supposed to fallback to the values provided in PIDController.
But dfx.ImplicitEuler() does not support adaptive step-sizing using a PIDController.

The solution is to use

solver=dfx.ImplicitEuler(nonlinear_solver=dfx.NewtonNonlinearSolver(rtol=1e-3, atol=1e-6))

Just something that feels a bit odd.

@patrick-kidger patrick-kidger added the refactor Tidy things up label Sep 9, 2022
@patrick-kidger
Copy link
Owner

patrick-kidger commented Sep 9, 2022

Yep, your point makes sense. We should add a check for whether the solver is adaptive or not, and display an error message appropriately.

@SimiPixel
Copy link
Contributor Author

Do you concur with the following changed logic?

from diffrax import AbstractSolver, AbstractNonlinearSolver, NewtonNonlinearSolver


class AbstractImplicitSolver(AbstractSolver):
    """Indicates that this is an implicit differential equation solver, and as such
    that it should take a nonlinear solver as an argument.
    """

    nonlinear_solver: AbstractNonlinearSolver


class AbstractAdaptiveSolver(AbstractSolver):
    """Indicates that this solver provides error estimates, and that as such it may be
    used with an adaptive step size controller.
    """


class AbstractImplicitAdaptiveSolver(AbstractImplicitSolver, AbstractAdaptiveSolver):
    """Indicates that this is an implicit differential equation solver that also 
    provides error estimates, and as such it may be used with an adaptive step
    size controller and that it should take a nonlinear solver as an argument.
    """
    nonlinear_solver: AbstractNonlinearSolver = NewtonNonlinearSolver()

@patrick-kidger
Copy link
Owner

Hmm. I see what you're aiming for -- add a default nonlinear solver iff we're adaptive -- but I'd prefer to avoid adding abstract base classes with special behaviour for the intersection of two concepts.

I think the change I'd propose here is actually a little more involved. I've thrown together a completely untested draft at #157. The TL;DR is that by default the atol and rtol have special markers indicating that they should be replaced with the stepsize controller's atol and rtol.

Why this approach?

  • Avoids adding a special intersection of two concepts, as above.
  • Decouples the nonlinear solvers from the rest of Diffrax. (I have longer-term plans to factor these out elsewhere.)
  • Means that we can mark any other atol and rtol as something that should be inherited too. (For example inside a Conjugate Gradient linear solver.)

WDYT?

@SimiPixel
Copy link
Contributor Author

I like it. The entire logic is then handled by the fact that (leaf) nodes of a pytree can be typed. Correct?

class UseControllerAtol(eqx.Module):
    pass
class UseControllerRtol(eqx.Module):
    pass

# ...

def _replace_tol(x):
    if isinstance(x, UseControllerAtol):
        return stepsize_controller.atol
    elif isinstance(x, UseControllerRtol):
        return stepsize_controller.rtol
    else:
        return x

# solver is another `eqx.Module` (i.e. a pytree)
solver = jtu.tree_map(_replace_tol, solver)

@patrick-kidger
Copy link
Owner

Pretty much. Although I've just spotted that I should have written tree_map(..., is_leaf=lambda x: isinstance(x, (UseControllerAtol, UseControllerRtol))) so that these are in fact treated as leaves.

(At the moment they're non-leaf nodes that just so happen not to have any child nodes.)

@SimiPixel
Copy link
Contributor Author

Makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactor Tidy things up
Projects
None yet
2 participants