In [None]:
    using POMDPs
    using DiscreteValueIteration 
    
    using AutomotiveDrivingModels
    using AutoViz
    using AutomotiveSensors
    using AutomotivePOMDPs
    using Parameters
    using StaticArrays

    using GridInterpolations 
    using POMDPToolbox
    using JLD

    
    include("../src/pedestrian_crossing/pomdp_types.jl")
    include("../src/pedestrian_crossing/spaces.jl")
    include("../src/pedestrian_crossing/transition.jl")
    include("../src/pedestrian_crossing/observation.jl")
    include("../src/pedestrian_crossing/belief.jl")

    include("../src/pedestrian_crossing/frenet_pedestrian_pomdp.jl")

In [None]:
N_PROCS=56
addprocs(N_PROCS)
@everywhere begin 
    using POMDPs
    using DiscreteValueIteration 
    
    using AutomotiveDrivingModels
    using AutoViz
    using AutomotiveSensors
    using AutomotivePOMDPs
    using Parameters
    using StaticArrays

    using GridInterpolations 
    using POMDPToolbox
    using JLD

    
    include("../src/pedestrian_crossing/pomdp_types.jl")
    include("../src/pedestrian_crossing/spaces.jl")
    include("../src/pedestrian_crossing/transition.jl")
    include("../src/pedestrian_crossing/observation.jl")
    include("../src/pedestrian_crossing/belief.jl")

    include("../src/pedestrian_crossing/frenet_pedestrian_pomdp.jl")
    pomdp = SingleOCFPOMDP()
end 

solver = ParallelValueIterationSolver(n_procs=N_PROCS, max_iterations=200, belres=1e-4, include_Q=true, verbose=true)



In [None]:
vi_policy = solve(solver, pomdp)
qmdp_policy = AlphaVectorPolicy(pomdp, vi_policy.qmat, vi_policy.action_map)

# save policy!
JLD.save("policy.jld", "policy", qmdp_policy)


In [None]:
function AutomotivePOMDPs.action(policy::AlphaVectorPolicy, b::SingleOCFBelief)
    alphas = policy.alphas 
    util = zeros(n_actions(pomdp)) 
    for i=1:n_actions(pomdp)
        res = 0.0
        for (j,s) in enumerate(b.vals)
            si = state_index(pomdp, s)
            res += alphas[i][si]*b.probs[j]
        end
        util[i] = res
    end
    ihi = indmax(util)
    return policy.action_map[ihi]
end


In [None]:
n_states(pomdp)

In [None]:
pomdp = SingleOCFPOMDP()

qmdp_policy = load("policy.jld")["policy"];


In [None]:
Pkg.add("Plots")
using Interact
using Plots
gr()

policy_grid = Matrix(length(pomdp.T_RANGE),length(pomdp.S_RANGE))

@manipulate for ego_v in pomdp.EGO_V_RANGE
    @manipulate for ped_v in pomdp.PED_V_RANGE

    #ego_v = 5.38462
    for i=1:length(pomdp.T_RANGE)
        ped_t = pomdp.T_RANGE[i]
        for j=1:length(pomdp.S_RANGE)
            ped_s = pomdp.S_RANGE[j]
            b = SparseCat([SingleOCFState(0.0, ego_v, ped_s, ped_t, 1.57, ped_v)],[1.])
            act = action(qmdp_policy, b) 
          #  println(ped_s, "/", ped_t, " act: ", act.acc)
            policy_grid[i,j] = act.acc
        end
    end

    xs = [pomdp.S_RANGE[i] for i = 1:length(pomdp.S_RANGE)]
    ys = [pomdp.T_RANGE[i] for i = 1:length(pomdp.T_RANGE)]
    heatmap(xs,ys,policy_grid,aspect_ratio=1)


    end
end




In [None]:
using Interact
using Plots
gr()

policy_grid = Matrix(length(pomdp.T_RANGE),length(pomdp.S_RANGE))

@manipulate for ego_v in pomdp.EGO_V_RANGE
    @manipulate for ped_v in pomdp.PED_V_RANGE

    #ego_v = 5.38462
    for i=1:length(pomdp.T_RANGE)
        ped_t = pomdp.T_RANGE[i]
        for j=1:length(pomdp.S_RANGE)
            ped_s = pomdp.S_RANGE[j]
            b = SparseCat([SingleOCFState(0.0, ego_v, ped_s, ped_t, 1.57, ped_v)],[1.])
            act = action(qmdp_policy, b) 
          #  println(ped_s, "/", ped_t, " act: ", act.acc)
            policy_grid[i,j] = act.lateral_movement
        end
    end

    xs = [pomdp.S_RANGE[i] for i = 1:length(pomdp.S_RANGE)]
    ys = [pomdp.T_RANGE[i] for i = 1:length(pomdp.T_RANGE)]
    heatmap(xs,ys,policy_grid,aspect_ratio=1)


    end
end


In [None]:
# Test action space


cnt = 0
for (index, a) in enumerate(pomdp.action_space)
    idx = action_index(pomdp,a)
    if (idx != index)
        println("error")
        cnt = cnt + 1
    end
    println(a)
end
println(cnt)


In [None]:
state_space = states(pomdp)

state_space[state_index(pomdp,SingleOCFState(0.0, 4, 40.0, 0.0, 1.57, 1.5))]

In [None]:

s = SingleOCFState(0.0, 4.3076923076923075, 40.0, 4.0, 1.57, 1.5)
sp = SingleOCFState(0.0, 4.3076923076923075, 40.0, 4.0, 1.57, 1.5)

act = SingleOCFAction(0.0, 1.0)
reward(pomdp, s, act, sp) 