Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added strict=True #353

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_ys(_final_state):
return final_state


class AbstractAdjoint(eqx.Module):
class AbstractAdjoint(eqx.Module, strict=True):
"""Abstract base class for all adjoint methods."""

@abc.abstractmethod
Expand Down Expand Up @@ -167,7 +167,7 @@ def _uncallable(*args, **kwargs):
assert False


class RecursiveCheckpointAdjoint(AbstractAdjoint):
class RecursiveCheckpointAdjoint(AbstractAdjoint, strict=True):
"""Backpropagate through [`diffrax.diffeqsolve`][] by differentiating the numerical
solution directly. This is sometimes known as "discretise-then-optimise", or
described as "backpropagation through the solver".
Expand Down Expand Up @@ -318,7 +318,7 @@ def loop(
"""


class DirectAdjoint(AbstractAdjoint):
class DirectAdjoint(AbstractAdjoint, strict=True):
"""A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that
`DirectAdjoint`:

Expand Down Expand Up @@ -434,7 +434,7 @@ def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
return frozenset(iter_x)


class ImplicitAdjoint(AbstractAdjoint):
class ImplicitAdjoint(AbstractAdjoint, strict=True):
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).

This is used when solving towards a steady state, typically using
Expand Down Expand Up @@ -705,7 +705,7 @@ def __get(__aug):
return a_y1, a_diff_args1, a_diff_terms1


class BacksolveAdjoint(AbstractAdjoint):
class BacksolveAdjoint(AbstractAdjoint, strict=True):
"""Backpropagate through [`diffrax.diffeqsolve`][] by solving the continuous
adjoint equations backwards-in-time. This is also sometimes known as
"optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .._path import AbstractPath


class AbstractBrownianPath(AbstractPath):
class AbstractBrownianPath(AbstractPath, strict=True):
"""Abstract base class for all Brownian paths."""

levy_area: AbstractVar[LevyArea]
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .base import AbstractBrownianPath


class UnsafeBrownianPath(AbstractBrownianPath):
class UnsafeBrownianPath(AbstractBrownianPath, strict=True):
"""Brownian simulation that is only suitable for certain cases.

This is a very quick way to simulate Brownian motion, but can only be used when all
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
_Spline: TypeAlias = Literal["sqrt", "quad", "zero"]


class _State(eqx.Module):
class _State(eqx.Module, strict=True):
level: IntScalarLike # level of the tree
s: RealScalarLike # starting time of the interval
w_s_u_su: FloatTriple # W_s, W_u, W_{s,u}
Expand Down Expand Up @@ -109,7 +109,7 @@ def _split_interval(
return x_s, x_u, x_su


class VirtualBrownianTree(AbstractBrownianPath):
class VirtualBrownianTree(AbstractBrownianPath, strict=True):
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.

Can be initialised with `levy_area` set to `""`, or `"space-time"`.
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
LevyArea: TypeAlias = Literal["", "space-time"]


class LevyVal(eqx.Module):
class LevyVal(eqx.Module, strict=True):
dt: PyTree
W: PyTree
H: Optional[PyTree]
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._step_size_controller import AbstractAdaptiveStepSizeController


class AbstractDiscreteTerminatingEvent(eqx.Module):
class AbstractDiscreteTerminatingEvent(eqx.Module, strict=True):
"""Evaluated at the end of each integration step. If true then the solve is stopped
at that time.
"""
Expand All @@ -30,7 +30,7 @@ def __call__(self, state, **kwargs) -> BoolScalarLike:
"""


class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent):
class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent, strict=True):
"""Terminates the solve if its condition is ever active."""

cond_fn: Callable[..., BoolScalarLike]
Expand All @@ -50,7 +50,7 @@ def __call__(self, state, **kwargs):
"""


class SteadyStateEvent(AbstractDiscreteTerminatingEvent):
class SteadyStateEvent(AbstractDiscreteTerminatingEvent, strict=True):
"""Terminates the solve once it reaches a steady state."""

rtol: Optional[float] = None
Expand Down
8 changes: 4 additions & 4 deletions diffrax/_global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ._path import AbstractPath


class AbstractGlobalInterpolation(AbstractPath):
class AbstractGlobalInterpolation(AbstractPath, strict=True):
ts: AbstractVar[Real[Array, " times"]]
ts_size: AbstractVar[IntScalarLike]

Expand Down Expand Up @@ -52,7 +52,7 @@ def t1(self):
return self.ts[-1]


class LinearInterpolation(AbstractGlobalInterpolation):
class LinearInterpolation(AbstractGlobalInterpolation, strict=True):
"""Linearly interpolates some data `ys` over the interval $[t_0, t_1]$ with knots
at `ts`.

