Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 30 additions & 32 deletions tensorcircuit/timeevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,17 +493,15 @@ def ode_evol_local(
) -> Tensor:
"""
ODE-based time evolution for a time-dependent Hamiltonian acting on a subsystem of qubits.

This function solves the time-dependent Schrodinger equation using numerical ODE integration.
The Hamiltonian is applied only to a specific subset of qubits (indices) in the system.

The ode_backend parameter defaults to 'jaxode' (which uses `jax.experimental.ode.odeint` with a default solver
of 'Dopri5');if set to 'diffrax', it uses `diffrax.diffeqsolve` instead (with a default solver of 'Tsit5').
The ode_backend parameter defaults to 'jaxode' (which uses ``jax.experimental.ode.odeint`` with a default solver
of 'Dopri5'); if set to 'diffrax', it uses ``diffrax.diffeqsolve`` instead (with a default solver of 'Tsit5').

Note: This function currently only supports the JAX backend.

:param hamiltonian: A function that returns a dense Hamiltonian matrix for the specified
subsystem size. The function signature should be hamiltonian(time, *args) -> Tensor.
subsystem size. The function signature should be ``hamiltonian(time, *args) -> Tensor``.
:type hamiltonian: Callable[..., Tensor]
:param initial_state: The initial quantum state vector of the full system.
:type initial_state: Tensor
Expand All @@ -515,15 +513,16 @@ def ode_evol_local(
:type callback: Optional[Callable[..., Tensor]]
:param args: Additional arguments to pass to the Hamiltonian function.
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
ode_backend='jaxode'(default) uses `jax.experimental.ode.odeint`; ode_backend='diffrax'
uses `diffrax.diffeqsolve`.
rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would
like the numerical approximation to your equation.
The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
and only works when ode_backend='diffrax'.
dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'.
max_steps (default: 10000) The maximum number of steps to take before quitting the computation
unconditionally and only works when ode_backend='diffrax'.
- ode_backend='jaxode'(default) uses ``jax.experimental.ode.odeint``; ode_backend='diffrax'
uses ``diffrax.diffeqsolve``.
- rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would
like the numerical approximation to your equation.
- The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
and only works when ode_backend='diffrax'.
- dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'.
- max_steps (default: 10000) The maximum number of steps to take before quitting the computation
unconditionally and only works when ode_backend='diffrax'.

:return: Evolved quantum states at the specified time points. If callback is provided,
returns the callback results; otherwise returns the state vectors.
:rtype: Tensor
Expand Down Expand Up @@ -566,18 +565,16 @@ def ode_evol_global(
) -> Tensor:
"""
ODE-based time evolution for a time-dependent Hamiltonian acting on the entire system.

This function solves the time-dependent Schrodinger equation using numerical ODE integration.
The Hamiltonian is applied to the full system and should be provided in sparse matrix format
for efficiency.

The ode_backend parameter defaults to 'jaxode' (which uses `jax.experimental.ode.odeint` with a default solver
of 'Dopri5');if set to 'diffrax', it uses `diffrax.diffeqsolve` instead (with a default solver of 'Tsit5').
The Hamiltonian is applied to the full system and should be provided in sparse matrix
format for efficiency.
The ode_backend parameter defaults to 'jaxode' (which uses ``jax.experimental.ode.odeint`` with a default solver
of 'Dopri5'); if set to 'diffrax', it uses ``diffrax.diffeqsolve`` instead (with a default solver of 'Tsit5').

Note: This function currently only supports the JAX backend.

:param hamiltonian: A function that returns a sparse Hamiltonian matrix for the full system.
The function signature should be hamiltonian(time, *args) -> Tensor.
The function signature should be ``hamiltonian(time, *args) -> Tensor``.
:type hamiltonian: Callable[..., Tensor]
:param initial_state: The initial quantum state vector.
:type initial_state: Tensor
Expand All @@ -588,15 +585,16 @@ def ode_evol_global(
:param args: Additional arguments to pass to the Hamiltonian function.
:type args: tuple | list
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
ode_backend='jaxode'(default) uses `jax.experimental.ode.odeint`; ode_backend='diffrax'
uses `diffrax.diffeqsolve`.
rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would
like the numerical approximation to your equation.
The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
and only works when ode_backend='diffrax'.
dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'.
max_steps (default: 10000) The maximum number of steps to take before quitting the computation
unconditionally and only works when ode_backend='diffrax'.
- ode_backend='jaxode'(default) uses ``jax.experimental.ode.odeint``; ode_backend='diffrax'
uses ``diffrax.diffeqsolve``.
- rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would
like the numerical approximation to your equation.
- The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
and only works when ode_backend='diffrax'.
- dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'.
- max_steps (default: 10000) The maximum number of steps to take before quitting the computation
unconditionally and only works when ode_backend='diffrax'.

:type solver_kws: dict
:return: Evolved quantum states at the specified time points. If callback is provided,
returns the callback results; otherwise returns the state vectors.
Expand Down Expand Up @@ -632,7 +630,7 @@ def evol_local(
:param index: qubit sites to evolve
:type index: Sequence[int]
:param h_fun: h_fun should return a dense Hamiltonian matrix
with input arguments time and *args
with input arguments ``time`` and ``*args``
:type h_fun: Callable[..., Tensor]
:param t: evolution time
:type t: float
Expand All @@ -658,7 +656,7 @@ def evol_global(
:param c: _description_
:type c: Circuit
:param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
with input arguments time and *args
with input arguments ``time`` and ``*args``
:type h_fun: Callable[..., Tensor]
:param t: _description_
:type t: float
Expand Down