Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change hint propagation for user-defined structs. #434

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ MacroTools = "0.5"
Parameters = "0.12"
ReverseDiff = "1.4, 1.5"
SpecialFunctions = "0.8, 0.9, 0.10, 1"
julia = "1"
julia = "1.3"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
49 changes: 49 additions & 0 deletions docs/src/ref/extending.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,55 @@ apply_with_state
update_with_state
```

### Incremental computation for user-defined types

Gen provides the generative functions [`Construct`](@ref) and
[`GetField`](@ref) to support incremental computation for user-defined composite
types, i.e., structs. They are useful when structs contain fields which are
returned by other generative functions, or sampled from distributions. For
example, we might define a struct representing 2D points:
```julia
struct Point2f0
x::Float64
y::Float64
end
```
We might then use [`Construct`](@ref) and [`GetField`](@ref) within the
following model of point motion:
```julia
@gen (static) function point_motion()
x ~ normal(0, 1)
y ~ normal(0, 1)
p ~ Construct(Point2f0)(x, y)
new_x ~ x_motion(p)
new_y ~ y_motion(p)
end

@gen (static) function x_motion(p::Point2f0)
x ~ GetField(Point2f0, :x)(p::Point2f0)
new_x = long_computation_for_x(x)
return new_x
end

@gen (static) function y_motion(p::Point2f0)
y ~ GetField(Point2f0, :y)(p::Point2f0)
new_y = long_computation_for_y(y)
return new_y
end
```
Because [`Construct`](@ref) propagates changes to each field of `Point2f0`
separately, writing the model in this way ensures that MCMC moves that adjust
`x` only result in `x_motion` being re-run, while `y_motion` is not recomputed.

Note that that `Construct` should be used with constructors `T(args...)`
where the ``n``th argument corresponds to the ``n``th field of the type `T`.
Default constructors meet this requirement. Other constructors are not
guaranteed to propagate changes correctly.

```@docs
Construct
GetField
```

## [Custom distributions](@id custom_distributions)

Expand Down
2 changes: 1 addition & 1 deletion docs/src/ref/modeling.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ The functions are also subject to the following restrictions:

- Default argument values are not supported.

- Julia closures are not allowed.
- Constructing named or anonymous Julia functions (and closures) is not allowed.

- List comprehensions with internal `@trace` calls are not allowed.

Expand Down
2 changes: 1 addition & 1 deletion src/dsl/dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ function parse_gen_function(ast, annotations, __module__)
return_type = get(def, :rtype, :Any)
static = DSL_STATIC_ANNOTATION in annotations
if static
make_static_gen_function(name, args, body, return_type, annotations)
make_static_gen_function(name, args, body, return_type, annotations, __module__)
else
args = map(a -> resolve_grad_arg(a, __module__), args)
make_dynamic_gen_function(name, args, body, return_type, annotations)
Expand Down
62 changes: 37 additions & 25 deletions src/dsl/static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ gen_node_name(arg::Symbol) = gensym(arg)
gen_node_name(arg::QuoteNode) = gensym(string(arg.value))

"Parse @trace expression and add corresponding node to IR."
function parse_trace_expr!(stmts, bindings, fn, args, addr)
function parse_trace_expr!(stmts, bindings, fn, args, addr, __module__)
expr_s = "@trace($fn($(join(args, ", "))), $addr)"
name = gen_node_name(addr) # Each @trace node is named after its address
node = gen_node_name(addr) # Generate a variable name for the StaticIRNode
Expand Down Expand Up @@ -91,7 +91,7 @@ function parse_trace_expr!(stmts, bindings, fn, args, addr)
# Create Julia node for each argument to gen_fn_or_dist
arg_name = gen_node_name(arg_expr)
push!(inputs, parse_julia_expr!(stmts, bindings,
arg_name, arg_expr, QuoteNode(Any)))
arg_name, arg_expr, QuoteNode(Any), __module__))
end
# Add addr node (a GenerativeFunctionCallNode or RandomChoiceNode)
push!(stmts, :($(esc(node)) = add_addr_node!(
Expand All @@ -101,14 +101,26 @@ function parse_trace_expr!(stmts, bindings, fn, args, addr)
return name
end

function set_module_for_global_constants(expr, bindings, __module__)
expr = MacroTools.postwalk(expr) do e
if MacroTools.@capture(e, var_Symbol) && !haskey(bindings, var) && var != :end
:($__module__.$var)
else
e
end
end
return expr
end

"Parse a Julia expression and add a corresponding node to the IR."
function parse_julia_expr!(stmts, bindings, name::Symbol, expr::Expr,
typ::Union{Symbol,Expr,QuoteNode})
typ::Union{Symbol,Expr,QuoteNode}, __module__)
resolved = resolve_symbols(bindings, expr)
inputs = collect(resolved)
input_vars = map((x) -> esc(x[1]), inputs)
input_vars = map((x) -> x[1], inputs)
input_nodes = map((x) -> esc(x[2]), inputs)
fn = Expr(:function, Expr(:tuple, input_vars...), esc(expr))
expr = set_module_for_global_constants(expr, bindings, __module__)
fn = Expr(:function, Expr(:tuple, input_vars...), expr)
node = gensym(name)
push!(stmts, :($(esc(node)) = add_julia_node!(
builder, $fn, inputs=[$(input_nodes...)], name=$(QuoteNode(name)),
Expand All @@ -117,17 +129,17 @@ function parse_julia_expr!(stmts, bindings, name::Symbol, expr::Expr,
end

function parse_julia_expr!(stmts, bindings, name::Symbol, var::Symbol,
typ::Union{Symbol,Expr,QuoteNode})
typ::Union{Symbol,Expr,QuoteNode}, __module__)
if haskey(bindings, var)
# Use the existing node instead of creating a new one
return bindings[var]
end
node = parse_julia_expr!(stmts, bindings, name, Expr(:block, var), typ)
node = parse_julia_expr!(stmts, bindings, name, :($__module__.$var), typ, __module__)
return node
end

function parse_julia_expr!(stmts, bindings, name::Symbol, var::QuoteNode,
typ::Union{Symbol,Expr,QuoteNode})
typ::Union{Symbol,Expr,QuoteNode}, __module__)
fn = Expr(:function, Expr(:tuple), var)
node = gensym(name)
push!(stmts, :($(esc(node)) = add_julia_node!(
Expand All @@ -137,7 +149,7 @@ function parse_julia_expr!(stmts, bindings, name::Symbol, var::QuoteNode,
end

function parse_julia_expr!(stmts, bindings, name::Symbol, value,
typ::Union{Symbol,Expr,QuoteNode})
typ::Union{Symbol,Expr,QuoteNode}, __module__)
fn = Expr(:function, Expr(:tuple), QuoteNode(value))
node = gensym(name)
push!(stmts, :($(esc(node)) = add_julia_node!(
Expand All @@ -159,23 +171,23 @@ function parse_param_line!(stmts::Vector{Expr}, bindings, name::Symbol, typ)
end

"Parse assignments and add corresponding nodes for the right-hand-side."
function parse_assignment_line!(stmts, bindings, lhs, rhs)
function parse_assignment_line!(stmts, bindings, lhs, rhs, __module__)
if isa(lhs, Expr) && lhs.head == :tuple
# Recursively handle tuple assignments
name, typ = gen_node_name(rhs), QuoteNode(Any)
node = parse_julia_expr!(stmts, bindings, name, rhs, typ)
node = parse_julia_expr!(stmts, bindings, name, rhs, typ, __module__)
bindings[name] = node
for (i, lhs_i) in enumerate(lhs.args)
# Assign lhs[i] = rhs[i]
rhs_i = :($name[$i])
parse_assignment_line!(stmts, bindings, lhs_i, rhs_i)
parse_assignment_line!(stmts, bindings, lhs_i, rhs_i, __module__)
end
else
# Handle single variable assignment (base case)
(name::Symbol, typ) = parse_typed_var(lhs)
# Generate new node name if name is already bound
node_name = haskey(bindings, name) ? gensym(name) : name
node = parse_julia_expr!(stmts, bindings, node_name, rhs, typ)
node = parse_julia_expr!(stmts, bindings, node_name, rhs, typ, __module__)
# Old bindings are overwritten with new nodes
bindings[name] = node
end
Expand All @@ -184,32 +196,32 @@ function parse_assignment_line!(stmts, bindings, lhs, rhs)
end

"Parse a return line and add corresponding return node."
function parse_return_line!(stmts, bindings, expr)
function parse_return_line!(stmts, bindings, expr, __module__)
name, typ = gensym("return"), QuoteNode(Any)
node = parse_julia_expr!(stmts, bindings, name, expr, typ)
node = parse_julia_expr!(stmts, bindings, name, expr, typ, __module__)
bindings[name] = node
push!(stmts, :(set_return_node!(builder, $(esc(node)))))
return Expr(:return, expr)
end

"Parse and rewrite expression if it matches an @trace call."
function parse_and_rewrite_trace!(stmts, bindings, expr)
function parse_and_rewrite_trace!(stmts, bindings, expr, __module__)
if MacroTools.@capture(expr, e_gentrace)
# Parse "@trace(f(xs...), addr)" and return fresh variable
call, addr = expr.args
if addr == nothing static_dsl_syntax_error(expr, "Address required.") end
fn, args = call.args[1], call.args[2:end]
parse_trace_expr!(stmts, bindings, fn, args, something(addr))
parse_trace_expr!(stmts, bindings, fn, args, something(addr), __module__)
else
expr # Return expression unmodified
end
end

"Parse line (i.e. top-level expression) of a static Gen function body."
function parse_static_dsl_line!(stmts, bindings, line)
function parse_static_dsl_line!(stmts, bindings, line, __module__)
# Walk each line bottom-up, parsing and rewriting :gentrace expressions
rewritten = MacroTools.postwalk(
e -> parse_and_rewrite_trace!(stmts, bindings, e), line)
e -> parse_and_rewrite_trace!(stmts, bindings, e, __module__), line)
# If line is a top-level @trace call, we are done
if MacroTools.@capture(line, e_gentrace) return end
# Match and parse any other top-level expressions
Expand All @@ -220,10 +232,10 @@ function parse_static_dsl_line!(stmts, bindings, line)
parse_param_line!(stmts, bindings, name, typ)
elseif MacroTools.@capture(line, lhs_ = rhs_)
# Parse "lhs = rhs"
parse_assignment_line!(stmts, bindings, lhs, rhs)
parse_assignment_line!(stmts, bindings, lhs, rhs, __module__)
elseif MacroTools.@capture(line, return expr_)
# Parse "return expr"
parse_return_line!(stmts, bindings, expr)
parse_return_line!(stmts, bindings, expr, __module__)
elseif line isa LineNumberNode
# Skip line number nodes
else
Expand All @@ -234,18 +246,18 @@ end

"Parse static Gen function body line by line."
function parse_static_dsl_function_body!(
stmts::Vector{Expr}, bindings::Dict{Symbol,Symbol}, expr)
stmts::Vector{Expr}, bindings::Dict{Symbol,Symbol}, expr, __module__)
# TODO: Use line number nodes to improve error messages in generated code
if !(isa(expr, Expr) && expr.head == :block)
static_dsl_syntax_error(expr)
end
for line in expr.args
parse_static_dsl_line!(stmts, bindings, line)
parse_static_dsl_line!(stmts, bindings, line, __module__)
end
end

"Generates the code that builds the IR of a static Gen function."
function make_static_gen_function(name, args, body, return_type, annotations)
function make_static_gen_function(name, args, body, return_type, annotations, __module__)
# Construct the builder for the intermediate representation (IR)
stmts = Expr[]
push!(stmts, :(bindings = Dict{Symbol, StaticIRNode}()))
Expand All @@ -265,7 +277,7 @@ function make_static_gen_function(name, args, body, return_type, annotations)
bindings[arg.name] = node
end
# Parse function body and add corresponding nodes to the IR
parse_static_dsl_function_body!(stmts, bindings, body)
parse_static_dsl_function_body!(stmts, bindings, body, __module__)
push!(stmts, :(ir = build_ir(builder)))
expr = gensym("gen_fn_defn")
# Handle function annotations (caching Julia nodes by default)
Expand Down
2 changes: 1 addition & 1 deletion src/dynamic/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U},
if has_previous
prev_call = get_call(state.prev_trace, key)
prev_subtrace = prev_call.subtrace
get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key)
get_gen_fn(prev_subtrace) == gen_fn || gen_fn_changed_error(key)
(subtrace, weight, _, discard) = update(prev_subtrace,
args, map((_) -> UnknownChange(), args), constraints)
else
Expand Down
24 changes: 11 additions & 13 deletions src/modeling_library/distributions/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ Float64
const broadcasted_normal = BroadcastedNormal()

function logpdf(::Normal, x::Real, mu::Real, std::Real)
var = std * std
diff = x - mu
-(diff * diff)/ (2.0 * var) - 0.5 * log(2.0 * pi * var)
z = (x - mu) / std
- (abs2(z) + log(2π))/2 - log(std)
end

function logpdf(::BroadcastedNormal,
Expand All @@ -65,17 +64,17 @@ function logpdf(::BroadcastedNormal,
std::Union{AbstractArray{<:Real}, Real})
assert_has_shape(x, broadcast_shapes_or_crash(mu, std);
msg="Shape of `x` does not agree with the sample space")
z = (x .- mu) ./ std
var = std .* std
diff = x .- mu
sum(-(diff .* diff) ./ (2.0 * var) .- 0.5 * log.(2.0 * pi * var))
sum(- (abs2.(z) .+ log(2π)) / 2 .- log.(std))
end

function logpdf_grad(::Normal, x::Real, mu::Real, std::Real)
precision = 1. / (std * std)
diff = mu - x
deriv_x = diff * precision
z = (x - mu) / std
deriv_x = - z / std
deriv_mu = -deriv_x
deriv_std = -1. / std + (diff * diff) / (std * std * std)
deriv_std = -1. / std + abs2(z) / std
(deriv_x, deriv_mu, deriv_std)
end

Expand All @@ -85,11 +84,10 @@ function logpdf_grad(::BroadcastedNormal,
std::Union{AbstractArray{<:Real}, Real})
assert_has_shape(x, broadcast_shapes_or_crash(mu, std);
msg="Shape of `x` does not agree with the sample space")
precision = 1.0 ./ (std .* std)
diff = mu .- x
deriv_x = sum(diff .* precision)
deriv_mu = sum(-deriv_x)
deriv_std = sum(-1.0 ./ std .+ (diff .* diff) ./ (std .* std .* std))
z = (x .- mu) ./ std
deriv_x = sum(- z ./ std)
deriv_mu = -deriv_x
deriv_std = sum(-1. ./ std .+ abs2.(z) ./ std)
(deriv_x, deriv_mu, deriv_std)
end

Expand Down
3 changes: 2 additions & 1 deletion src/modeling_library/modeling_library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ include("recurse/recurse.jl")
include("switch/switch.jl")

#############################################################
# abstractions for constructing custom generative functions #
# custom deterministic generative functions #
#############################################################

include("custom_determ.jl")
include("structs.jl")
Loading