Skip to content

Commit

Permalink
Remove Task.storage and use flatten names internally
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Apr 2, 2017
1 parent 77f81a9 commit 7da3749
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 56 deletions.
2 changes: 1 addition & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Base: ~, convert, promote_rule

# Turing essentials - modelling macros and inference algorithms
export @model, @sample, @~, InferenceAlgorithm, HMC, IS, SMC, PG, Gibbs, sample, Chain, Sample, Sampler, ImportanceSampler, HMCSampler
export VarName, VarInfo, nextvn, randrn, randrc
export VarName, VarInfo, nextvn, randrn, randrc, randoc
export Dual

# Export Mamba Chain utility functions
Expand Down
12 changes: 1 addition & 11 deletions src/core/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,10 @@ end

########### Auxilary Functions ###################


# NOTE: Particle is a type alias of Trace
Base.keys(p :: Particle) = keys(p.task.storage[:turing_predicts])
Base.values(p :: Particle) = values(p.task.storage[:turing_predicts])
Base.getindex(p :: Particle, args...) = getindex(p.task.storage[:turing_predicts], args...)

# ParticleContainer: particles ==> (weight, results)
function getsample(pc :: ParticleContainer, i :: Int, w :: Float64 = 0.)
p = pc.vals[i]

predicts = Dict{Symbol, Any}()
for k in keys(p)
predicts[k] = p[k]
end
predicts = varInfo2samples(p.vi)
return Sample(w, predicts)
end

Expand Down
21 changes: 18 additions & 3 deletions src/core/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ end

uid(vn::VarName) = (vn.csym, vn.sym, vn.indexing, vn.counter)
string(vn::VarName) = "{$(vn.csym),$(vn.sym)$(vn.indexing)}:$(vn.counter)"
sym(vn::VarName) = Symbol("$(vn.sym)$(vn.indexing)") # simplified symbol
sym(t::Tuple{Symbol,Symbol,String,Int64}) = Symbol("$(t[2])$(t[3])")

isequal(x::VarName, y::VarName) = uid(x) == uid(y)
==(x::VarName, y::VarName) = isequal(x, y)

Expand Down Expand Up @@ -103,9 +106,6 @@ uids(vi::VarInfo) = union(Set(keys(vi.idcs)), Set(vi.names))
Base.keys(vi::VarInfo) = map(t -> VarName(t...), keys(vi.idcs))
Base.haskey(vi::VarInfo, vn::VarName) = haskey(vi.idcs, uid(vn))




nextvn(vi::VarInfo, csym::Symbol, sym::Symbol, indexing::String) = begin
# TODO: update this method when VarInfo internal structure is updated
VarName(csym, sym, indexing, 1)
Expand Down Expand Up @@ -153,3 +153,18 @@ function randrc(vi::VarInfo, vn::VarName, dist::Distribution)
end
r
end

# Randome with force overwriting by counter
function randoc(vi::VarInfo, vn::VarName, dist::Distribution)
vi.index += 1
r = Distributions.rand(dist)
if vi.index <= length(vi.randomness)
vi.randomness[vi.index] = r
else # sample, record
@assert ~(vn in vi.names) "[randr(trace)] attempt to generate an exisitng variable $name to $(vi)"
push!(vi.randomness, r)
push!(vi.names, vn)
push!(vi.tsyms, vn.sym)
end
r
end
10 changes: 2 additions & 8 deletions src/samplers/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,8 @@ function sample(model::Function, data::Dict, alg::InferenceAlgorithm)
Base.run(model, data, sampler)
end

assume(spl::ParticleSampler, dist::Distribution, vn::VarName, vi) = rand(current_trace(), dist)
assume(spl::ParticleSampler, dist::Distribution, vn::VarName, vi) = rand(current_trace(), vn, dist)

observe(spl :: ParticleSampler, d :: Distribution, value, varInfo) = produce(logpdf(d, value))

function predict(spl :: Sampler, v_name :: Symbol, value)
task = current_task()
if ~haskey(task.storage, :turing_predicts)
task.storage[:turing_predicts] = Dict{Symbol,Any}()
end
task.storage[:turing_predicts][v_name] = isa(value, Dual) ? realpart(value) : value
end
predict(spl :: Sampler, v_name :: Symbol, value) = nothing
1 change: 0 additions & 1 deletion src/samplers/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ function Base.run(model, data, spl::Sampler{SMC})
resample!(spl.particles,use_replay=spl.alg.use_replay)
end
end

res = Chain(getsample(spl.particles)...)

end
25 changes: 6 additions & 19 deletions src/samplers/support/gibbs_helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,14 @@ function varInfo2samples(vi)
val = vi[uid]
val = reconstruct(dist, val)
val = invlink(dist, val)
sym = getsym(vi, uid)
if ~(sym in keys(samples))
samples[sym] = Any[realpart(val)]
else
push!(samples[sym], realpart(val))
end
val = Any[realpart(val)]
val = length(val) == 1 ? val[1] : val # Remove un-necessary []'s
samples[sym(uid)] = val
end
for i = 1:length(vi.tsyms)
for i = 1:length(vi.names)
uid = vi.names[i]
val = vi.randomness[i]
sym = vi.tsyms[i]
if ~(sym in keys(samples))
samples[sym] = Any[val]
else
push!(samples[sym], val)
end
end
# Remove un-necessary []'s
for k in keys(samples)
if isa(samples[k], Array) && length(samples[k]) == 1
samples[k] = samples[k][1]
end
samples[sym(uid)] = val
end
samples
end
8 changes: 4 additions & 4 deletions src/trace/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Notes:
module Traces
using Distributions
using Turing: VarName, VarInfo
import Turing.randrc
import Turing.randrc, Turing.randoc

