Skip to content

Commit

Permalink
Added IArray.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hong Ge committed Aug 22, 2016
1 parent 3fe5441 commit b783424
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 26 deletions.
66 changes: 66 additions & 0 deletions notebooks/demo-crp.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"#DP mixture of Gaussians as used in Anglican\n",
"using Turing, Distributions\n",
"eval(Turing, :(debug_level=-1))\n",
"\n",
"data = [1.0, 1.1, 1.2, -10, -15, -20, 0.01, 0.1, 0.05, 0]\n",
"N = length(data)\n",
"\n",
"\n",
"@model dpmixa begin\n",
" @assume ϕ ~ Inf[Normal(0, 10)]\n",
"\n",
" dprintln(0, ϕ)\n",
" urn = CRP(1.72)\n",
" μ, classes = tzeros(N), tzeros(Int, N)\n",
" for i in 1:N\n",
" classes[i] = rand!(urn)\n",
" μ[i] = ϕ[classes[i]]\n",
" @observe data[i] ~ Normal(μ[i], 1)\n",
" end\n",
" K = length(unique(classes))\n",
"\n",
" @predict K μ\n",
"end\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Collect 50 samples from Particle Gibbs sampler\n",
"@time resulta = sample(dpmixa, PG(20, 50))\n",
"\n",
"macroexpandTuring.TURING[:modelex]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 0.4.6",
"language": "julia",
"name": "julia-0.4"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "0.4.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
39 changes: 18 additions & 21 deletions notebooks/hmmdemo.ipynb → notebooks/demo-hmm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 2,
"metadata": {
"collapsed": false
},
Expand All @@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 12,
"metadata": {
"collapsed": false
},
Expand All @@ -37,7 +37,7 @@
"hmmdemo (generic function with 1 method)"
]
},
"execution_count": 19,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -49,30 +49,35 @@
"means = collect(1.0:K)\n",
"\n",
"@model hmmdemo begin\n",
" states = tzeros(Int,N)\n",
" states = tzeros(Int,N)\n",
" # T = TArray{Array{Float64,}}\n",
" \n",
" # Prior over T\n",
" # for i=1:K, @assume T[i,:] ~ Dirichlet(ones(K)./K); end\n",
" \n",
" @assume states[1] ~ Categorical(initial)\n",
" for i = 2:N\n",
" @assume states[i] ~ Categorical(vec(T[states[i-1],:]))\n",
" @observe obs[i] ~ Normal(means[states[i]], 4)\n",
" @observe obs[i] ~ Normal(means[states[i]], 10)\n",
" end\n",
" @predict states\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"chain = sample(hmmdemo, PG(50,100));"
"chain = sample(hmmdemo, PG(100,100));"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 14,
"metadata": {
"collapsed": false
},
Expand All @@ -81,10 +86,10 @@
"data": {
"text/plain": [
"1x51 Array{Int64,2}:\n",
" 7 2 7 9 7 7 9 4 5 8 9 4 … 10 10 10 7 9 8 9 10 7 9 4"
" 9 1 1 9 10 10 10 7 8 9 7 8 … 5 8 9 10 10 7 8 9 8 9 6"
]
},
"execution_count": 21,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -96,7 +101,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 15,
"metadata": {
"collapsed": false
},
Expand All @@ -108,7 +113,8 @@
"PyPlot.ion()\n",
"# . . . and that previous plots are overwritten\n",
"PyPlot.hold(false)\n",
"figure(figsize=(8,6))\n",
"fig = gcf()\n",
"fig[:figsize]=(8,6)\n",
"\n",
"for i=1:50\n",
" # Plot trajectories\n",
Expand All @@ -119,15 +125,6 @@
"end"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
3 changes: 2 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Turing.Traces
export @model, @assume, @observe, @predict, InferenceAlgorithm, IS, SMC, PG, sample

# Turing-safe data structures and associated functions
export TArray, tzeros, localcopy
export TArray, tzeros, localcopy, IArray

# Debugging helpers
export dprintln
Expand All @@ -24,6 +24,7 @@ include("core/intrinsic.jl")
include("core/conditional.jl")
include("core/container.jl")
include("core/io.jl")
include("core/IArray.jl")
include("samplers/sampler.jl")
include("distributions/bnp.jl")

Expand Down
29 changes: 29 additions & 0 deletions src/core/IArray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## a lazy, infinite collection of iid params
type IArray
distr :: Distribution
vals :: Dict{Int,Any}
end

DD = Union{Distributions.Distribution{Distributions.Univariate, Distributions.ValueSupport},
Distributions.Distribution{Distributions.Multivariate, Distributions.ValueSupport},
Distributions.Distribution{Distributions.Matrixvariate, Distributions.ValueSupport},
ConjugatePriors.NormalGamma, ConjugatePriors.NormalWishart,
ConjugatePriors.NormalInverseGamma, ConjugatePriors.NormalInverseWishart}

IArray(distr::Distribution) = IArray(distr, Dict{Int,Any}())
IArray(distr::Distribution, vals::Dict{Int,Any}) = IArray(distr, vals)
Distributions.logpdf(x :: IArray, t :: Bool) = mapreduce(v -> logpdf(x.distr, v, t), +, values(x.vals))
Distributions.logpdf(d :: DD, x :: IArray, t :: Bool) = logpdf(x :: IArray, t :: Bool)


function Base.getindex(x :: IArray, i)
global sampler
if i in keys(x.vals)
x.vals[i]
else
x.vals[i] = rand(current_trace(), x.distr)
end
end

# This function is not compitable with replaying.
# Base.setindex!(x :: IArray, val, key) = ( x.vals[key] = val )
13 changes: 10 additions & 3 deletions src/core/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@
## a valid distribution from the Distributions package
macro assume(ex)
@assert ex.args[1] == symbol("@~")
esc(quote
$(ex.args[2]) = Turing.assume(Turing.sampler, $(ex.args[3]))
end)
if ex.args[3].args[1] == :Inf # e.g. x ~ Inf[Normal(0,1)]
dprintln(0, (ex.args[2]), " ~ Inf[", (ex.args[3].args[2]), "]")
return esc(quote
$(ex.args[2]) = Turing.assume(Turing.sampler, $(ex.args[3].args[2]), Val{true})
end)
else
return esc(quote
$(ex.args[2]) = Turing.assume(Turing.sampler, $(ex.args[3]), Val{false})
end)
end
end

## Usage: @observe(x ~ Dist) where x is a value and Dist is a valid distribution
Expand Down
2 changes: 2 additions & 0 deletions src/core/intrinsic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ function sample(model::Function, alg :: InferenceAlgorithm)
end

assume(spl :: Sampler, distr :: Distribution) = rand(current_trace(), distr)
assume(spl :: Sampler, distr :: Distribution, ::Type{Val{false}}) = assume(spl, distr)
assume(spl :: Sampler, distr :: Distribution, ::Type{Val{true}}) = IArray(distr)
observe(spl :: Sampler, score :: Float64) = produce(score)

function predict(spl :: Sampler, v_name :: Symbol, value)
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/bnp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ end

CRP(alpha) = CRP(convert(TArray, [alpha]), alpha)

function randclass(urn::CRP)
function Base.rand!(urn::CRP)
counts = localcopy(urn.counts)
weights = counts ./ sum(counts)
@assume c ~ Categorical(weights)
Expand Down

0 comments on commit b783424

Please sign in to comment.