Skip to content

Commit

Permalink
Replace NULL using flag
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Dec 10, 2017
1 parent a760da5 commit 15af49c
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 19 deletions.
2 changes: 0 additions & 2 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import Stan: Adapt, Hmc
# Global variables/constants #
##############################

global const NULL = NaN # constant for "delete" vals

global CHUNKSIZE = 0 # default chunksize used by AD
global SEEDS # pre-alloced dual parts
setchunksize(chunk_size::Int) = begin
Expand Down
28 changes: 21 additions & 7 deletions src/core/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ type VarInfo
pred :: Dict{Symbol,Any}
num_produce :: Int # num of produce calls from trace, each produce corresponds to an observe.
orders :: Vector{Int} # observe statements number associated with random variables
flags :: Dict{String,Vector{Bool}}

VarInfo() = begin
vals = Vector{Vector{Real}}(); push!(vals, Vector{Real}())
vals = Vector{Vector{Real}}(); push!(vals, Vector{Real}())
trans = Vector{Vector{Real}}(); push!(trans, Vector{Real}())
logp = Vector{Real}(); push!(logp, zero(Real))
pred = Dict{Symbol,Any}()
logp = Vector{Real}(); push!(logp, zero(Real))
pred = Dict{Symbol,Any}()
flags = Dict{String,Vector{Bool}}()
flags["del"] = Vector{Bool}()

new(
Dict{VarName, Int}(),
Expand All @@ -57,7 +61,8 @@ type VarInfo
trans, logp,
pred,
0,
Vector{Int}()
Vector{Int}(),
flags
)
end
end
Expand Down Expand Up @@ -203,6 +208,7 @@ push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distribution, gid::I
push!(vi.gids, gid)
push!(vi.trans[end], false)
push!(vi.orders, vi.num_produce)
push!(vi.flags["del"], [false for _ = 1:n]...)

vi
end
Expand Down Expand Up @@ -310,9 +316,17 @@ end
# Rand & replaying method for VarInfo #
#######################################

# Check if a vn is set to NULL
isdel(vi::VarInfo, vn::VarName) = any(isnan.(getval(vi, vn)))

# TODO: turn below to marco generated functions
# Check if a vn is set to del
isdel(vi::VarInfo, vn::VarName) = any(vi.flags["del"][getrange(vi, vn)])
set_vn_del!(vi::VarInfo, vn::VarName) = vi.flags["del"][getrange(vi, vn)] = true
unset_vn_del!(vi::VarInfo, vn::VarName) = vi.flags["del"][getrange(vi, vn)] = false
set_vns_del_by_spl!(vi::VarInfo, spl::Sampler) = begin
vview = getretain(vi, spl)
if length(vview) > 0
vi.flags["del"][[i for arr in vview for i in arr]] = true
end
end

updategid!(vi::VarInfo, vn::VarName, spl::Sampler) = begin
if ~isempty(spl.alg.space) && getgid(vi, vn) == 0 && getsym(vi, vn) in spl.alg.space
Expand Down
5 changes: 3 additions & 2 deletions src/samplers/pgibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ step(model::Function, spl::Sampler{PG}, vi::VarInfo) = begin
ref_particle = isempty(vi) ?
nothing :
forkr(Trace(model, spl, vi))

vi[getretain(vi, spl)] = NULL
set_vns_del_by_spl!(vi, spl)
resetlogp!(vi)

