Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type instabilities for sesolve #189

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ QuantumToolboxCUDAExt = "CUDA"
[compat]
ArrayInterface = "6, 7"
CUDA = "5"
DiffEqCallbacks = "2, 3"
DiffEqCallbacks = "2, <3.2"
FFTW = "1.5"
Graphs = "1.7"
IncompleteLU = "0.2"
Expand Down
25 changes: 13 additions & 12 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end
@doc raw"""
sesolveProblem(H::QuantumObject,
ψ0::QuantumObject,
t_l::AbstractVector;
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5()
e_ops::AbstractVector=[],
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
Expand All @@ -44,7 +44,7 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system

- `H::QuantumObject`: The Hamiltonian of the system ``\hat{H}``.
- `ψ0::QuantumObject`: The initial state of the system ``|\psi(0)\rangle``.
- `t_l::AbstractVector`: The time list of the evolution.
- `tlist::AbstractVector`: The time list of the evolution.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: The algorithm used for the time evolution.
- `e_ops::AbstractVector`: The list of operators to be evaluated during the evolution.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: The time-dependent Hamiltonian of the system. If `nothing`, the Hamiltonian is time-independent.
Expand All @@ -55,7 +55,7 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system
# Notes

- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
- If `e_ops` is specified, the default value of `saveat=[t_l[end]]` (only save the final state), otherwise, `saveat=t_l` (saving the states corresponding to `t_l`). You can also specify `e_ops` and `saveat` separately.
- If `e_ops` is specified, the default value of `saveat=[tlist[end]]` (only save the final state), otherwise, `saveat=tlist` (saving the states corresponding to `tlist`). You can also specify `e_ops` and `saveat` separately.
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
- For more details about `alg` and extra `kwargs`, please refer to [`DifferentialEquations.jl`](https://diffeq.sciml.ai/stable/)

Expand All @@ -66,7 +66,7 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system
function sesolveProblem(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
t_l::AbstractVector;
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{MT2,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
Expand All @@ -81,6 +81,8 @@ function sesolveProblem(

is_time_dependent = !(H_t === nothing)

t_l = collect(tlist)

ϕ0 = get_data(ψ0)
U = -1im * get_data(H)

Expand Down Expand Up @@ -141,7 +143,7 @@ end
@doc raw"""
sesolve(H::QuantumObject,
ψ0::QuantumObject,
t_l::AbstractVector;
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::AbstractVector=[],
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
Expand All @@ -159,7 +161,7 @@ Time evolution of a closed quantum system using the Schrödinger equation:

- `H::QuantumObject`: The Hamiltonian of the system ``\hat{H}``.
- `ψ0::QuantumObject`: The initial state of the system ``|\psi(0)\rangle``.
- `t_l::AbstractVector`: List of times at which to save the state of the system.
- `tlist::AbstractVector`: List of times at which to save the state of the system.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: Algorithm to use for the time evolution.
- `e_ops::AbstractVector`: List of operators for which to calculate expectation values.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
Expand All @@ -170,7 +172,7 @@ Time evolution of a closed quantum system using the Schrödinger equation:
# Notes

- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
- If `e_ops` is specified, the default value of `saveat=[t_l[end]]` (only save the final state), otherwise, `saveat=t_l` (saving the states corresponding to `t_l`). You can also specify `e_ops` and `saveat` separately.
- If `e_ops` is specified, the default value of `saveat=[tlist[end]]` (only save the final state), otherwise, `saveat=tlist` (saving the states corresponding to `tlist`). You can also specify `e_ops` and `saveat` separately.
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
- For more details about `alg` and extra `kwargs`, please refer to [`DifferentialEquations.jl`](https://diffeq.sciml.ai/stable/)

Expand All @@ -181,7 +183,7 @@ Time evolution of a closed quantum system using the Schrödinger equation:
function sesolve(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
t_l::AbstractVector;
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{MT2,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
Expand All @@ -192,7 +194,7 @@ function sesolve(
prob = sesolveProblem(
H,
ψ0,
t_l;
tlist;
alg = alg,
e_ops = e_ops,
H_t = H_t,
Expand All @@ -206,9 +208,8 @@ end

function sesolve(prob::ODEProblem, alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5())
sol = solve(prob, alg)
ψt =
isempty(sol.prob.kwargs[:saveat]) ? QuantumObject[] :
map(ϕ -> QuantumObject(ϕ, type = Ket, dims = sol.prob.p.Hdims), sol.u)

ψt = map(ϕ -> QuantumObject(ϕ, type = Ket, dims = sol.prob.p.Hdims), sol.u)

return TimeEvolutionSol(
sol.t,
Expand Down
10 changes: 9 additions & 1 deletion test/time_evolution_and_partial_trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
psi0 = kron(fock(N, 0), fock(2, 0))
t_l = LinRange(0, 1000, 1000)
e_ops = [a_d * a]
sol = sesolve(H, psi0, t_l, e_ops = e_ops, alg = Vern7(), progress_bar = false)
sol = sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = false)
sol2 = sesolve(H, psi0, t_l, progress_bar = false)
sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = false)
sol_string = sprint((t, s) -> show(t, "text/plain", s), sol)
Expand All @@ -33,6 +33,14 @@
"ODE alg.: $(sol.alg)\n" *
"abstol = $(sol.abstol)\n" *
"reltol = $(sol.reltol)\n"

@testset "Type Inference sesolve" begin
if VERSION >= v"1.10"
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = false)
@inferred sesolve(H, psi0, t_l, progress_bar = false)
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = false)
end
end
end

@testset "mesolve and mcsolve" begin
Expand Down