Skip to content

Commit

Permalink
Implement the OrdinaryDiffEq interface for Kets
Browse files Browse the repository at this point in the history
The following now works

```julia
using QuantumOptics
using DifferentialEquations

ℋ = SpinBasis(1//2)

σx = sigmax(ℋ)

↓ = s =  spindown(ℋ)

schrod(ψ,p,t) = im * σx * ψ

t₀, t₁ = (0.0, pi)
Δt = 0.1

prob = ODEProblem(schrod, ↓, (t₀, t₁))
sol = solve(prob,Tsit5())
```

It works for Bras as well.
It works for in-place operations, however there are spurrious
allocations due to inefficient broadcasting that ruin the performance.
  • Loading branch information
Krastanov committed Apr 4, 2021
1 parent 201b630 commit c3621ca
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where {
end
find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes)
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
Expand Down
26 changes: 23 additions & 3 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ find_basis(x) = x
find_basis(a::StateVector, rest) = a.basis
find_basis(::Any, rest) = find_basis(rest)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)} # `:/` was added for use with scalars in the DifferentialEquations interface
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:T}}, axes) where T<:StateVector
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
Expand All @@ -237,6 +237,15 @@ function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:T}}, axes) where T<:Stat
throw(error("Cannot broadcast function `$f` on type `$T`"))
end

# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`)
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:QuantumOpticsBase.KetStyle{B}} = T()
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:QuantumOpticsBase.BraStyle{B}} = T()
getdata(arg::StateVector) = arg.data
getdata(arg) = arg
function Broadcasted_restrict_f(f, args, axes)
args_ = Tuple(getdata(a) for a=args)
return Broadcast.Broadcasted(f, args_, axes)
end

# In-place broadcasting for Kets
@inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args}
Expand All @@ -250,8 +259,8 @@ end
end
# Get the underlying data fields of kets and broadcast them as arrays
bcf = Broadcast.flatten(bc)
args_ = Tuple(a.data for a=bcf.args)
bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf))
args_ = Tuple(getdata(a) for a=bcf.args)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
Expand All @@ -278,3 +287,14 @@ end
throw(IncompatibleBases())

@inline Base.copyto!(A::T,B::T) where T<:StateVector = (copyto!(A.data,B.data); A)

# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl
Base.eltype(::Type{Ket{B,A}}) where {B,N,A<:AbstractVector{N}} = N # ODE init
Base.eltype(::Type{Bra{B,A}}) where {B,N,A<:AbstractVector{N}} = N
Base.zero(k::StateVector) = typeof(k)(k.basis, zero(k.data)) # ODE init
Base.any(f::Function, x::StateVector; kwargs...) = any(f, x.data; kwargs...) # ODE nan checks
Base.all(f::Function, x::StateVector; kwargs...) = all(f, x.data; kwargs...)
Broadcast.similar(k::StateVector, t) = typeof(k)(k.basis, copy(k.data))
using RecursiveArrayTools
RecursiveArrayTools.recursivecopy!(dst::Ket{B,A},src::Ket{B,A}) where {B,A} = copy!(dst.data,src.data) # ODE in-place equations
RecursiveArrayTools.recursivecopy!(dst::Bra{B,A},src::Bra{B,A}) where {B,A} = copy!(dst.data,src.data)
2 changes: 1 addition & 1 deletion src/superoperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ end
# end
find_basis(a::SuperOperator, rest) = (a.basis_l, a.basis_r)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:SuperOperator}}, axes)
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
Expand Down

0 comments on commit c3621ca

Please sign in to comment.