From 0c0283f4fbfb003b2255f7eba47c0e2e60515f42 Mon Sep 17 00:00:00 2001 From: Xuan Date: Fri, 12 Feb 2021 21:51:28 -0500 Subject: [PATCH] Add ExtendingTraceTranslator. --- src/inference/trace_translators.jl | 69 ++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 4 deletions(-) diff --git a/src/inference/trace_translators.jl b/src/inference/trace_translators.jl index 93becd05..c65e836a 100644 --- a/src/inference/trace_translators.jl +++ b/src/inference/trace_translators.jl @@ -764,7 +764,7 @@ end argdiffs::Tuple = (), new_obs::ChoiceMap = EmptyChoiceMap(), q_fwd::GenerativeFunction, - q_fwd_args::Tuple = ()) + q_fwd_args::Tuple = ()) Constructor for a simple extending trace translator. @@ -772,12 +772,12 @@ Run the translator with: (output_trace, log_weight) = translator(input_trace) """ -@with_kw struct SimpleExtendingTraceTranslator +@with_kw struct SimpleExtendingTraceTranslator p_new_args::Tuple = () argdiffs::Tuple = () new_obs::ChoiceMap = EmptyChoiceMap() q_fwd::GenerativeFunction - q_fwd_args::Tuple = () + q_fwd_args::Tuple = () end function (translator::SimpleExtendingTraceTranslator)(prev_model_trace::Trace) @@ -801,6 +801,67 @@ function (translator::SimpleExtendingTraceTranslator)(prev_model_trace::Trace) return (new_model_trace, log_weight) end +################################## +# ExtendingTraceTranslator # +################################## + +""" + translator = ExtendingTraceTranslator(; + p_new_args::Tuple = (), + argdiffs::Tuple = (), + new_obs::ChoiceMap = EmptyChoiceMap(), + q_fwd::GenerativeFunction, + q_fwd_args::Tuple = (), + f::Union{TraceTransformDSLProgram,Nothing} = nothing) + +Constructor for a extending trace translator. + +Run the translator with: + + (output_trace, log_weight) = translator(input_trace) +""" +@with_kw struct ExtendingTraceTranslator + p_new_args::Tuple = () + argdiffs::Tuple = () + new_obs::ChoiceMap = EmptyChoiceMap() + q_fwd::GenerativeFunction + q_fwd_args::Tuple = () + f::Union{TraceTransformDSLProgram,Nothing} = nothing # a bijection +end + +function (translator::ExtendingTraceTranslator)(prev_model_trace::Trace) + + # simulate from auxiliary program + forward_proposal_trace = simulate(translator.q_fwd, (prev_model_trace, translator.q_fwd_args...,)) + forward_proposal_score = get_score(forward_proposal_trace) + + # transform forward proposal + if translator.f === nothing + constraints = get_choices(forward_proposal_trace) + else + first_pass_results = + run_first_pass(translator.f, forward_proposal_trace, nothing) + log_abs_determinant = + jacobian_correction(translator.f, forward_proposal_trace, + nothing, first_pass_results, nothing) + constraints = first_pass_results.constraints + end + + # computing the new trace via update + constraints = merge(constraints, translator.new_obs) + (new_model_trace, log_model_weight, _, discard) = update( + prev_model_trace, translator.p_new_args, + translator.argdiffs, constraints) + + if !isempty(discard) + @error("can only extend the trace with random choices, cannot remove random choices") + error("Invalid ExtendingTraceTranslator") + end + + log_weight = log_model_weight - forward_proposal_score + log_abs_determinant + return (new_model_trace, log_weight) +end + ############################ # SymmetricTraceTranslator # ############################ @@ -884,7 +945,7 @@ function (translator::SymmetricTraceTranslator{<:Function})( forward_retval = get_retval(forward_trace) (new_model_trace, backward_choices, log_weight) = translator.involution( prev_model_trace, forward_choices, forward_retval, translator.q_args) - (backward_score, backward_retval) = assess(translator.q, (new_model_trace, translator.q_args...), backward_choices) + (backward_score, backward_retval) = assess(translator.q, (new_model_trace, translator.q_args...), backward_choices) log_weight += (backward_score - forward_score)