if ref_particle == nothing
Expand Down Expand Up @@ -149,6 +149,7 @@ assume{T<:Union{PG,SMC}}(spl::Sampler{T}, dist::Distribution, vn::VarName, _::Va
spl.info[:cache_updated] = CACHERESET # sanity flag mask for getidcs and getranges
r
elseif isdel(vi, vn)
unset_vn_del!(vi, vn)
r = rand(dist)
setval!(vi, vectorize(dist, r), vn)
setgid!(vi, spl.alg.gid, vn)
Expand Down
2 changes: 1 addition & 1 deletion src/samplers/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end
step(model::Function, spl::Sampler{SMC}, vi::VarInfo) = begin
particles = ParticleContainer{Trace}(model)
vi.num_produce = 0; # Reset num_produce before new sweep\.
vi[getretain(vi, spl)] = NULL
set_vns_del_by_spl!(vi, spl)
resetlogp!(vi)

push!(particles, spl.alg.n_particles, spl, vi)
Expand Down
4 changes: 2 additions & 2 deletions src/trace/trace.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module Traces
using Turing: VarInfo, Sampler, getvns, NULL, getretain
using Turing: VarInfo, Sampler, getvns, set_vns_del_by_spl!

# Trick for supressing some warning messages.
# URL: https://github.com/KristofferC/OhMyREPL.jl/issues/14#issuecomment-242886953
Expand Down Expand Up @@ -70,7 +70,7 @@ function fork(trace :: Trace, is_ref :: Bool = false)

newtrace.vi = deepcopy(trace.vi)
if is_ref
newtrace.vi[getretain(newtrace.vi, newtrace.spl)] = NULL
set_vns_del_by_spl!(newtrace.vi, newtrace.spl)
end

newtrace.task.storage[:turing_trace] = newtrace
Expand Down
7 changes: 4 additions & 3 deletions test/varinfo.jl/orders.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Turing, Base.Test
using Turing: uid, cuid, reconstruct, invlink, getvals, step, getidcs, getretain, NULL, CACHERESET
using Turing: uid, cuid, reconstruct, invlink, getvals, step, getidcs, getretain, set_vns_del_by_spl!, CACHERESET, isdel, unset_vn_del!
using Turing: VarInfo, VarName

# Mock assume method for CSMC cf src/samplers/pgibbs.jl
Expand All @@ -10,6 +10,7 @@ randr(vi::VarInfo, vn::VarName, dist::Distribution, spl::Turing.Sampler) = begin
spl.info[:cache_updated] = CACHERESET
r
elseif isdel(vi, vn)
unset_vn_del!(vi, vn)
r = rand(dist)
Turing.setval!(vi, Turing.vectorize(dist, r), vn)
Turing.setorder!(vi, vn, vi.num_produce)
Expand Down Expand Up @@ -57,7 +58,7 @@ randr(vi, vn_z3, dists[1], spl1)
vi.num_produce = 0
@test getretain(vi, spl1) == UnitRange[6:6,5:5,4:4,2:2,1:1]
@test getretain(vi, spl2) == UnitRange[3:3]
vi[getretain(vi, spl1)] = NULL
set_vns_del_by_spl!(vi, spl1)

vi.num_produce += 1
randr(vi, vn_z1, dists[1], spl1)
Expand Down Expand Up @@ -89,7 +90,7 @@ randr(vi_ref, vn_a2, dists[2], spl1)
# Change order of samples: z1,a1,z2,z3 (no a2 anymore)
vi = deepcopy(vi_ref)
vi.num_produce = 0
vi[getretain(vi, spl1)] = NULL
set_vns_del_by_spl!(vi, spl1)
vi.num_produce += 1
randr(vi, vn_z1, dists[1], spl1)
randr(vi, vn_a1, dists[2], spl1)
Expand Down
5 changes: 3 additions & 2 deletions test/varinfo.jl/varinfo.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Turing, Base.Test
using Turing: uid, cuid, reconstruct, invlink, getvals, step, getidcs, getretain, NULL
using Turing: uid, cuid, reconstruct, invlink, getvals, step, getidcs, getretain, set_vns_del_by_spl!, isdel, unset_vn_del!
using Turing: VarInfo, VarName

randr(vi::VarInfo, vn::VarName, dist::Distribution, spl::Turing.Sampler, count::Bool) = begin
Expand All @@ -8,6 +8,7 @@ randr(vi::VarInfo, vn::VarName, dist::Distribution, spl::Turing.Sampler, count::
Turing.push!(vi, vn, r, dist, spl.alg.gid)
r
elseif isdel(vi, vn)
unset_vn_del!(vi, vn)
r = rand(dist)
Turing.setval!(vi, Turing.vectorize(dist, r), vn)
r
Expand Down Expand Up @@ -63,7 +64,7 @@ randr(vi, vn_u, dists[1], spl2, true)

# println(vi)
vi.num_produce = 1
vi[getretain(vi, spl2)] = NULL
set_vns_del_by_spl!(vi, spl2)

# println(vi)

Expand Down

0 comments on commit 15af49c

Please sign in to comment.