In [1]:
# Load the Julia environment needed for this notebook
import Pkg
Pkg.activate("../../Tasks2D")

[32m[1m  Activating[22m[39m project at `~/Developer/research/fall2023/cocosci/tasks2D/Tasks2D`


In [2]:
import Makie   # Visualization Library
using Revise      # For development; makes it so modifications
                  # to imported modules are immediately reflected in this Julia session
using Gen         # Gen probabilistic programming library
import GenParticleFilters # Additional particle filtering functionality for Gen
import GridWorlds # Simple gridworld functionality
import LineWorlds
const L = LineWorlds
import LineWorlds: cast # Ray caster
import GenPOMDPs  # Beginnings of a Gen POMDP library

import Tasks2D

include("SLAM/Utils.jl")

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling GridWorlds [c15fd557-8ec5-4bf9-9d87-df57ac477796]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling GenPOMDPs [f88df91c-fa0e-46d7-b73d-4420684e5acb]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling Tasks2D [d49a4b09-0ba2-4df1-b547-2796ddccb60c]


Main.Utils

In [None]:
import GLMakie
GLMakie.activate!()

In [16]:
# Initial position model

# Distribution to sample uniformly from a Julia Set
using Tasks2D.Distributions: uniform_from_set

@gen (static) function uniform_agent_pos(params)
    w = params.map # a map, represented as a GridWorlds.GridWorld
    
    cell ~ uniform_from_set(GridWorlds.empty_cells(w))
    
    # Cell (i, j) corresponds to the region from i-1 to i and j-1 to j
    x ~ uniform(cell[1] - 1, cell[1])
    y ~ uniform(cell[2] - 1, cell[2])
    
    return [x, y]
end

var"##StaticGenFunction_uniform_agent_pos#344"(Dict{Symbol, Any}(), Dict{Symbol, Any}())

In [17]:
function det_next_pos(pos, a, Δ)
    (x, y) = pos
    a == :up    ? [x, y + Δ] :
    a == :down  ? [x, y - Δ] : 
    a == :left  ? [x - Δ, y] :
    a == :right ? [x + Δ, y] :
    a == :stay  ? [x, y]     :
                error("Unrecognized action: $a")
end

function handle_wall_intersection(prev, new, gridworld)
    walls = GridWorlds.wall_segments(gridworld)
    move = L.Segment(prev, new)
    
    min_collision_dist = Inf
    vec_to_min_dist_collision = nothing
    for i in 1:(size(walls)[1])
        wall = walls[i, :]
        # print("wall: $wall")
        do_intersect, dist = L.Geometry.cast(move, L.Segment(wall))

        if do_intersect && dist ≤ L.Geometry.norm(move)
            if dist < min_collision_dist
                min_collision_dist = dist
                vec_to_min_dist_collision = L.Geometry.diff(move)
            end
        end
    end
    
    if !isnothing(vec_to_min_dist_collision)
        dist = min_collision_dist
        if dist < 0.05
            return prev
        else
            normalized_vec = (vec_to_min_dist_collision / L.Geometry.norm(vec_to_min_dist_collision))
            collision_pt = prev + (dist - 0.04) * normalized_vec
            return collision_pt
        end

    end
    
    return new
end

handle_wall_intersection (generic function with 1 method)

In [18]:
@gen (static) function motion_model(pos, action, params)
    w, σ = params.map, params.step.σ
    σ = 
    
    next_pos_det = det_next_pos(pos, action, params.step.Δ)
    noisy_next_pos ~ broadcasted_normal(next_pos_det, params.step.σ)
    next_pos = handle_wall_intersection(pos, noisy_next_pos, w)
    
    return next_pos
end

var"##StaticGenFunction_motion_model#370"(Dict{Symbol, Any}(), Dict{Symbol, Any}())

In [19]:
@gen function observe_noisy_distances(pos, params)
    p = reshape([pos..., params.obs.orientation], (1, 3))
    w, s_noise, outlier, outlier_vol, zmax = params.obs.sensor_args
    _as = L.create_angles(params.obs.fov, params.obs.n_rays)
    segs = GridWorlds.wall_segments(params.map)
    # println(segs.shape)
    zs = L.cast(p, segs; num_a=params.obs.n_rays, zmax)
    # obs ~ L.sensordist_2dp3(sensor_args...)
    obs ~ broadcasted_normal(zs, params.obs.sensor_args.σ)
    return obs
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##observe_noisy_distances#378", Bool[0, 0], false)

In [20]:
# POMDP of this environment
pomdp = GenPOMDPs.GenPOMDP(
    uniform_agent_pos,       # INIT   : params                      ⇝ state
    motion_model,            # STEP   : prev_state, actions, params ⇝ state
    observe_noisy_distances, # OBS    : state, params               ⇝ observation
    (state, action) -> 0.    # UTILITY: state, action, params       → utility
)

GenPOMDPs.GenPOMDP(var"##StaticGenFunction_uniform_agent_pos#344"(Dict{Symbol, Any}(), Dict{Symbol, Any}()), var"##StaticGenFunction_motion_model#370"(Dict{Symbol, Any}(), Dict{Symbol, Any}()), DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##observe_noisy_distances#378", Bool[0, 0], false), var"#64#65"())

In [21]:
# Generative function over trajectories of the POMDP,
# given a fixed action sequence.
trajectory_model = GenPOMDPs.ControlledTrajectoryModel(pomdp)

GenPOMDPs.var"##StaticGenFunction__ControlledTrajectoryModel#450"(Dict{Symbol, Any}(), Dict{Symbol, Any}())

In [22]:
i = uniform_discrete(1, 20) # Random enviornment index
PARAMS = (;
    map = GridWorlds.load_houseexpo_gridworld(24, i),
    step = (; Δ = 1.25, σ = 0.005 ), # step model arguments
    obs = (; fov = 2π, n_rays = 40,  # obs model arguments
        orientation=π/2,
        sensor_args = (;
            w = 5, s_noise = 0.02,
            outlier = 0.0001, outlier_vol = 100.0,
            zmax = 100.0, σ=0.005
)));

In [25]:
world_trace, take_action = GenPOMDPs.interactive_world_trace(trajectory_model, PARAMS);

function get_posobs_seq(groundtruth_trace)
    return map(
        trace -> (
            GenPOMDPs.state_sequence(trace),
            [reshape(o, (:,)) for o in GenPOMDPs.observation_retval_sequence(trace)]
        ),
        groundtruth_trace
    )
end
function get_obs_seq(groundtruth_trace)
    return map(
        trace -> [reshape(o, (:,)) for o in GenPOMDPs.observation_retval_sequence(trace)],
        groundtruth_trace
    )
end

(f, t) = GridWorlds.Viz.play_as_agent_gui(get_obs_seq(world_trace), take_action, show_lines_to_walls=false)
f

In [28]:
using Dates
filename = "saves/" * string(now()) * "__pomdp_trace.jld"
Utils.serialize_pomdp_trace(filename, world_trace[])

Trace serialized to saves/2023-12-06T15:21:41.953__pomdp_trace.jld.


In [34]:
tr2 = Utils.deserialize_pomdp_trace("saves/2023-12-06T15:21:41.953__pomdp_trace.jld", trajectory_model);
trace2, take_action2 = GenPOMDPs.make_trace_interactive(tr2);

In [35]:
# Interactive GUI with visible walls
(f, t) = GridWorlds.Viz.interactive_gui(
    PARAMS.map, get_posobs_seq(trace2), take_action2
)
f

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mWS connection closed because of IO error
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/connection/websocket.jl:40[39m


In [25]:
function choicemap_to_serializable(cm)
    address_to_value = collect(get_values_shallow(cm))
    address_to_sub = [
        (addr, choicemap_to_serializable(submap))
        for (addr, submap) in get_submaps_shallow(cm)
    ]
    return (address_to_value, address_to_sub)
end
function serializable_to_choicemap(s)
    (a_to_v, a_to_s) = s
    cm = choicemap()
    for (a, v) in a_to_v
        cm[a] = v
    end
    for (a, s) in a_to_s
        Gen.set_submap!(cm, a, serializable_to_choicemap(s))
    end
    return cm
end

serializable_to_choicemap (generic function with 1 method)

In [36]:
function choicemap_to_serializable(cm)
    address_to_value = collect(get_values_shallow(cm))
    address_to_sub = [
        (addr, choicemap_to_serializable(submap))
        for (addr, submap) in get_submaps_shallow(cm)
    ]
    return (address_to_value, address_to_sub)
end
function serializable_to_choicemap(s)
    (a_to_v, a_to_s) = s
    cm = choicemap()
    for (a, v) in a_to_v
        cm[a] = v
    end
    for (a, s) in a_to_s
        Gen.set_submap!(cm, a, serializable_to_choicemap(s))
    end
    return cm
end

import Serialization

function serialize_pomdp_trace(filename, trace)
    Serialization.serialize(filename, Dict(
        "args" => get_args(trace),
        "choices" => choicemap_to_serializable(get_choices(trace))
    ))
end
function deserialize_pomdp_trace(filename, pomdp_trajectory_model)
    s = Serialization.deserialize(filename)
    args = s["args"]
    choices = serializable_to_choicemap(s["choices"])
    tr, _ = Gen.generate(pomdp_trajectory_model, args, choices)
    return tr
end

In [41]:
 using Dates
filename = "saves/" * string(now()) * "__pomdp_trace.jld"
serialize_pomdp_trace(filename, world_trace[])

In [42]:
tr2 = deserialize_pomdp_trace(filename, trajectory_model);

[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/connection/websocket.jl:40[39m


In [40]:
get_choices(tr2) == get_choices(world_trace[])

true

[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/p