Skip to content

Commit

Permalink
Additional cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf committed Oct 14, 2022
1 parent c953c6a commit a19ae9e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,7 +1,7 @@
name = "ThermodynamicIntegration"
uuid = "1022446e-a4a4-4a46-8bce-0ffd39f68cd3"
authors = ["Theo Galy-Fajou <theo.galyfajou@gmail.com> and contributors"]
version = "0.2.5"
version = "0.2.6"

[deps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Expand Down
16 changes: 14 additions & 2 deletions README.md
Expand Up @@ -20,6 +20,7 @@ For a different way of computing the evidence integral see also my [BayesianQuad

A simple package to compute Thermodynamic Integration for computing the evidence in a Bayesian setting.
You need to provide the `logprior` and the `loglikelihood` as well as an initial sample:

```julia
using Distributions, ThermodynamicIntegration
D = 5
Expand All @@ -36,7 +37,8 @@ You need to provide the `logprior` and the `loglikelihood` as well as an initial
# -8.211990123364176
```

You can also simply pass a Turing model :
You can also simply pass a Turing model:

```julia
using Turing
@model function gauss(y)
Expand All @@ -53,11 +55,21 @@ You can also simply pass a Turing model :
## Parallel sampling

The algorithm also works on multiple threads by calling :

```julia
alg = ThermInt(n_samples=5000)
logZ = alg(logprior, loglikelihood, rand(prior), TIParallelThreads())
logZ = alg(logprior, loglikelihood, rand(prior), TIThreads())
```

or on multiple processes:

```julia
alg = ThermInt(n_samples=5000)
logZ = alg(logprior, loglikelihood, rand(prior), TIDistributed())
```

Note that you need to load `ThermodynamicIntegration` and other necessary external packages on your additional processes via `@everywhere`.

## Sampling methods

Right now sampling is based on [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl), with the `ForwardDiff` AD backend.
Expand Down
4 changes: 1 addition & 3 deletions src/ThermodynamicIntegration.jl
Expand Up @@ -4,16 +4,14 @@ using AdvancedHMC
using Distributed
using ForwardDiff
using ProgressMeter
using Random
using Random: Random, AbstractRNG, default_rng
using Requires
using Statistics
using Trapz

export ThermInt
export TISerial, TIThreads, TIDistributed

const GLOBAL_RNG = Random.MersenneTwister(42)

const ADBACKEND = Ref(:ForwardDiff)

set_adbackend(ad::String) = set_adbackend(Symbol(ad))
Expand Down
11 changes: 6 additions & 5 deletions src/thermint.jl
Expand Up @@ -7,12 +7,13 @@
`(1:n_steps) ./ n_steps).^5`
A `ThermInt` object can then be used as a function:
```julia
alg = ThermInt(30)
alg(loglikelihood, logprior, x_init)
alg = ThermInt(30)
alg(loglikelihood, logprior, x_init)
```
"""
struct ThermInt{AD,TRNG,V}
struct ThermInt{AD,TRNG<:AbstractRNG,V}
schedule::V
n_samples::Int
n_warmup::Int
Expand All @@ -26,7 +27,7 @@ function ThermInt(rng::AbstractRNG, schedule; n_samples::Int=2000, n_warmup::Int
end

function ThermInt(schedule; n_samples::Int=2000, n_warmup::Int=500)
return ThermInt(GLOBAL_RNG, schedule; n_samples=n_samples, n_warmup=n_warmup)
return ThermInt(default_rng(), schedule; n_samples=n_samples, n_warmup=n_warmup)
end

function ThermInt(rng::AbstractRNG; n_steps::Int, n_samples::Int=2000, n_warmup::Int=500)
Expand All @@ -37,7 +38,7 @@ end

function ThermInt(; n_steps::Int=30, n_samples::Int=2000, n_warmup::Int=500)
return ThermInt(
GLOBAL_RNG, range(0, 1; length=n_steps) .^ 5; n_samples=n_samples, n_warmup=n_warmup
default_rng(), range(0, 1; length=n_steps) .^ 5; n_samples=n_samples, n_warmup=n_warmup
)
end

Expand Down

0 comments on commit a19ae9e

Please sign in to comment.