From 90877869f324b1a69a65a37e2bc930442d8faf5a Mon Sep 17 00:00:00 2001 From: Huang-Xu-Yang Date: Wed, 3 Sep 2025 20:45:26 +0800 Subject: [PATCH] new default, test and docstring --- tensorcircuit/timeevol.py | 53 +++++++++++++++++++++++---------------- tests/test_timeevol.py | 14 ++++++----- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/tensorcircuit/timeevol.py b/tensorcircuit/timeevol.py index d4263526..9f13cf5a 100644 --- a/tensorcircuit/timeevol.py +++ b/tensorcircuit/timeevol.py @@ -435,10 +435,10 @@ def _solve_ode( args: Any, solver_kws: Dict[str, Any], ) -> Tensor: - rtol = solver_kws.get("rtol", 1e-12) - atol = solver_kws.get("atol", 1e-12) + rtol = solver_kws.get("rtol", 1e-8) + atol = solver_kws.get("atol", 1e-8) ode_backend = solver_kws.get("ode_backend", "jaxode") - max_steps = solver_kws.get("max_steps", 10000) + max_steps = solver_kws.get("max_steps", 4096) ts = backend.convert_to_tensor(times) ts = backend.cast(ts, dtype=rdtypestr) @@ -513,15 +513,21 @@ def ode_evol_local( :type callback: Optional[Callable[..., Tensor]] :param args: Additional arguments to pass to the Hamiltonian function. :param solver_kws: Additional keyword arguments to pass to the ODE solver. - - ode_backend='jaxode'(default) uses ``jax.experimental.ode.odeint``; ode_backend='diffrax' - uses ``diffrax.diffeqsolve``. - - rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would - like the numerical approximation to your equation. - - The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'} - and only works when ode_backend='diffrax'. - - dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'. - - max_steps (default: 10000) The maximum number of steps to take before quitting the computation - unconditionally and only works when ode_backend='diffrax'. + + - ``ode_backend='jaxode'`` (default) uses ``jax.experimental.ode.odeint``; ``ode_backend='diffrax'`` + uses ``diffrax.diffeqsolve``. + + - ``rtol`` (default: 1e-8) and ``atol`` (default: 1e-8) are used to determine how accurately you would + like the numerical approximation to your equation. + + - The ``solver`` parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'} + and only works when ``ode_backend='diffrax'``. + + - ``t0`` (default: 0.01) specifies the initial step size and only works when ``ode_backend='diffrax'``. + + - ``max_steps`` (default: 4096) The maximum number of steps to take before quitting the computation + unconditionally and only works when ``ode_backend='diffrax'``. + :type solver_kws: dict :return: Evolved quantum states at the specified time points. If callback is provided, returns the callback results; otherwise returns the state vectors. @@ -585,17 +591,22 @@ def ode_evol_global( :param args: Additional arguments to pass to the Hamiltonian function. :type args: tuple | list :param solver_kws: Additional keyword arguments to pass to the ODE solver. - - ode_backend='jaxode'(default) uses ``jax.experimental.ode.odeint``; ode_backend='diffrax' - uses ``diffrax.diffeqsolve``. - - rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would - like the numerical approximation to your equation. - - The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'} - and only works when ode_backend='diffrax'. - - dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'. - - max_steps (default: 10000) The maximum number of steps to take before quitting the computation - unconditionally and only works when ode_backend='diffrax'. + - ``ode_backend='jaxode'`` (default) uses ``jax.experimental.ode.odeint``; ``ode_backend='diffrax'`` + uses ``diffrax.diffeqsolve``. + + - ``rtol`` (default: 1e-8) and ``atol`` (default: 1e-8) are used to determine how accurately you would + like the numerical approximation to your equation. + + - The ``solver`` parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'} + and only works when ``ode_backend='diffrax'``. + + - ``t0`` (default: 0.01) specifies the initial step size and only works when ``ode_backend='diffrax'``. + + - ``max_steps`` (default: 4096) The maximum number of steps to take before quitting the computation + unconditionally and only works when ``ode_backend='diffrax'``. :type solver_kws: dict + :return: Evolved quantum states at the specified time points. If callback is provided, returns the callback results; otherwise returns the state vectors. :rtype: Tensor diff --git a/tests/test_timeevol.py b/tests/test_timeevol.py index 5dc3a227..ceeb4a13 100644 --- a/tests/test_timeevol.py +++ b/tests/test_timeevol.py @@ -105,14 +105,14 @@ def local_hamiltonian(t, Omega, phi): 1.0, 2.0, # Omega=1.0, phi=2.0 solver="Dopri8", - atol=1.0e-13, - rtol=1.0e-13, + atol=1.0e-11, + rtol=1.0e-11, ode_backend="diffrax", dt0=0.005, ) - np.testing.assert_allclose(states2, states1, atol=1e-10, rtol=0.0) - np.testing.assert_allclose(states0, states1, atol=1e-10, rtol=0.0) + np.testing.assert_allclose(states2, states1, atol=1e-5, rtol=0.0) + np.testing.assert_allclose(states0, states1, atol=1e-5, rtol=0.0) def test_ode_evol_global(highp, jaxb): @@ -270,6 +270,8 @@ def do5_ode_solver_(params): tc.backend.convert_to_tensor([0, 10.0]), None, *params, + atol=1.0e-13, + rtol=1.0e-13, ) return tc.backend.real(zz_correlation(states[-1])) @@ -339,8 +341,8 @@ def do5_ode_solver_local(paras): v1, g1 = s1 v2, g2 = s2 - np.testing.assert_allclose(g1, g2, atol=1e-8, rtol=0) - np.testing.assert_allclose(v1, v2, atol=1e-8, rtol=0) + np.testing.assert_allclose(g1, g2, atol=1e-5, rtol=0) + np.testing.assert_allclose(v1, v2, atol=1e-5, rtol=0) @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])