Skip to content

Commit

Permalink
Make index optional arg and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
david-pl committed Jan 22, 2018
1 parent 2f40924 commit 6c8ad7e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 9 deletions.
27 changes: 18 additions & 9 deletions src/particle.jl
Expand Up @@ -299,16 +299,25 @@ function transform(basis_l::PositionBasis, basis_r::MomentumBasis; ket_only::Boo
end
end

function transform(basis_l::CompositeBasis, basis_r::CompositeBasis; ket_only::Bool=false)
function transform(basis_l::CompositeBasis, basis_r::CompositeBasis; ket_only::Bool=false, index::Vector{Int}=Int[])
@assert length(basis_l.bases) == length(basis_r.bases)
check_pos = typeof.(basis_l.bases) .== PositionBasis
check_mom = typeof.(basis_l.bases) .== MomentumBasis
if any(check_pos) && !any(check_mom)
@assert all(typeof.(basis_r.bases[check_pos]) .== MomentumBasis)
transform_xp(basis_l, basis_r, [1:length(basis_l.bases);][check_pos]; ket_only=ket_only)
elseif any(check_mom) && !any(check_pos)
@assert all(typeof.(basis_r.bases[check_pos]) .== PositionBasis)
transform_px(basis_l, basis_r, [1:length(basis_l.bases);][check_mom]; ket_only=ket_only)
if length(index) == 0
check_pos = typeof.(basis_l.bases) .== PositionBasis
check_mom = typeof.(basis_l.bases) .== MomentumBasis
if any(check_pos) && !any(check_mom)
index = [1:length(basis_l.bases);][check_pos]
elseif any(check_mom) && !any(check_pos)
index = [1:length(basis_l.bases);][check_mom]
else
throw(IncompatibleBases())
end
end
if all(typeof.(basis_l.bases[index]) .== PositionBasis)
@assert all(typeof.(basis_r.bases[index]) .== MomentumBasis)
transform_xp(basis_l, basis_r, index; ket_only=ket_only)
elseif all(typeof.(basis_l.bases[index]) .== MomentumBasis)
@assert all(typeof.(basis_r.bases[index]) .== PositionBasis)
transform_px(basis_l, basis_r, index; ket_only=ket_only)
else
throw(IncompatibleBases())
end
Expand Down
37 changes: 37 additions & 0 deletions test/test_particle.jl
Expand Up @@ -303,6 +303,10 @@ psi_x_fft = dagger(tensor(psi0_p...))*Tpx
psi_x_fft2 = tensor((dagger.(psi0_p).*Tpx_sub)...)
@test norm(psi_p_fft - psi_p_fft2) < 1e-15

psi_x_fft = Txp*tensor(psi0_p...)
psi_x_fft2 = tensor(Txp_sub...)*tensor(psi0_p...)
@test norm(psi_x_fft - psi_x_fft2) < 1e-15

# Test composite basis of mixed type
bc = FockBasis(2)
psi_fock = fockstate(FockBasis(2), 1)
Expand All @@ -327,4 +331,37 @@ Txp_sub = [transform(basis_position[i], basis_momentum[i]) for i=1:2]
difference = (full(Txp) - tensor(full(Txp_sub[1]), full(one(bc)), full(Txp_sub[2]))).data
@test isapprox(difference, zeros(difference); atol=1e-12)

basis_l = tensor(bc, basis_position[1], basis_position[2])
basis_r = tensor(bc, basis_momentum[1], basis_momentum[2])
Txp2 = transform(basis_l, basis_r)
Tpx2 = transform(basis_r, basis_l)
difference = (full(Txp) - permutesystems(full(Txp2), [2, 1, 3])).data
@test isapprox(difference, zeros(difference); atol=1e-13)
difference = (full(dagger(Txp)) - permutesystems(full(Tpx2), [2, 1, 3])).data
@test isapprox(difference, zeros(difference); atol=1e-13)

# Test error messages
b1 = PositionBasis(-1, 1, 50)
b2 = MomentumBasis(-1, 1, 30)
@test_throws bases.IncompatibleBases transform(b1, b2)
@test_throws bases.IncompatibleBases transform(b2, b1)

bc1 = b1 bc
bc2 = b2 bc
@test_throws bases.IncompatibleBases transform(bc1, bc2)
@test_throws bases.IncompatibleBases transform(bc2, bc1)

b1 = PositionBasis(-1, 1, 50)
b2 = MomentumBasis(-1, 1, 50)
bc1 = b1 bc
bc2 = b2 bc
@test_throws bases.IncompatibleBases transform(bc1, bc2)
@test_throws bases.IncompatibleBases transform(bc2, bc1)
@test_throws bases.IncompatibleBases transform(bc1, bc2; index=[2])

bc1 = b1 b2
bc2 = b1 b2
@test_throws bases.IncompatibleBases transform(bc1, bc2)
@test_throws bases.IncompatibleBases transform(bc2, bc1)

end # testset

0 comments on commit 6c8ad7e

Please sign in to comment.