Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convenience forms for update and regenerate with optional args and argdiffs #236

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
3 changes: 1 addition & 2 deletions src/dynamic/regenerate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U},
prev_call = get_call(state.prev_trace, key)
prev_subtrace = prev_call.subtrace
get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key)
(subtrace, weight, _) = regenerate(
prev_subtrace, args, map((_) -> UnknownChange(), args), subselection)
(subtrace, weight, _) = regenerate(prev_subtrace, args, subselection)
else
(subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap())
end
Expand Down
3 changes: 1 addition & 2 deletions src/dynamic/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U},
prev_call = get_call(state.prev_trace, key)
prev_subtrace = prev_call.subtrace
get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key)
(subtrace, weight, _, discard) = update(prev_subtrace,
args, map((_) -> UnknownChange(), args), constraints)
(subtrace, weight, _, discard) = update(prev_subtrace, args, constraints)
else
(subtrace, weight) = generate(gen_fn, args, constraints)
end
Expand Down
60 changes: 55 additions & 5 deletions src/gen_fn_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,9 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)
end

"""
(new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple,
constraints::ChoiceMap)

(new_trace, weight, retdiff, discard) = update(
trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap)

Update a trace by changing the arguments and/or providing new values for some
existing random choice(s) and values for some newly introduced random choice(s).
Expand All @@ -272,13 +273,37 @@ that if the original `trace` was generated using non-default argument values,
then for each optional argument that is omitted, the old value will be
over-written by the default argument value in the updated trace.
"""
function update(trace, args::Tuple, argdiffs::Tuple, ::ChoiceMap)
function update(trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap)
error("Not implemented")
end

"""
(new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple,
selection::Selection)
update(
trace, constraints::ChoiceMap, args::Tuple;
argdiffs::Tuple=map((_) -> UnknownChange(), args))

Convenience form of `update` with keyword argument for argdiffs.
"""
function update(trace, constraints::ChoiceMap, args::Tuple;
argdiffs::Tuple=map((_) -> UnknownChange(), args))
update(trace, args, argdiffs, constraints)
end

"""
update(trace, constraints::ChoiceMap)

Convenience form of `update` when there is no change to arguments.
"""
function update(trace, constraints::ChoiceMap)
args = get_args(trace)
argdiffs = map((_) -> NoChange(), args)
update(trace, args, argdiffs, constraints)
end


"""
(new_trace, weight, retdiff) = regenerate(
trace, args::Tuple, argdiffs::Tuple, selection::Selection)

Update a trace by changing the arguments and/or randomly sampling new values
for selected random choices using the internal proposal distribution family.
Expand Down Expand Up @@ -307,6 +332,31 @@ function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection)
error("Not implemented")
end

"""
regenerate(
trace, selection::Selection, args::Tuple;
argdiffs::Tuple=map((_) -> UnknownChange(), args),

Convenience form of `regenerate` with keyword arguments for argdiffs.
"""
function regenerate(trace, selection::Selection, args::Tuple;
argdiffs::Tuple=map((_) -> UnknownChange(), args))
regenerate(trace, args, argdiffs, selection)
end

"""
regenerate(trace, selection::Selection)

Convenience form of `regenerate` when there is no change to arguments.
"""
function regenerate(trace, selection::Selection)
args = get_args(trace)
argdiffs = map((_) -> NoChange(), args)
regenerate(trace, args, argdiffs, selection)
end



"""
arg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.)