# Trick for supressing some warning messages.
# URL: https://github.com/KristofferC/OhMyREPL.jl/issues/14#issuecomment-242886953
Expand Down Expand Up @@ -78,10 +78,10 @@ typealias TraceC Trace{:C} # Replay
randr(t::Trace, vn::VarName, distr::Distribution) = randrc(t.vi, vn, distr)

# generate a new random variable, no replay
randc(t::Trace, distr :: Distribution) = Distributions.rand(distr)
randc(t::Trace, vn::VarName, distr :: Distribution) = randoc(t.vi, vn, distr)

Distributions.rand(t::TraceR, vn::VarName, dist::Distribution) = randr(t, vn, dist)
Distributions.rand(t::TraceC, vn::VarName, dist::Distribution) = randc(t, dist)
Distributions.rand(t::TraceC, vn::VarName, dist::Distribution) = randc(t, vn, dist)

Distributions.rand(t::TraceR, distr :: Distribution) = randr(t, distr)
Distributions.rand(t::TraceC, distr :: Distribution) = randc(t, distr)
Expand All @@ -98,7 +98,7 @@ function forkc(trace :: Trace)
newtrace.vi.vals = trace.vi.vals
newtrace.vi.syms = trace.vi.syms
newtrace.vi.dists = trace.vi.dists
newtrace.vi.randomness = trace.vi.randomness[1:n_rand]
newtrace.vi.randomness = deepcopy(trace.vi.randomness[1:n_rand])
newtrace.vi.names = trace.vi.names[1:n_rand]
newtrace.vi.tsyms = trace.vi.tsyms[1:n_rand]
newtrace.vi.index = trace.vi.index
Expand Down
23 changes: 18 additions & 5 deletions test/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,23 @@ x = [1.5 2.0]
s, m
end

gibbs = Gibbs(2000, PG(30, 3, :s), HMC(2, 0.1, 3, :m))
gibbs = Gibbs(2000, PG(30, 3, :s), HMC(2, 0.1, 7, :m))
chain = @sample(gibbstest(x), gibbs)

Turing.TURING[:modelex]
# # print(mean(chain[:s]) )
# @test_approx_eq_eps mean(chain[:s]) 49/24 0.15
# @test_approx_eq_eps mean(chain[:m]) 7/6 0.15
print(" 1. s ≈ 49/24 (ϵ = 0.15)")
ans1 = abs(mean(chain[:s]) - 49/24) <= 0.15
if ans1
print_with_color(:green, "\n")
else
print_with_color(:red, " X\n")
print_with_color(:red, " s = $(mean(chain[:s])), diff = $(abs(mean(chain[:s]) - 49/24))\n")
end

print(" 2. m ≈ 7/6 (ϵ = 0.15)")
ans2 = abs(mean(chain[:m]) - 7/6) <= 0.15
if ans2
print_with_color(:green, "\n")
else
print_with_color(:red, " X\n")
print_with_color(:red, " m = $(mean(chain[:m])), diff = $(abs(mean(chain[:m]) - 7/6))\n")
end
11 changes: 9 additions & 2 deletions test/particlecontainer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@ using Distributions

import Turing: ParticleContainer, weights, resample!, effectiveSampleSize, TraceC, TraceR, Trace, current_trace

global n = 0

function f()
global n
t = TArray(Float64, 1);
t[1] = 0;
while true
rand(current_trace(), Normal(0,1))
vn = VarName(gensym(), :x, "[$n]", 1)
rand(current_trace(), vn, Normal(0,1))
n += 1
produce(0)
rand(current_trace(), Normal(0,1))
vn = VarName(gensym(), :x, "[$n]", 1)
rand(current_trace(), vn, Normal(0,1))
n += 1
t[1] = 1 + t[1]
end
end
Expand Down
5 changes: 3 additions & 2 deletions test/test_varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ v_mat = eval(varname(:((x[1,2][1+5][45][3][i])))[1])
p
end
chain = sample(mat_name_test, HMC(1000, 0.75, 2))
@test_approx_eq_eps mean(mean(chain[:p])) 0 0.25

@test_approx_eq_eps mean(mean(chain[Symbol("p[1,1]")])) 0 0.25

# Multi array
i, j = 1, 2
Expand All @@ -50,4 +51,4 @@ v_arrarr = eval(varname(:(x[i][j]))[1])
p
end
chain = sample(marr_name_test, HMC(1000, 0.75, 2))
@test_approx_eq_eps mean(mean(mean(chain[:p]))) 0 0.25
@test_approx_eq_eps mean(mean(mean(chain[Symbol("p[1][1]")]))) 0 0.25

0 comments on commit 7da3749

Please sign in to comment.