Skip to content

Commit

Permalink
Semiclassical MCWF (#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-pl committed Sep 23, 2019
1 parent e5eda50 commit b9aa28d
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ function jump(rng, t::Float64, psi::T, J::Vector, psi_new::T, rates::Nothing) wh
if length(J)==1
operators.gemv!(complex(1.), J[1], psi, complex(0.), psi_new)
psi_new.data ./= norm(psi_new)
i=1
else
probs = zeros(Float64, length(J))
for i=1:length(J)
Expand All @@ -384,13 +385,14 @@ function jump(rng, t::Float64, psi::T, J::Vector, psi_new::T, rates::Nothing) wh
i = findfirst(cumprobs.>r)
operators.gemv!(complex(1.)/sqrt(probs[i]), J[i], psi, complex(0.), psi_new)
end
return nothing
return i
end

function jump(rng, t::Float64, psi::T, J::Vector, psi_new::T, rates::Vector{Float64}) where T<:Ket
if length(J)==1
operators.gemv!(complex(sqrt(rates[1])), J[1], psi, complex(0.), psi_new)
psi_new.data ./= norm(psi_new)
i=1
else
probs = zeros(Float64, length(J))
for i=1:length(J)
Expand All @@ -402,7 +404,7 @@ function jump(rng, t::Float64, psi::T, J::Vector, psi_new::T, rates::Vector{Floa
i = findfirst(cumprobs.>r)
operators.gemv!(complex(sqrt(rates[i]/probs[i])), J[i], psi, complex(0.), psi_new)
end
return nothing
return i
end

"""
Expand Down
181 changes: 180 additions & 1 deletion src/semiclassical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@ module semiclassical

import Base: ==
import ..bases, ..operators, ..operators_dense
import ..timeevolution: integrate, recast!
import ..timeevolution: integrate, recast!, QO_CHECKS
import ..timeevolution.timeevolution_mcwf: jump
import LinearAlgebra: normalize!

using Random, LinearAlgebra
import OrdinaryDiffEq

# TODO: Remove imports
import DiffEqCallbacks, RecursiveArrayTools.copyat_or_push!
Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T)

using ..bases, ..states, ..operators, ..operators_dense, ..timeevolution

Expand All @@ -26,6 +35,7 @@ end

Base.length(state::State) = length(state.quantum) + length(state.classical)
Base.copy(state::State) = State(copy(state.quantum), copy(state.classical))
normalize!(state::State{B,T}) where {B,T<:Ket} = normalize!(state.quantum)

function ==(a::State, b::State)
samebases(a.quantum, b.quantum) &&
Expand Down Expand Up @@ -111,6 +121,44 @@ function master_dynamic(tspan, state0::State{B,T}, fquantum, fclassical; kwargs.
master_dynamic(tspan, dm(state0), fquantum, fclassical; kwargs...)
end

"""
semiclassical.mcwf_dynamic(tspan, psi0, fquantum, fclassical, fjump_classical; <keyword arguments>)
Calculate MCWF trajectories coupled to a classical system.
# Arguments
* `tspan`: Vector specifying the points of time for which output should
be displayed.
* `rho0`: Initial semi-classical state [`semiclassical.State`](@ref).
* `fquantum`: Function `f(t, rho, u) -> (H, J, Jdagger)` returning the time
and/or state dependent Hamiltonian and Jump operators.
* `fclassical`: Function `f(t, rho, u, du)` calculating the possibly time and
state dependent derivative of the classical equations and storing it
in the complex vector `du`.
* `fjump_classical`: Function `f(t, rho, u, i)` making a classical jump when a
quantum jump of the i-th jump operator occurs.
* `fout=nothing`: If given, this function `fout(t, state)` is called every time
an output should be displayed. ATTENTION: The given state is not
permanent!
* `kwargs...`: Further arguments are passed on to the ode solver.
"""
function mcwf_dynamic(tspan, psi0::State{B,T}, fquantum, fclassical, fjump_classical;
seed=rand(UInt),
rates::DecayRates=nothing,
fout::Union{Function,Nothing}=nothing,
kwargs...) where {B<:Basis,T<:Ket{B}}
tspan_ = convert(Vector{Float64}, tspan)
tmp=copy(psi0.quantum)
function dmcwf_(t::Float64, psi::S, dpsi::S) where {B<:Basis,T<:Ket{B},S<:State{B,T}}
dmcwf_h_dynamic(t, psi, fquantum, fclassical, rates, dpsi, tmp)
end
j_(rng, t::Float64, psi, psi_new) = jump_dynamic(rng, t, psi, fquantum, fclassical, fjump_classical, psi_new, rates)
x0 = Vector{ComplexF64}(undef, length(psi0))
recast!(psi0, x0)
psi = copy(psi0)
dpsi = copy(psi0)
integrate_mcwf(dmcwf_, j_, tspan_, psi, seed, fout; kwargs...)
end

function recast!(state::State{B,T,C}, x::C) where {B<:Basis,T<:QuantumState{B},C<:Vector{ComplexF64}}
N = length(state.quantum)
Expand Down Expand Up @@ -139,4 +187,135 @@ function dmaster_h_dynamic(t::Float64, state::State{B,T}, fquantum::Function,
fclassical(t, state.quantum, state.classical, dstate.classical)
end

function dmcwf_h_dynamic(t::Float64, psi::T, fquantum::Function, fclassical::Function, rates::DecayRates,
dpsi::T, tmp::K) where {T,K}
fquantum_(t, rho) = fquantum(t, psi.quantum, psi.classical)
timeevolution.timeevolution_mcwf.dmcwf_h_dynamic(t, psi.quantum, fquantum_, rates, dpsi.quantum, tmp)
fclassical(t, psi.quantum, psi.classical, dpsi.classical)
end

function jump_dynamic(rng, t::Float64, psi::T, fquantum::Function, fclassical::Function, fjump_classical::Function, psi_new::T, rates::DecayRates) where T<:State
result = fquantum(t, psi.quantum, psi.classical)
QO_CHECKS[] && @assert 3 <= length(result) <= 4
J = result[2]
if length(result) == 3
rates_ = rates
else
rates_ = result[4]
end
i = jump(rng, t, psi.quantum, J, psi_new.quantum, rates_)
fjump_classical(t, psi_new.quantum, psi.classical, i)
psi_new.classical .= psi.classical
end

function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Function;
display_beforeevent=false, display_afterevent=false,
#TODO: Remove kwargs
save_everystep=false, callback=nothing,
alg=OrdinaryDiffEq.DP5(),
kwargs...) where {B<:Basis,T<:State}

tmp = copy(psi0)
psi_tmp = copy(psi0)
x0 = [psi0.quantum.data; psi0.classical]
rng = MersenneTwister(convert(UInt, seed))
jumpnorm = Ref(rand(rng))
n = length(psi0.quantum)
djumpnorm(x::Vector{ComplexF64}, t::Float64, integrator) = norm(x[1:n])^2 - (1-jumpnorm[])

if !display_beforeevent && !display_afterevent
function dojump(integrator)
x = integrator.u
recast!(x, psi_tmp)
t = integrator.t
jumpfun(rng, t, psi_tmp, tmp)
recast!(tmp, x)
jumpnorm[] = rand(rng)
end
cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (display_beforeevent,display_afterevent))


timeevolution.integrate(float(tspan), dmcwf, x0,
copy(psi0), copy(psi0), fout;
callback = cb,
kwargs...)
else
# Temporary workaround until proper tooling for saving
# TODO: Replace by proper call to timeevolution.integrate
function fout_(x::Vector{ComplexF64}, t::Float64, integrator)
recast!(x, state)
fout(t, state)
end

state = copy(psi0)
dstate = copy(psi0)
out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)})
out = DiffEqCallbacks.SavedValues(Float64,out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan,
save_everystep=save_everystep,
save_start = false)

function dojump_display(integrator)
x = integrator.u
t = integrator.t

affect! = scb.affect!
if display_beforeevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end

recast!(x, psi_tmp)
jumpfun(rng, t, psi_tmp, tmp)
recast!(tmp, x)

if display_afterevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end
jumpnorm[] = rand(rng)
end

cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump_display,
save_positions = (false,false))
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)

function df_(dx::Vector{ComplexF64}, x::Vector{ComplexF64}, p, t)
recast!(x, state)
recast!(dx, dstate)
dmcwf(t, state, dstate)
recast!(dstate, dx)
end

prob = OrdinaryDiffEq.ODEProblem{true}(df_, x0,(tspan[1],tspan[end]))

sol = OrdinaryDiffEq.solve(
prob,
alg;
reltol = 1.0e-6,
abstol = 1.0e-8,
save_everystep = false, save_start = false,
save_end = false,
callback=full_cb, kwargs...)
return out.t, out.saveval
end
end

function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Nothing;
kwargs...) where {T<:State}
function fout_(t::Float64, x::T)
psi = copy(x)
normalize!(psi)
return psi
end
integrate_mcwf(dmcwf, jumpfun, tspan, psi0, seed, fout_; kwargs...)
end

end # module
52 changes: 51 additions & 1 deletion test/test_semiclassical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,54 @@ semiclassical.master_dynamic(T, state0, fquantum_master, fclassical; fout=f)
tout, state_t = semiclassical.master_dynamic(T, state0, fquantum_master, fclassical)
f(T[end], state_t[end])

end # testset
# Test mcwf
# Set up system where only atom can jump once
ba = SpinBasis(1//2)
bf = FockBasis(5)
sm = sigmam(ba)one(bf)
a = one(ba)destroy(bf)
H = 0*sm
J = [0*a,sm]
Jdagger = dagger.(J)
function fquantum(t,psi,u)
return H, J, Jdagger
end
function fclassical(t,psi,u,du)
du[1] = u[2] # dx
du[2] = 0.0
end
njumps = [0]
function fjump_classical(t,psi,u,i)
@test i==2
njumps .+= 1
u[2] += 1.0
end
u0 = rand(2) .+ 0.0im
ψ0 = semiclassical.State(spinup(ba)fockstate(bf,0),u0)

tout1, ψt1 = semiclassical.mcwf_dynamic(T,ψ0,fquantum,fclassical,fjump_classical,seed=1)
@test njumps == [1]
tout2, ψt2 = semiclassical.mcwf_dynamic(T,ψ0,fquantum,fclassical,fjump_classical,seed=1)
@test ψt2 == ψt1
tout3, ψt3 = semiclassical.mcwf_dynamic(T,ψ0,fquantum,fclassical,fjump_classical;display_beforeevent=true,seed=1)
@test length(ψt3) == length(ψt1)+1
tout4, ψt4 = semiclassical.mcwf_dynamic(T,ψ0,fquantum,fclassical,fjump_classical;display_beforeevent=true,display_afterevent=true,seed=1)
@test length(ψt4) == length(ψt1)+2
tout5, ut = semiclassical.mcwf_dynamic(T,ψ0,fquantum,fclassical,fjump_classical;display_beforeevent=true,display_afterevent=true,seed=1,fout=(t,psi)->copy(psi.classical))

@test ψt1[end].classical[2] == u0[2] + 1.0

# Test classical jump behavior
before_jump = findfirst(t -> !(tT), tout3)
after_jump = findlast(t-> !(tT), tout4)
@test after_jump == before_jump+1
@test ψt3[before_jump].classical[2] == u0[2]
@test ψt4[after_jump].classical[2] == u0[2] + 1.0
@test ut ==.classical for ψ=ψt4]

# Test quantum jumps
@test ψt1[end].quantum == spindown(ba)fockstate(bf,0)
@test ψt4[before_jump].quantum == ψ0.quantum
@test ψt4[after_jump].quantum == spindown(ba)fockstate(bf,0)

end # testsets

0 comments on commit b9aa28d

Please sign in to comment.