Expand Down
6 changes: 2 additions & 4 deletions src/inference/elliptical_slice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ Also takes the mean vector and covariance matrix of the prior.
"""
function elliptical_slice(
trace, addr, mu, cov; check=false, observations=EmptyChoiceMap())
args = get_args(trace)
argdiffs = map((_) -> NoChange(), args)

# sample nu
nu = mvnormal(zeros(length(mu)), cov)
Expand All @@ -29,7 +27,7 @@ function elliptical_slice(
f = trace[addr] .- mu

new_f = f * cos(theta) + nu * sin(theta)
new_trace, weight = update(trace, args, argdiffs, choicemap((addr, new_f .+ mu)))
new_trace, weight = update(trace, choicemap((addr, new_f .+ mu)))
while weight <= log(u)
if theta < 0
theta_min = theta
Expand All @@ -38,7 +36,7 @@ function elliptical_slice(
end
theta = uniform(theta_min, theta_max)
new_f = f * cos(theta) + nu * sin(theta)
new_trace, weight = update(trace, args, argdiffs, choicemap((addr, new_f .+ mu)))
new_trace, weight = update(trace, choicemap((addr, new_f .+ mu)))
end
check && check_observations(get_choices(new_trace), observations)
return new_trace
Expand Down
4 changes: 1 addition & 3 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ function hmc(
trace::U, selection::Selection; L=10, eps=0.1,
check=false, observations=EmptyChoiceMap()) where {T,U}
prev_model_score = get_score(trace)
args = get_args(trace)
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
argdiffs = map((_) -> NoChange(), args)

# run leapfrog dynamics
new_trace = trace
Expand All @@ -46,7 +44,7 @@ function hmc(

# get new gradient
values_trie = from_array(values_trie, values)
(new_trace, _, _) = update(new_trace, args, argdiffs, values_trie)
(new_trace, _, _) = update(new_trace, values_trie)
(_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
gradient = to_array(gradient_trie, Float64)

Expand Down
2 changes: 1 addition & 1 deletion src/inference/involution_dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ function apply_involution(involution::InvolutionDSLProgram, trace, u, proposal_a

# update model trace
(new_trace, model_weight, _, discard) = update(
trace, get_args(trace), map((_) -> NoChange(), get_args(trace)), first_pass_state.constraints)
trace, first_pass_state.constraints)

# create input array and mappings input addresses that are needed for Jacobian
# exclude addresses that were moved to another address
Expand Down
5 changes: 1 addition & 4 deletions src/inference/mala.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ Apply a Metropolis-Adjusted Langevin Algorithm (MALA) update.
function mala(
trace, selection::Selection, tau::Real;
check=false, observations=EmptyChoiceMap())
args = get_args(trace)
argdiffs = map((_) -> NoChange(), args)
std = sqrt(2 * tau)
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing

Expand All @@ -30,8 +28,7 @@ function mala(

# evaluate model weight
constraints = from_array(values_trie, proposed_values)
(new_trace, weight, _, discard) = update(trace,
args, argdiffs, constraints)
(new_trace, weight, _, discard) = update(trace, constraints)
check && check_observations(get_choices(new_trace), observations)

# backward proposal
Expand Down
4 changes: 1 addition & 3 deletions src/inference/map_optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ Selected random choices must have support on the entire real line.
"""
function map_optimize(trace, selection::Selection;
max_step_size=0.1, tau=0.5, min_step_size=1e-16, verbose=false)
args = get_args(trace)
argdiffs = map((_) -> NoChange(), args)
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing

(_, values, gradient) = choice_gradients(trace, selection, retval_grad)
Expand All @@ -21,7 +19,7 @@ function map_optimize(trace, selection::Selection;
new_values_vec = values_vec + gradient_vec * step_size
values = from_array(values, new_values_vec)
# TODO discard and weight are not actually needed, there should be a more specialized variant
(new_trace, _, _, discard) = update(trace, args, argdiffs, values)
(new_trace, _, _, discard) = update(trace, values)
new_score = get_score(new_trace)
change = new_score - score
if verbose
Expand Down
9 changes: 2 additions & 7 deletions src/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ Perform a Metropolis-Hastings update that proposes new values for the selected a
function metropolis_hastings(
trace, selection::Selection;
check=false, observations=EmptyChoiceMap())
args = get_args(trace)
argdiffs = map((_) -> NoChange(), args)
(new_trace, weight) = regenerate(trace, args, argdiffs, selection)
(new_trace, weight) = regenerate(trace, selection)
check && check_observations(get_choices(new_trace), observations)
if log(rand()) < weight
# accept
Expand All @@ -41,12 +39,9 @@ If the proposal modifies addresses that determine the control flow in the model,
function metropolis_hastings(
trace, proposal::GenerativeFunction, proposal_args::Tuple;
check=false, observations=EmptyChoiceMap())
model_args = get_args(trace)
argdiffs = map((_) -> NoChange(), model_args)
proposal_args_forward = (trace, proposal_args...,)
(fwd_choices, fwd_weight, _) = propose(proposal, proposal_args_forward)
(new_trace, weight, _, discard) = update(trace,
model_args, argdiffs, fwd_choices)
(new_trace, weight, _, discard) = update(trace, fwd_choices)
proposal_args_backward = (new_trace, proposal_args...,)
(bwd_weight, _) = assess(proposal, proposal_args_backward, discard)
alpha = weight - fwd_weight + bwd_weight
Expand Down
3 changes: 2 additions & 1 deletion src/inference/particle_filter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ function particle_filter_step!(state::ParticleFilterState{U}, new_args::Tuple, a
for i=1:num_particles
(prop_choices, prop_weight, _) = propose(proposal, (state.traces[i], proposal_args...))
constraints = merge(observations, prop_choices)
(state.new_traces[i], up_weight, _, disc) = update(state.traces[i], new_args, argdiffs, constraints)
(state.new_traces[i], up_weight, _, disc) = update(
state.traces[i], new_args, argdiffs, constraints)
@assert isempty(disc)
state.log_weights[i] += up_weight - prop_weight
end
Expand Down