In [1]:
# Load the Julia environment needed for this notebook
import Pkg
Pkg.activate("../Tasks2D")

[32m[1m  Activating[22m[39m project at `~/Developer/research/fall2023/cocosci/tasks2D/Tasks2D`


In [2]:
import Makie   # Visualization Library
using Revise      # For development; makes it so modifications
                  # to imported modules are immediately reflected in this Julia session
using Gen         # Gen probabilistic programming library
import GenParticleFilters # Additional particle filtering functionality for Gen
import GridWorlds # Simple gridworld functionality
import LineWorlds
const L = LineWorlds
import LineWorlds: cast # Ray caster
import GenPOMDPs  # Beginnings of a Gen POMDP library

import Tasks2D

includet("SLAM/Utils.jl")

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling GridWorlds [c15fd557-8ec5-4bf9-9d87-df57ac477796]


In [3]:
import LinearAlgebra

In [4]:
include("SLAM/Utils.jl")
Utils.mapped_uniform



Main.Utils.MappedUniform()

In [5]:
import GLMakie
GLMakie.activate!()

state = (position, time, has_ever_hit_a_wall)

In [6]:
# Initial position model

# Distribution to sample uniformly from a Julia Set
using Tasks2D.Distributions: uniform_from_set

@gen (static) function uniform_agent_pos(t_to_params)
    w = t_to_params(0).map # a map, represented as a GridWorlds.GridWorld
    
    cell ~ uniform_from_set(GridWorlds.empty_cells(w))
    
    # Cell (i, j) corresponds to the region from i-1 to i and j-1 to j
    x ~ uniform(cell[1] - 1, cell[1])
    y ~ uniform(cell[2] - 1, cell[2])
    
    return ([x, y], 0, false)
end

var"##StaticGenFunction_uniform_agent_pos#317"(Dict{Symbol, Any}(), Dict{Symbol, Any}())

In [7]:
function det_next_pos(pos, a, Δ)
    (x, y) = pos
    a == :up    ? [x, y + Δ] :
    a == :down  ? [x, y - Δ] : 
    a == :left  ? [x - Δ, y] :
    a == :right ? [x + Δ, y] :
    a == :stay  ? [x, y]     :
                error("Unrecognized action: $a")
end

function handle_wall_intersection(prev, new, gridworld)
    walls = GridWorlds.nonempty_segments(gridworld)
    move = L.Segment(prev, new)
    
    min_collision_dist = Inf
    vec_to_min_dist_collision = nothing
    for i in 1:(size(walls)[1])
        wall = walls[i, :]
        # print("wall: $wall")
        do_intersect, dist = L.Geometry.cast(move, L.Segment(wall))

        if do_intersect && dist ≤ L.Geometry.norm(move)
            if dist < min_collision_dist
                min_collision_dist = dist
                vec_to_min_dist_collision = L.Geometry.diff(move)
            end
        end
    end
    
    if !isnothing(vec_to_min_dist_collision)
        dist = min_collision_dist
        if dist < 0.05
            return (prev, true)
        else
            normalized_vec = (vec_to_min_dist_collision / L.Geometry.norm(vec_to_min_dist_collision))
            collision_pt = prev + (dist - 0.04) * normalized_vec
            return (collision_pt, true)
        end
    end
    
    return (new, false)
end

handle_wall_intersection (generic function with 1 method)

In [8]:
@gen (static) function motion_model(state, action, t_to_params)
    (a, wall_clock_time) = action
    (pos, t_prev, prev_hit_wall) = state
    params = t_to_params(t_prev + 1)
    w, σ = params.map, params.step.σ
    
    next_pos_det = det_next_pos(pos, a, params.step.Δ)
    noisy_next_pos ~ broadcasted_normal(next_pos_det, params.step.σ)
    (next_pos, hit_wall) = handle_wall_intersection(pos, noisy_next_pos, w)
    
    return (next_pos, t_prev + 1, prev_hit_wall || hit_wall)
end

var"##StaticGenFunction_motion_model#355"(Dict{Symbol, Any}(), Dict{Symbol, Any}())

In [9]:
@gen function observe_noisy_distances(state, t_to_params)
    (pos, t, _) = state
    params = t_to_params(t)

    p = reshape([pos..., params.obs.orientation], (1, 3))
    _as = L.create_angles(params.obs.fov, params.obs.n_rays)
    wall_segs = GridWorlds.wall_segments(params.map)
    strange_segs = GridWorlds.strange_segments(params.map)
    
    w, s_noise, outlier, outlier_vol, zmax = params.obs.wall_sensor_args
    dists_walls = L.cast(p, wall_segs; num_a=params.obs.n_rays, zmax)
    dists_walls = reshape(dists_walls, (:,))

    w, s_noise, outlier, outlier_vol, zmax = params.obs.strange_sensor_args
    dists_strange = L.cast(p, strange_segs; num_a=params.obs.n_rays, zmax)
    dists_strange = reshape(dists_strange, (:,))

    is_wall_measurement = [w < s for (w, s) in zip(dists_walls, dists_strange)]
    wall_measurements = dists_walls[is_wall_measurement]
    
    σ_wall = params.obs.wall_sensor_args.σ
    noisy_wall_measurements ~ Gen.broadcasted_normal(wall_measurements, s_noise)

    strange_measurements = dists_strange[[!w for w in is_wall_measurement]]
    if !isempty(strange_measurements)
        mins = [params.obs.strange_sensor_args.dist_to_zmin(m) for m in strange_measurements]
        maxs = [params.obs.strange_sensor_args.dist_to_zmax(m) for m in strange_measurements]
        noisy_strange_measurements ~ Utils.mapped_uniform(mins, maxs)
    else
        noisy_strange_measurements = []
    end

    obs = []
    wall_ctr = 1
    strange_ctr = 1
    for i in 1:params.obs.n_rays
        if is_wall_measurement[i]
            push!(obs, noisy_wall_measurements[wall_ctr])
            wall_ctr += 1
        else
            push!(obs, noisy_strange_measurements[strange_ctr])
            strange_ctr += 1
        end
    end
    # dists = min.(dists_walls, dists_strange)
    # dists = reshape(dists, (:,))

    # # get vector of sigmas, to apply to `dists`
    # # params.obs.sensor_args.σ_wall for walls; params.obs.sensor_args.σ_strange for strange
    # # σ_strange = params.obs.strange_sensor_args.σ
    # sigmas = [w < s ? σ_wall : σ_strange for (w, s) in zip(dists_walls, dists_strange)]
    # sigmas = reshape(sigmas, (:,))

    # obs ~ Gen.mvnormal(dists, LinearAlgebra.Diagonal(sigmas))
    return obs
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##observe_noisy_distances#363", Bool[0, 0], false)

In [10]:
# POMDP of this environment
pomdp = GenPOMDPs.GenPOMDP(
    uniform_agent_pos,       # INIT   : params                      ⇝ state
    motion_model,            # STEP   : prev_state, actions, params ⇝ state
    observe_noisy_distances, # OBS    : state, params               ⇝ observation
    (state, action) -> 0.    # UTILITY: state, action, params       → utility
)

GenPOMDPs.GenPOMDP(var"##StaticGenFunction_uniform_agent_pos#317"(Dict{Symbol, Any}(), Dict{Symbol, Any}()), var"##StaticGenFunction_motion_model#355"(Dict{Symbol, Any}(), Dict{Symbol, Any}()), DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##observe_noisy_distances#363", Bool[0, 0], false), var"#69#70"())

In [11]:
# Generative function over trajectories of the POMDP,
# given a fixed action sequence.
trajectory_model = GenPOMDPs.ControlledTrajectoryModel(pomdp)

GenPOMDPs.var"##StaticGenFunction__ControlledTrajectoryModel#435"(Dict{Symbol, Any}(), Dict{Symbol, Any}())

In [12]:
function get_posobs_seq(groundtruth_trace)
    return map(
        trace -> (
            [p for (p, t, h) in GenPOMDPs.state_sequence(trace)],
            [reshape(o, (:,)) for o in GenPOMDPs.observation_retval_sequence(trace)]
        ),
        groundtruth_trace
    )
end
function get_obs_seq(groundtruth_trace)
    return map(
        trace -> [reshape(o, (:,)) for o in GenPOMDPs.observation_retval_sequence(trace)],
        groundtruth_trace
    )
end

using Dates

function get_save_tr(tr)
    function save_tr(viz_actions)
        filename = "saves/" * string(now()) * "__pomdp_trace.jld"
        Utils.serialize_trace_and_viz_actions(filename, tr[];
            viz_actions=viz_actions,
            # args_to_serializeable = args -> (args[1:2]..., args[3].params)
        )
    end
    return save_tr
end

function get_take_action(_take_action)
    function take_action(a)
        _take_action((a, Dates.now()))
    end
    return take_action
end
function get_interactive_trace(args...; kwargs...)
    (trace, _take_action) = GenPOMDPs.interactive_world_trace(args...; kwargs...)
    return (trace, get_take_action(_take_action))
end
function make_trace_interactive(args...; kwargs...)
    (trace, _take_action) = GenPOMDPs.make_trace_interactive(args...; kwargs...)
    return (trace, get_take_action(_take_action))
end

function get_did_hitwall_observable(trace)
    return map(trace -> GenPOMDPs.state_sequence(trace)[end][3], trace)
end
function close_window(f)
    glfw_window = GLMakie.to_native(display(f))
    GLMakie.GLFW.SetWindowShouldClose(glfw_window, true)
end

function get_action_times_observable(trace)
    return map(trace -> [t for (a, t) in GenPOMDPs.action_sequence(trace)], trace)
end
function get_timing_args(trace; speedup_factor=1, max_delay=5) # 5 seconds max delay
    return (get_action_times_observable(trace), speedup_factor, max_delay)
end

struct ConstantTToParams
    params
end
(p::ConstantTToParams)(t) = p.params

struct SwitchTToParams
    params1
    params2
    switch
end
(p::SwitchTToParams)(t::Makie.Observable) = p(t[])
(p::SwitchTToParams)(t) = p.switch(t) ? p.params1 : p.params2

struct TimedSwitchTToParams
    params1
    params2
    timewindows_params1
    timewindows_params2
end
(p::TimedSwitchTToParams)(t::Makie.Observable) = p(t[])
function (p::TimedSwitchTToParams)(t)
    for (start, stop) in p.timewindows_params1
        if start ≤ t ≤ stop
            return p.params1
        end
    end
    for (start, stop) in p.timewindows_params2
        if start ≤ t ≤ stop
            return p.params2
        end
    end
    error("No time window found for time $t")
end

In [13]:
struct DistanceToConstant
    c
end
(d::DistanceToConstant)(x) = d.c
struct MultiplyDistanceByConstant
    c
end
(d::MultiplyDistanceByConstant)(x) = x * d.c

In [81]:
MAP_W_STRANGE = GridWorlds.load_custom_map(4)
MAP_WOUT_STRANGE = GridWorlds.load_custom_map(5)
_get_params(map) = (;
    # map = GridWorlds.load_custom_map(5),
    map=map,
    step = (; Δ = .5, σ = 0.005),
    obs = (; fov = 2π, n_rays = 90, orientation = π/2,
        # I think currently only σ is used
        wall_sensor_args = (;
            w = 5, s_noise = 0.02,
            outlier = 0.0001,
            outlier_vol = 100.0,
            zmax = 100.0, σ=0.005
        ),

        strange_sensor_args = (;
            w = 5, s_noise = 0.02,
            outlier = 0.0001,
            outlier_vol = 100.0,
            zmax = 100.0, σ=5.5,
            dist_to_zmin = DistanceToConstant(1.0),
            dist_to_zmax = MultiplyDistanceByConstant(2.5)
            # min_dist=1.0
            # max_dist=10.0
        )
    )
)

T_TO_PARAMS = TimedSwitchTToParams(
    _get_params(MAP_W_STRANGE),
    _get_params(MAP_WOUT_STRANGE),
    ((6, Inf),),
    ((-Inf, 6),),
)
# T_TO_PARAMS = ConstantTToParams(_PARAMS)
T_TO_PARAMS(0);

In [82]:
world_trace, take_action = get_interactive_trace(
    trajectory_model, T_TO_PARAMS,
    choicemap((GenPOMDPs.state_addr(0, :cell), (3, 3)))  
);

In [83]:
(f, t, actions) = GridWorlds.Viz.interactive_gui(
    t -> T_TO_PARAMS(t).map, get_posobs_seq(world_trace), take_action,
    save_fn=get_save_tr(world_trace),
    show_lines_to_walls=true,
    framerate=8,
    close_on_hitwall=true,
    did_hitwall_observable=get_did_hitwall_observable(world_trace),
    close_window=close_window,

    # For playback
    timing_args=get_timing_args(world_trace; speedup_factor=1, max_delay=5)
)
display(f)

GLMakie.Screen(...)

Trace & viz actions serialized to saves/2023-12-13T15:17:15.827__pomdp_trace.jld.
Trace & viz actions serialized to saves/2023-12-13T15:17:19__pomdp_trace.jld.


In [79]:
(f, t, actions) = GridWorlds.Viz.interactive_gui(
    t -> T_TO_PARAMS(t).map, get_posobs_seq(world_trace), take_action,
    save_fn=get_save_tr(world_trace),
    show_lines_to_walls=false,
    framerate=8,
    close_on_hitwall=true,
    did_hitwall_observable=get_did_hitwall_observable(world_trace),
    close_window=close_window,

    # For playback
    timing_args=get_timing_args(world_trace; speedup_factor=1, max_delay=5)
)
display(f)

GLMakie.Screen(...)

### Agent view

In [80]:
(f, t, actions) = GridWorlds.Viz.play_as_agent_gui(
    get_obs_seq(world_trace),
    take_action,
    show_lines_to_walls=false,
    save_fn=get_save_tr(world_trace),
    framerate=8,
    close_on_hitwall=true,
    did_hitwall_observable=get_did_hitwall_observable(world_trace),
    close_window=close_window,

    timing_args=get_timing_args(world_trace; speedup_factor=1, max_delay=5)
)
display(f)

GLMakie.Screen(...)

In [77]:
# times = map(x -> x[2], GenPOMDPs.action_sequence(world_trace[]))
# times[2:end] - times[1:(end-1)]

Trace & viz actions serialized to saves/2023-12-13T14:45:20.944__pomdp_trace.jld.


## Load saved trace

In [None]:
_noise_enabled = false

false

In [14]:
tr2, viz_actions = Utils.deserialize_trace_and_viz_actions(
    # "saves/2023-12-12T15:51:18.757__pomdp_trace.jld",
    # "saves/2023-12-12T16:47:41.689__pomdp_trace.jld",
    # "saves/2023-12-12T17:07:55.823__pomdp_trace.jld",
    # "saves/2023-12-12T17:10:15.804__pomdp_trace.jld",
    # "saves/2023-12-12T17:45:51.846__pomdp_trace.jld",
    # "saves/2023-12-12T20:21:58.518__pomdp_trace.jld",
    # "saves/2023-12-12T20:23:53.018__pomdp_trace.jld",
    # "saves/2023-12-12T20:25:10.383__pomdp_trace.jld",
    # "saves/2023-12-12T20:44:22.276__pomdp_trace.jld",
    # "saves/2023-12-12T20:49:13.392__pomdp_trace.jld",
    # "saves/2023-12-12T20:49:13.392__pomdp_trace.jld",
    # "saves/2023-12-12T20:52:10.206__pomdp_trace.jld",
    # "saves/2023-12-12T21:02:18.894__pomdp_trace.jld",
    # "saves/2023-12-12T21:00:33.101__pomdp_trace.jld",
    # "saves/2023-12-12T20:49:13.392__pomdp_trace.jld",
    # "saves/2023-12-12T20:52:10.206__pomdp_trace.jld",
    # "saves/2023-12-13T14:23:50.487__pomdp_trace.jld",
    # "saves/2023-12-13T14:26:06.838__pomdp_trace.jld",
    # "saves/2023-12-13T14:28:28.878__pomdp_trace.jld",
    # "saves/2023-12-13T14:30:52.756__pomdp_trace.jld",
    # "saves/2023-12-13T14:45:20.944__pomdp_trace.jld",
    # "saves/2023-12-13T14:48:28.013__pomdp_trace.jld",
    "saves/2023-12-13T15:17:19__pomdp_trace.jld",

    # clean version - saves/2023-12-12T20:52:10.206__pomdp_trace.jld
    # noisy version - saves/2023-12-12T20:49:13.392__pomdp_trace.jld

    # map 18 - clean - saves/2023-12-12T21:00:33.101__pomdp_trace.jld
    # map 18 - noisy - saves/2023-12-12T21:02:18.894__pomdp_trace.jld
    trajectory_model,
    # args_from_serializeable=(args -> (args[1:2]..., ConstantTToParams(args[3]))),
    # args_from_serializeable=(args -> (
    #         args[1:2]...,
    #         SwitchTToParams(_PARAMS_CLEAN, _PARAMS_NOISY, t -> t < 60 || !_noise_enabled)
    #     ))
    );
trace2, take_action2 = make_trace_interactive(tr2);

(f, t, _replay_actions) = GridWorlds.Viz.play_as_agent_gui(
    get_obs_seq(trace2),
    take_action2,
    show_lines_to_walls=false,
    save_fn=get_save_tr(trace2),
    framerate=8,

    # Timing control for playback
    timing_args=get_timing_args(trace2; speedup_factor=3, max_delay=5)
)
display(f)

GLMakie.Screen(...)

In [25]:
times = map(x -> x[2], GenPOMDPs.action_sequence(trace2[]))
maximum(times[2:end] - times[1:end-1])

238495 milliseconds

### Saved trace on visible map

In [15]:
(f, t) = GridWorlds.Viz.interactive_gui(
    t -> get_args(trace2[])[3](t).map, get_posobs_seq(trace2), take_action2,
    save_fn=get_save_tr(trace2),
    framerate=8,

    # Timing control for playback
    timing_args=get_timing_args(trace2; speedup_factor=3, max_delay=5)
)
display(f)

GLMakie.Screen(...)