Expand Down Expand Up @@ -178,7 +178,7 @@ def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
"""


class CubicInterpolation(AbstractGlobalInterpolation):
class CubicInterpolation(AbstractGlobalInterpolation, strict=True):
"""Piecewise cubic spline interpolation over the interval $[t_0, t_1]$."""

ts: Real[Array, " times"]
Expand Down Expand Up @@ -302,7 +302,7 @@ def derivative(
"""


class DenseInterpolation(AbstractGlobalInterpolation):
class DenseInterpolation(AbstractGlobalInterpolation, strict=True):
ts: Real[Array, " times"]
# DenseInterpolations typically get `ts` and `infos` that are way longer than they
# need to be, and padded with `nan`s. This means the normal way of measuring how
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@
from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm


class SaveState(eqx.Module):
class SaveState(eqx.Module, strict=True):
saveat_ts_index: IntScalarLike
ts: eqxi.MaybeBuffer[Real[Array, " times"]]
ys: PyTree[eqxi.MaybeBuffer[Inexact[Array, "times ..."]]]
save_index: IntScalarLike


class State(eqx.Module):
class State(eqx.Module, strict=True):
# Evolving state during the solve
y: PyTree[Array]
tprev: FloatScalarLike
Expand Down
11 changes: 7 additions & 4 deletions diffrax/_local_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, TYPE_CHECKING

import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
Expand All @@ -17,11 +18,11 @@
from ._path import AbstractPath


class AbstractLocalInterpolation(AbstractPath):
class AbstractLocalInterpolation(AbstractPath, strict=True):
pass


class LocalLinearInterpolation(AbstractLocalInterpolation):
class LocalLinearInterpolation(AbstractLocalInterpolation, strict=True):
t0: RealScalarLike
t1: RealScalarLike
y0: Y
Expand All @@ -39,7 +40,7 @@ def evaluate(
return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω


class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation):
class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation, strict=True):
t0: RealScalarLike
t1: RealScalarLike
coeffs: PyTree[Shaped[Array, "4 ?*dims"], "Y"]
Expand Down Expand Up @@ -83,7 +84,9 @@ def _eval(_coeffs):
return jtu.tree_map(_eval, self.coeffs)


class FourthOrderPolynomialInterpolation(AbstractLocalInterpolation):
class FourthOrderPolynomialInterpolation(
AbstractLocalInterpolation, strict=eqx.StrictConfig(allow_abstract_name=True)
):
t0: RealScalarLike
t1: RealScalarLike
coeffs: PyTree[Shaped[Array, "5 ?*y"], "Y"]
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ._custom_types import RealScalarLike


class AbstractPath(eqx.Module):
class AbstractPath(eqx.Module, strict=True):
"""Abstract base class for all paths.

Every path has a start point `t0` and an end point `t1`. In between these values
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_root_finder/_verychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _converged(factor: Scalar, tol: float) -> Bool[Array, ""]:
return (factor > 0) & (factor < tol)


class _VeryChordState(eqx.Module):
class _VeryChordState(eqx.Module, strict=True):
linear_state: tuple[lx.AbstractLinearOperator, PyTree[Any]]
diff: Y
diffsize: Scalar
Expand All @@ -39,7 +39,7 @@ class _VeryChordState(eqx.Module):
step: Scalar


class _NoAux(eqx.Module):
class _NoAux(eqx.Module, strict=True):
fn: Callable

def __call__(self, y, args):
Expand All @@ -48,7 +48,7 @@ def __call__(self, y, args):
return out


class VeryChord(optx.AbstractRootFinder):
class VeryChord(optx.AbstractRootFinder, strict=True):
"""The Chord method of root finding.

