Skip to content

Commit

Permalink
Error bars for KL plots
Browse files Browse the repository at this point in the history
  • Loading branch information
adscib committed May 17, 2016
1 parent b81bee3 commit b387985
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 18 deletions.
62 changes: 44 additions & 18 deletions examples/anglican.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,28 @@
"using Gadfly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"set_default_plot_size(20cm, 8cm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"srand(0);"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -30,7 +52,7 @@
},
"outputs": [],
"source": [
"models = [\"anglican_gaussian\", \"anglican_hmm\", \"anglican_crp\", \"anglican_branching\"];"
"models = [\"anglican_gaussian\", \"anglican_hmm\", \"anglican_crp\"]; #Branching seems broken"
]
},
{
Expand All @@ -48,16 +70,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
"collapsed": true
},
"outputs": [],
"source": [
"results = Dict();\n",
"for m = models\n",
" for n = ns\n",
" results[(m,n)] = benchmark(m, SMC(n))\n",
" end\n",
"end"
"n_repeat = 5;"
]
},
{
Expand All @@ -68,11 +85,10 @@
},
"outputs": [],
"source": [
"println(\"times: \")\n",
"results = Dict();\n",
"for m = models\n",
" println(\"\\t\", m, \": \")\n",
" for n = ns\n",
" println(\"\\t\\t\", n, \": \", results[(m,n)][:time], \" s\")\n",
" results[(m,n)] = multibenchmark(n_repeat, m, SMC(n))\n",
" end\n",
"end"
]
Expand All @@ -85,11 +101,13 @@
},
"outputs": [],
"source": [
"plots = Dict();\n",
"println(\"times: \")\n",
"for m = models\n",
" kls = map(n -> results[(m,n)][:KL], ns)\n",
" plots[m] = plot(x = ns, y = kls, Scale.x_log, Scale.y_log, Guide.xlabel(\"Number of particles\"),\n",
" Guide.ylabel(\"KL\"), Guide.title(m))\n",
" println(\"\\t\", m, \": \")\n",
" for n = ns\n",
" t,s = results[(m,n)][:time]\n",
" println(\"\\t\\t\", n, \": \", t, \" +- \", s , \" s\")\n",
" end\n",
"end"
]
},
Expand All @@ -101,18 +119,26 @@
},
"outputs": [],
"source": [
"plots[\"anglican_gaussian\"]"
"plots = Dict();\n",
"for m = models\n",
" kls = map(n -> results[(m,n)][:KL], ns)\n",
" means = map(x -> x[1], kls)\n",
" sds = map(x -> x[2], kls)\n",
" plots[m] = plot(x = ns, y = means, ymin = max(means - sds, 0.5^10), ymax = means + sds, Scale.x_log10, Scale.y_log2,\n",
" Geom.point, Geom.errorbar, Guide.xlabel(\"Particles\"), Guide.ylabel(\"KL\"), Guide.title(m))\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"hstack(plots[\"anglican_gaussian\"], plots[\"anglican_hmm\"], plots[\"anglican_crp\"]) #Branching seems broken"
"hstack(plots[\"anglican_gaussian\"], plots[\"anglican_hmm\"], plots[\"anglican_crp\"])"
]
},
{
Expand Down
28 changes: 28 additions & 0 deletions examples/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,31 @@ end
function benchmark(modelname :: AbstractString, algs, do_eval=true, do_warmup=true)
return map(a -> benchmark(modelname, a, do_eval, do_warmup), algs)
end

function mean_sd(k, results)
values = map(r -> r[k], results)
d = fit_mle(Normal, values)
return (d.μ, d.σ)
end

function multibenchmark(n_repeat :: Int, modelname :: AbstractString, alg, do_eval=true, do_warmup=true)
# model definition
include(string(modelname, ".jl"))

# extract model and evaluation function
model = eval(symbol(modelname))
evaluate = eval(symbol(string(modelname,"_evaluate")))

if do_warmup
#warmup run
sample(model, alg)
end

results = map(x -> benchmark(modelname, alg, do_eval, false), zeros(n_repeat))

summary = Dict()
summary[:time] = mean_sd(:time, results)
summary[:KL] = mean_sd(:KL, results)

return summary
end

0 comments on commit b387985

Please sign in to comment.