Skip to content

Latest commit

 

History

History
114 lines (96 loc) · 3.41 KB

combinators.md

File metadata and controls

114 lines (96 loc) · 3.41 KB

[Generative Combinators](@id combinators_tutorial)

Generative function combinators are Julia functions that take one or more generative functions as input and return a new generative function. Generative function combinators are used to express patterns of repeated computation that appear frequently in generative models. Some generative function combinators are similar to higher order functions from functional programming languages.

Map combinator

In the schematic below, the kernel is denoted \mathcal{G}_{\mathrm{k}}.

<div style="text-align:center">
    <img src="./images/map_combinator.png" alt="schematic of map combinator" width="50%"/>
</div>

For example, consider the following generative function, which makes one random choice at address r^2:

using Gen
@gen function foo(x, y, z)
    r ~ normal(x^2 + y^2 + z^2, 1.0)
    return r
end

We apply the map combinator to produce a new generative function bar:

bar = Map(foo)

We can then obtain a trace of bar:

trace, _ = generate(bar, ([0.0, 0.5], [0.5, 1.0], [1.0, -1.0]))
trace

This causes foo to be invoked twice, once with arguments (0.0, 0.5, 1.0) in address namespace 1 and once with arguments (0.5, 1.0, -1.0) in address namespace 2.

get_choices(trace)

If the resulting trace has random choices: then the return value is:

get_retval(trace)

Unfold combinator

In the schematic below, the kernel is denoted \mathcal{G}_{\mathrm{k}}. The initial state is denoted y_0, the number of applications is n, and the remaining arguments to the kernel not including the state, are z.

<div style="text-align:center">
    <img src="./images/unfold_combinator.png" alt="schematic of unfold combinator" width="70%"/>
</div>

For example, consider the following kernel, with state type Bool, which makes one random choice at address :z:

using Gen
@gen function foo(t::Int, y_prev::Bool, z1::Float64, z2::Float64)
    y = @trace(bernoulli(y_prev ? z1 : z2), :y)
    return y
end

We apply the map combinator to produce a new generative function bar:

bar = Unfold(foo)

We can then obtain a trace of bar:

trace, _ = generate(bar, (5, false, 0.05, 0.95))
trace

This causes foo to be invoked five times. The resulting trace may contain the following random choices:

get_choices(trace)

then the return value is:

get_retval(trace)

Switch combinator

<div style="text-align:center">
    <img src="./images/switch_combinator.png" alt="schematic of switch combinator" width="100%"/>
</div>

Consider the following constructions:

using Gen
@gen function line(x)
    z ~ normal(3*x+1,1.0)
    return z
end

@gen function outlier(x)
    z ~ normal(3*x+1, 10.0)
    return z
end

switch_model = Switch(line, outlier)

This creates a new generative function switch_model whose arguments take the form (branch, args...). By default, branch is an integer indicating which generative function to execute. For example, branch 2 corresponds to outlier:

trace = simulate(switch_model, (2, 5.0))
get_choices(trace)