As `optimistix.Chord`, except that in Runge--Kutta methods, the linearisation point
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_saveat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _convert_ts(
return jnp.asarray(ts)


class SubSaveAt(eqx.Module):
class SubSaveAt(eqx.Module, strict=True):
"""Used for finer-grained control over what is saved. A PyTree of these should be
passed to `SaveAt(subs=...)`.

Expand Down Expand Up @@ -53,7 +53,7 @@ def __check_init__(self):
"""


class SaveAt(eqx.Module):
class SaveAt(eqx.Module, strict=True):
"""Determines what to save as output from the differential equation solve.

Instances of this class should be passed as the `saveat` argument of
Expand Down
3 changes: 2 additions & 1 deletion diffrax/_solution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Optional

import equinox as eqx
import jax
import optimistix as optx
from jaxtyping import Array, Bool, PyTree, Real, Shaped
Expand Down Expand Up @@ -55,7 +56,7 @@ def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS:
return RESULTS.where(pred, old_result, out_result)


class Solution(AbstractPath):
class Solution(AbstractPath, strict=eqx.StrictConfig(allow_method_override=True)):
"""The solution to a differential equation.

**Attributes:**
Expand Down
12 changes: 6 additions & 6 deletions diffrax/_solver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __instancecheck__(cls, obj):
_set_metaclass = dict(metaclass=_MetaAbstractSolver)


class AbstractSolver(eqx.Module, Generic[_SolverState], **_set_metaclass):
class AbstractSolver(eqx.Module, Generic[_SolverState], strict=True, **_set_metaclass):
"""Abstract base class for all differential equation solvers.

Subclasses should have a class-level attribute `terms`, specifying the PyTree
Expand Down Expand Up @@ -179,7 +179,7 @@ def func(
"""


class AbstractImplicitSolver(AbstractSolver[_SolverState]):
class AbstractImplicitSolver(AbstractSolver[_SolverState], strict=True):
"""Indicates that this is an implicit differential equation solver, and as such
that it should take a root finder as an argument.
"""
Expand All @@ -188,25 +188,25 @@ class AbstractImplicitSolver(AbstractSolver[_SolverState]):
root_find_max_steps: AbstractVar[int]


class AbstractItoSolver(AbstractSolver[_SolverState]):
class AbstractItoSolver(AbstractSolver[_SolverState], strict=True):
"""Indicates that when used as an SDE solver that this solver will converge to the
Itô solution.
"""


class AbstractStratonovichSolver(AbstractSolver[_SolverState]):
class AbstractStratonovichSolver(AbstractSolver[_SolverState], strict=True):
"""Indicates that when used as an SDE solver that this solver will converge to the
Stratonovich solution.
"""


class AbstractAdaptiveSolver(AbstractSolver[_SolverState]):
class AbstractAdaptiveSolver(AbstractSolver[_SolverState], strict=True):
"""Indicates that this solver provides error estimates, and that as such it may be
used with an adaptive step size controller.
"""


class AbstractWrappedSolver(AbstractSolver[_SolverState]):
class AbstractWrappedSolver(AbstractSolver[_SolverState], strict=True):
"""Wraps another solver "transparently", in the sense that all `isinstance` checks
will be forwarded on to the wrapped solver, e.g. when testing whether the solver is
implicit/adaptive/SDE-compatible/etc.
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_solver/dopri5.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)


class _Dopri5Interpolation(FourthOrderPolynomialInterpolation):
class _Dopri5Interpolation(FourthOrderPolynomialInterpolation, strict=True):
c_mid: ClassVar[np.ndarray] = np.array(
[
6025192743 / 30085553152 / 2,
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_solver/dopri8.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
_vmap_polyval = jax.vmap(jnp.polyval, in_axes=(0, None))


class _Dopri8Interpolation(AbstractLocalInterpolation):
class _Dopri8Interpolation(AbstractLocalInterpolation, strict=True):
t0: RealScalarLike
t1: RealScalarLike
y0: Y
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_solver/kencarp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
)


class KenCarpInterpolation(AbstractLocalInterpolation):
class AbstractKenCarpInterpolation(AbstractLocalInterpolation, strict=True):
t0: RealScalarLike
t1: RealScalarLike
y0: Y
Expand Down Expand Up @@ -120,7 +120,7 @@ def evaluate(
return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω


class _KenCarp3Interpolation(KenCarpInterpolation):
class _KenCarp3Interpolation(AbstractKenCarpInterpolation, strict=True):
coeffs = np.array(
[
[-215264564351 / 13552729205753, 4655552711362 / 22874653954995],
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_solver/kencarp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .._root_finder import VeryChord, with_stepsize_controller_tols
from .base import AbstractImplicitSolver
from .kencarp3 import KenCarpInterpolation
from .kencarp3 import AbstractKenCarpInterpolation
from .runge_kutta import (
AbstractRungeKutta,
ButcherTableau,
Expand Down Expand Up @@ -102,7 +102,7 @@
)


class _KenCarp4Interpolation(KenCarpInterpolation):
class _KenCarp4Interpolation(AbstractKenCarpInterpolation, strict=True):
coeffs = np.array(
[
[
Expand Down
Loading
Loading