In [1]:
using Pkg; Pkg.activate("../..")

[32m[1m  Activating[22m[39m project at `~/Developer/research/SMCP3OpenSource/GenSMCP3.jl`


In [2]:
using Gen
@gen function model(t)
    # `x` is the latent value to be inferred
    x ~ normal(0, 100)
    
    # a number of noisy observations are made of `x`.
    # the goal will be to infer P(x | observations).
    # whenever a new observation is made, we will update this posterior
    # using a particle filter.
    observations = []
    for i in 1:t
        obs = {"obs$i"} ~ normal(x, 1)
        push!(observations, obs)
    end

    return observations
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##model#292", Bool[0], false)

In [3]:
get_choices(simulate(model, (3,)))

│
├── "obs3" : -7.1291761085078065
│
├── "obs2" : -6.882139464429483
│
├── "obs1" : -6.4590919402763
│
└── :x : -6.371462094539885


In [4]:
import GenSMCP3: @kernel

# This function defines a forward proposal distribution for the model above.
# It will receive as input a trace of the model with t-1 observed datapoints,
# and a new observation.  Its job is to propose an update to the latent state
# of the model, to incorporate information in the new observation.
@kernel function forward_proposal(previous_trace, new_observation)
    t_prev = get_args(previous_trace)[1]
    t = t_prev + 1

    # Construct a vector of all the observations, as of time t.
    old_observations = [previous_trace["obs$i"] for i in 1:t_prev]
    new_observations = vcat(old_observations, [new_observation])

    # Compute the mean and variance of the new observations.
    mean = sum(new_observations) / t
    var = 1/(t + 1)

    # Propose a new value for x, based on the new observations.

    std = sqrt(var)
    new_x ~ normal(mean, std)

    # Return two things.
    # First: return a choicemap which tells Gen how to overwrite the latent state
    # of the current trace, to produce the new trace proposed by this update.
    # Second: return a choicemap containing all of the choices the _backward proposal_ (defined below)
    # would make to invert this update.  (In this case, the backward proposal would have to propose
    # what the value of x was before the update, in order to invert it.)
    return (
        choicemap((:x, new_x)),
        choicemap((:previous_x, previous_trace[:x]))
    )
end

GenTraceKernelDSL.Kernel(var"#3#5"())

In [5]:
@kernel function backward_proposal(updated_trace, new_observation)
    t = get_args(updated_trace)[1]
    
    if t > 1
        observations_before_update = [updated_trace["obs$i"] for i in 1:t-1]

        mean = sum(observations_before_update) / (t-1)
        var = 1/t
        std = sqrt(var)
    else
        # If t=1, the previous x value was generated at random from the prior,
        # without taking into account any observations.  So to revert to
        # this old latent state, we'll sample the previous x from normal(0, 1).
        mean = 0
        std = 1
    end

    # Propose what the value of x was before the update.
    previous_x ~ normal(mean, std)

    # Return two things.
    # First: return a choicemap which tells Gen how to overwrite the latent state to invert the update
    # from the forward proposal.
    # Second: return a choicemap containing all of the choices the _forward proposal_ (defined above)
    # would make to re-apply this update.  (In this case, the forward proposal would have to propose
    # what the value of x was after the update, in order to re-apply it.)
    return (
        choicemap((:x, previous_x)),
        choicemap((:new_x, updated_trace[:x]))
    )
end


GenTraceKernelDSL.Kernel(var"#7#9"())

In [6]:
import GenSMCP3: SMCP3Update
using GenParticleFilters

function smcp3_algorithm(observations, n_particles)
    # Initialize with 0 observations.
    # Pass in an empty choicemap to indicate that there are not
    # yet any observed values.
    state = pf_initialize(model, (0,), choicemap(), n_particles)

    # For each tth observation:
    for (t, observation) in enumerate(observations)
        pf_update!(
            state,
            # new argument to the model, to have it output `t` observations
            (t,),

            # Tell Gen some change occurred to the argument, but we are
            # not going to provide any special information about the
            # type of change.  (In some cases we may provide information
            # about the change, so Gen can use incremental computation
            # to improve performance.)
            (UnknownChange(),),
            
            # update the trace to have this observation
            # at the address "obs$t"
            choicemap(("obs$t", observation)),
            
            # update the latent state using an SMCP3 update,
            # with the forward and backward proposals defined above,
            # given the new observation as an additional argument
            # after the trace to be updated
            SMCP3Update(
                forward_proposal,
                backward_proposal,
                (observation,),
                (observation,)
            )
        )
    end
    
    # Resample the particles whenever the ESS becomes too small
    if effective_sample_size(state) < 1/5 * n_particles
        # Perform residual resampling, pruning low-weight particles
        pf_resample!(state, :residual)
    end
    
    return state
end

smcp3_algorithm (generic function with 1 method)

In [7]:
function estimate_expectation(state, f)
    return sum(
        w * f(tr)
        for (w, tr) in zip(
            GenParticleFilters.get_norm_weights(state),
            state.traces
        )
    )
end
    

estimate_expectation (generic function with 1 method)

In [8]:
inference_result_state = smcp3_algorithm([1, 2, 3], 10000);
empirical_expected_x = mean(inference_result_state, :x)

1.9777302118684603

In [9]:
inference_result_state = smcp3_algorithm([1, 2, 3, 10], 10000);
empirical_expected_x = mean(inference_result_state, :x)

4.007261796253326