In [1]:
# Load the Julia environment needed for this notebook
import Pkg
Pkg.activate("../Tasks2D")

[32m[1m  Activating[22m[39m project at `~/Developer/research/fall2023/cocosci/tasks2D/Tasks2D`


In [2]:
import Makie   # Visualization Library
using Revise      # For development; makes it so modifications
                  # to imported modules are immediately reflected in this Julia session
using Gen         # Gen probabilistic programming library
import GenParticleFilters # Additional particle filtering functionality for Gen
import GridWorlds # Simple gridworld functionality
import LineWorlds
const L = LineWorlds
import LineWorlds: cast # Ray caster
import GenPOMDPs  # Beginnings of a Gen POMDP library

import Tasks2D

includet("SLAM/Utils.jl")

In [105]:
includet("SLAM/Utils.jl")

In [3]:
import GLMakie
GLMakie.activate!()

In [4]:
# Initial position model

# Distribution to sample uniformly from a Julia Set
using Tasks2D.Distributions: uniform_from_set

@gen (static) function uniform_agent_pos(t_to_params)
    w = t_to_params(0).map # a map, represented as a GridWorlds.GridWorld
    
    cell ~ uniform_from_set(GridWorlds.empty_cells(w))
    
    # Cell (i, j) corresponds to the region from i-1 to i and j-1 to j
    x ~ uniform(cell[1] - 1, cell[1])
    y ~ uniform(cell[2] - 1, cell[2])
    
    return ([x, y], 0)
end

var"##StaticGenFunction_uniform_agent_pos#317"(Dict{Symbol, Any}(), Dict{Symbol, Any}())

In [5]:
function det_next_pos(pos, a, Δ)
    (x, y) = pos
    a == :up    ? [x, y + Δ] :
    a == :down  ? [x, y - Δ] : 
    a == :left  ? [x - Δ, y] :
    a == :right ? [x + Δ, y] :
    a == :stay  ? [x, y]     :
                error("Unrecognized action: $a")
end

function handle_wall_intersection(prev, new, gridworld)
    walls = GridWorlds.wall_segments(gridworld)
    move = L.Segment(prev, new)
    
    min_collision_dist = Inf
    vec_to_min_dist_collision = nothing
    for i in 1:(size(walls)[1])
        wall = walls[i, :]
        # print("wall: $wall")
        do_intersect, dist = L.Geometry.cast(move, L.Segment(wall))

        if do_intersect && dist ≤ L.Geometry.norm(move)
            if dist < min_collision_dist
                min_collision_dist = dist
                vec_to_min_dist_collision = L.Geometry.diff(move)
            end
        end
    end
    
    if !isnothing(vec_to_min_dist_collision)
        dist = min_collision_dist
        if dist < 0.05
            return prev
        else
            normalized_vec = (vec_to_min_dist_collision / L.Geometry.norm(vec_to_min_dist_collision))
            collision_pt = prev + (dist - 0.04) * normalized_vec
            return collision_pt
        end

    end
    
    return new
end

handle_wall_intersection (generic function with 1 method)

In [6]:
@gen (static) function motion_model(state, action, t_to_params)
    (pos, t_prev) = state
    params = t_to_params(t)
    w, σ = params.map, params.step.σ
    
    next_pos_det = det_next_pos(pos, action, params.step.Δ)
    noisy_next_pos ~ broadcasted_normal(next_pos_det, params.step.σ)
    next_pos = handle_wall_intersection(pos, noisy_next_pos, w)
    
    return (next_pos, t_prev + 1)
end

var"##StaticGenFunction_motion_model#348"(Dict{Symbol, Any}(), Dict{Symbol, Any}())

In [7]:
@gen function observe_noisy_distances(state, t_to_params)
    (pos, t) = state
    params = t_to_params(t)
    p = reshape([pos..., params.obs.orientation], (1, 3))
    w, s_noise, outlier, outlier_vol, zmax = params.obs.sensor_args
    _as = L.create_angles(params.obs.fov, params.obs.n_rays)
    segs = GridWorlds.wall_segments(params.map)
    # println(segs.shape)
    zs = L.cast(p, segs; num_a=params.obs.n_rays, zmax)
    # obs ~ L.sensordist_2dp3(sensor_args...)
    obs ~ broadcasted_normal(zs, params.obs.sensor_args.σ)
    return obs
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##observe_noisy_distances#356", Bool[0, 0], false)

In [8]:
# POMDP of this environment
pomdp = GenPOMDPs.GenPOMDP(
    uniform_agent_pos,       # INIT   : params                      ⇝ state
    motion_model,            # STEP   : prev_state, actions, params ⇝ state
    observe_noisy_distances, # OBS    : state, params               ⇝ observation
    (state, action) -> 0.    # UTILITY: state, action, params       → utility
)

GenPOMDPs.GenPOMDP(var"##StaticGenFunction_uniform_agent_pos#317"(Dict{Symbol, Any}(), Dict{Symbol, Any}()), var"##StaticGenFunction_motion_model#348"(Dict{Symbol, Any}(), Dict{Symbol, Any}()), DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##observe_noisy_distances#356", Bool[0, 0], false), var"#51#52"())

In [9]:
# Generative function over trajectories of the POMDP,
# given a fixed action sequence.
trajectory_model = GenPOMDPs.ControlledTrajectoryModel(pomdp)

GenPOMDPs.var"##StaticGenFunction__ControlledTrajectoryModel#428"(Dict{Symbol, Any}(), Dict{Symbol, Any}())

In [32]:
function get_posobs_seq(groundtruth_trace)
    return map(
        trace -> (
            [p for (p, t) in GenPOMDPs.state_sequence(trace)],
            [reshape(o, (:,)) for o in GenPOMDPs.observation_retval_sequence(trace)]
        ),
        groundtruth_trace
    )
end
function get_obs_seq(groundtruth_trace)
    return map(
        trace -> [reshape(o, (:,)) for o in GenPOMDPs.observation_retval_sequence(trace)],
        groundtruth_trace
    )
end

get_obs_seq (generic function with 1 method)

In [20]:
using Dates

function get_save_tr(tr)
    function save_tr()
        filename = "saves/" * string(now()) * "__pomdp_trace.jld"
        Utils.serialize_pomdp_trace(filename, tr[];
            # args_to_serializeable = args -> (args[1:2]..., args[3].params)
        )
    end
    return save_tr
end

get_save_tr (generic function with 1 method)

In [13]:
methods(Utils.serialize_pomdp_trace)

In [261]:
struct ConstantTToParams
    params
end
(p::ConstantTToParams)(t) = p.params

In [283]:
struct SwitchTToParams
    params1
    params2
    switch
end
(p::SwitchTToParams)(t::Makie.Observable) = p(t[])
(p::SwitchTToParams)(t) = p.switch(t) ? p.params1 : p.params2

In [353]:
# 1, 18
DO_CUSTOM_MAP = true

if !DO_CUSTOM_MAP
    i = uniform_discrete(1, 20)
    mp = GridWorlds.load_houseexpo_gridworld(24, i)
else
    mp = GridWorlds.load_custom_map(2)
end
_PARAMS_CLEAN = (;
    map = mp,
    step = (; Δ = .5, σ = 0.005,),
        #0.5 * 120/24, σ = 0.005 ), # step model arguments
    obs = (; fov = 2π, n_rays = 40,  # obs model arguments
        orientation=π/2,
        sensor_args = (;
            w = 5, s_noise = 0.02,
            outlier = 0.0001, outlier_vol = 100.0,
            zmax = 100.0, σ=0.005
)));


_PARAMS_NOISY = (;
    map = mp,
    step = (; Δ = .5, σ = 0.35,),
        #0.5 * 120/24, σ = 0.005 ), # step model arguments
    obs = (; fov = 2π, n_rays = 40,  # obs model arguments
        orientation=π/2,
        sensor_args = (;
            w = 5, s_noise = 0.02,
            outlier = 0.0001, outlier_vol = 100.0,
            zmax = 100.0, σ=0.5
)));

T_TO_PARAMS = SwitchTToParams(_PARAMS_CLEAN, _PARAMS_NOISY, t -> t < 60) #t < 50 || t > 150)
T_TO_PARAMS(0);

In [354]:
if DO_CUSTOM_MAP
    world_trace, take_action = GenPOMDPs.interactive_world_trace(
        trajectory_model, T_TO_PARAMS,
        choicemap((GenPOMDPs.state_addr(0, :cell), (30, 5)))  
    );
else
    world_trace, take_action = GenPOMDPs.interactive_world_trace(
        trajectory_model, T_TO_PARAMS,
    );
end;

In [355]:
(f, t) = GridWorlds.Viz.interactive_gui(
    T_TO_PARAMS(0).map, get_posobs_seq(world_trace), take_action,
    save_fn=get_save_tr(world_trace),
    show_lines_to_walls=true
)
display(f)

GLMakie.Screen(...)

In [146]:
# i = 

In [338]:
(f, t) = GridWorlds.Viz.interactive_gui(
    T_TO_PARAMS(0).map, get_posobs_seq(world_trace), take_action,
    save_fn=get_save_tr(world_trace),
    show_lines_to_walls=false
)
display(f)

GLMakie.Screen(...)

In [356]:
(f, t) = GridWorlds.Viz.play_as_agent_gui(
    get_obs_seq(world_trace),
    take_action,
    show_lines_to_walls=false,
    save_fn=get_save_tr(world_trace),
    framerate=8
)
display(f)

GLMakie.Screen(...)

In [255]:
# Interactive GUI with visible walls
(f, t) = GridWorlds.Viz.interactive_gui(
    T_TO_PARAMS(0).map, get_posobs_seq(world_trace), take_action,
    save_fn=get_save_tr(world_trace)
)
display(f)

GLMakie.Screen(...)

In [357]:
tr2 = Utils.deserialize_pomdp_trace(
    # "saves/2023-12-12T15:51:18.757__pomdp_trace.jld",
    # "saves/2023-12-12T16:47:41.689__pomdp_trace.jld",
    # "saves/2023-12-12T17:07:55.823__pomdp_trace.jld",
    # "saves/2023-12-12T17:10:15.804__pomdp_trace.jld",
    "saves/2023-12-12T17:45:51.846__pomdp_trace.jld",
    trajectory_model,
    args_from_serializeable=(args -> (
            args[1:2]...,
            SwitchTToParams(_PARAMS_CLEAN, _PARAMS_NOISY, t -> t < 60 || t > 325)
        ))
    );
trace2, take_action2 = GenPOMDPs.make_trace_interactive(tr2);

# (f, t) = GridWorlds.Viz.interactive_gui(
#     get_args(trace2[])[3](0).map, get_posobs_seq(trace2), take_action2,
#     save_fn=get_save_tr(trace2),
#     framerate=8
# )
(f, t) = GridWorlds.Viz.play_as_agent_gui(
    get_obs_seq(trace2),
    take_action2,
    show_lines_to_walls=false,
    save_fn=get_save_tr(trace2),
    framerate=8
)
display(f)
# trace2, take_action2 = GenPOMDPs.make_trace_interactive(tr2);

GLMakie.Screen(...)

[Animating]


In [21]:
get_args(trace2[])

(71, Any[:left, :left, :left, :left, :left, :right, :right, :right, :right, :right  …  :down, :down, :right, :down, :down, :down, :down, :down, :down, :down], T_TO_PARAMS)

In [28]:
# T_TO_PARAMS(0).map :: GridWorlds.GridWorld

In [38]:
T_TO_PARAMS(0).map==get_args(trace2[])[3](0).map

false

In [25]:
function choicemap_to_serializable(cm)
    address_to_value = collect(get_values_shallow(cm))
    address_to_sub = [
        (addr, choicemap_to_serializable(submap))
        for (addr, submap) in get_submaps_shallow(cm)
    ]
    return (address_to_value, address_to_sub)
end
function serializable_to_choicemap(s)
    (a_to_v, a_to_s) = s
    cm = choicemap()
    for (a, v) in a_to_v
        cm[a] = v
    end
    for (a, s) in a_to_s
        Gen.set_submap!(cm, a, serializable_to_choicemap(s))
    end
    return cm
end

serializable_to_choicemap (generic function with 1 method)

In [36]:
function choicemap_to_serializable(cm)
    address_to_value = collect(get_values_shallow(cm))
    address_to_sub = [
        (addr, choicemap_to_serializable(submap))
        for (addr, submap) in get_submaps_shallow(cm)
    ]
    return (address_to_value, address_to_sub)
end
function serializable_to_choicemap(s)
    (a_to_v, a_to_s) = s
    cm = choicemap()
    for (a, v) in a_to_v
        cm[a] = v
    end
    for (a, s) in a_to_s
        Gen.set_submap!(cm, a, serializable_to_choicemap(s))
    end
    return cm
end

import Serialization

function serialize_pomdp_trace(filename, trace)
    Serialization.serialize(filename, Dict(
        "args" => get_args(trace),
        "choices" => choicemap_to_serializable(get_choices(trace))
    ))
end
function deserialize_pomdp_trace(filename, pomdp_trajectory_model)
    s = Serialization.deserialize(filename)
    args = s["args"]
    choices = serializable_to_choicemap(s["choices"])
    tr, _ = Gen.generate(pomdp_trajectory_model, args, choices)
    return tr
end

In [41]:
 using Dates
filename = "saves/" * string(now()) * "__pomdp_trace.jld"
serialize_pomdp_trace(filename, world_trace[])

In [42]:
tr2 = deserialize_pomdp_trace(filename, trajectory_model);

[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/connection/websocket.jl:40[39m


In [40]:
get_choices(tr2) == get_choices(world_trace[])

true

[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/packages/JSServe/BRpDB/src/serialization/protocol.jl:32[39m
[33m[1m└ [22m[39m[90m@ JSServe ~/.julia/p