In [1]:
using JSServe
Page()

In [None]:
import WGLMakie
using Revise
using Gen
includet("../src/gridworld.jl")
includet("../src/houseexpo.jl")
includet("../src/visualize.jl")
includet("../src/raytrace.jl")
includet("../src/distributions.jl")

In [None]:
import GenPOMDPs

In [None]:
# Observation model
@gen (static) function observe_noisy_distances(pos, params)
    w, n_rays, σ = params.map, params.obs.n_rays, params.obs.σ
    
    dists = ray_trace_distances(moveagent(w, pos), n_rays)
    obs ~ broadcasted_normal(dists, σ)
    return obs
end

# Position prior
# Accepts as input a gridworld without the agent in it
@gen (static) function uniform_agent_pos(params)
    w = params.map
    
    pos ~ uniform_from_set(empty_cells(w))
    return pos
end

# Motion model
@dist list_categorical(list, probs) = list[categorical(probs)]
function _get_motion_logprobs(w, (x2, y2), σ)
    possible_positions = reshape(Tuple.(keys(w)), (:,))
    logprobs = [
        w[x, y] == empty ? logpdf(broadcasted_normal, [x, y], [x2, y2], σ) : -Inf
        for (x, y) in possible_positions
    ]
    return (possible_positions, logprobs)
end
function _normalized_probs(lps)
    exp.(lps .- logsumexp(lps))
end
# action `a` should be a symbol in [:up, :down, :left, :right, :stay]
@gen (static) function motion_model(pos, action, params)
    w, σ = params.map, params.step.σ
    
    (possible_positions, logprobs) = _get_motion_logprobs(w, newpos(w, pos, action), σ)
    pos ~ list_categorical(possible_positions, _normalized_probs(logprobs))
    return pos
end

In [None]:
# POMDP of this environment
pomdp = GenPOMDPs.GenPOMDP(
    uniform_agent_pos,       # INIT   : params                      ⇝ state
    motion_model,            # STEP   : prev_state, actions, params ⇝ state
    observe_noisy_distances, # OBS    : state                       ⇝ observation
    (state, action) -> 0.    # UTILITY: state, action, params       → utility
)

# Generative function over trajectories of the POMDP,
# given a fixed action sequence.
trajectory_model = GenPOMDPs.ControlledTrajectoryModel(pomdp)

In [None]:
PARAMS = (
    # load houseexpo env #7,
    # discretized with 24 gridsquares along the x axis
    map = load_gridworld(24, 7),
    
    obs  = (; n_rays = 20, σ = 1.),
    step = (; σ = 0.1)
);

In [None]:
ret = trajectory_model(0, [], PARAMS)

In [None]:
# Set up "interactive mode" with this model.
# `take_action` is a function which accepts an action
# as input.
# `world_trace` is an `Observable` whose value is a trace of
# `trajectory_model`, describing the "ground truth" trajectory which
# has played out in the environment.  (This includes the state,
# action, observation, and utility sequences.)
# Each time `take_action` is called, `world_trace` automatically updates
# to extend the simulation roll-out by one timestep.
world_trace, take_action = GenPOMDPs.interactive_world_trace(trajectory_model, PARAMS);

In [None]:
# Construct a Figure object, f, displaying a GUI which
# displays the world state, and allows the user to
# take actions in the world (playing as the agent).
# This is done by mapping keyboard events in `f` onto
# the `take_action` function obtained above.
# (t is an observable controlling which timestep is displayed
# in the figure, so the user can inspect past timesteps.)

# Lift the trace observable to an observable
# over the state & obs sequences, consumable by
# the GridWorld visualization library.
pos_obs_values = Makie.@lift(
        ( GenPOMDPs.state_sequence($world_trace),
        
          # Note: as GenPOMDPs treats it, observations look like
          # `choicemap((:obs, point_cloud_distances))`, while
          # and the observation distribution "return values"
          # just look like `point_cloud_distances`
          GenPOMDPs.observation_retval_sequence($world_trace)
        )
    )

(f, t) = interactive_mode_gui(PARAMS.map, pos_obs_values, take_action)

f # Render the figure object

# Controls:
# WASD to (noisily) move left/right/up/down
# E to (noisily) stay still
# G to decrement the displayed timestep
# T to increment the displayed timestep
# (up to the time of the latest action taken)
# Actions can only be taken when the visualized timestep is the
# latest one simulated.

In [None]:
# New interactive session.
world_trace2, take_action2 = GenPOMDPs.interactive_world_trace(trajectory_model, PARAMS)

# This time, also get an interactive particle filter, which updates
# as the environment is extended.

# An observable over the pair (observation_sequence, action_sequence).
# As the ground truth trace updates, this will update with the
# sequence of observations and actions which have been observed.
# (The observations and actions are all the information available
# to an agent attempting to act in the world, so this observable
# is what the agent sees.)
observation_action_sequence = Makie.@lift((
    GenPOMDPs.observation_sequence($world_trace2),
    GenPOMDPs.action_sequence($world_trace2)
))

# Get an observable of a particle filter belief state.
# Do particle filtering in the ground truth world model
# with the ground truth parameter values.
pf_states = GenPOMDPs.pf_observable(
    GenPOMDPs.bootstrap_pf(pomdp, PARAMS, 200),
    observation_action_sequence
);

In [None]:
# Visualize the environment in interactive mode, plus
# the particle filter results.
pos_obs_values2 = Makie.@lift(
        ( GenPOMDPs.state_sequence($world_trace2),
          GenPOMDPs.observation_retval_sequence($world_trace2)   ))
(f2, t2) = interactive_mode_gui(PARAMS.map, pos_obs_values2, take_action2)

display_pf_localization!(f2, t2, pf_states)

f2