In [None]:
using Pkg
if isfile("../Project.toml") && isfile("../Manifest.toml")
    Pkg.activate("..");
    ENV["PYTHON"] = "python3";
end

#Pkg.build("PyCall")
using JuliaProbo, Distributed
using Seaborn: heatmap, savefig
gr()

In [None]:
mutable struct PolicyEvaluator
    pose_min::Vector{Float64}
    pose_max::Vector{Float64}
    reso::Vector{Float64}
    goal::Goal
    index_nums::Vector{Int64}
    indices::Vector{Tuple{Int64, Int64, Int64}}
    value_function_::AbstractArray{Float64, 3}
    final_state_flags_::AbstractArray{Float64, 3}
    policy_::AbstractArray{Float64, 4}
end

function set_final_state(
        reso::Vector{Float64},
        goal::Goal,
        pose_min::Vector{Float64},
        index::Tuple{Int64, Int64, Int64}
    )
    x_min, y_min, θ_min = pose_min .+ reso .* ([index...] .- 1)
    x_max, y_max, θ_max = pose_min .+ reso .* ([index...])
    
    corners = [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]]
    return convert(Float64, all( [inside(goal, corner) for corner in corners] ))
end

function initial_policy(pose::Vector{Float64}, goal::Goal)
    x, y, θ = pose[1], pose[2], pose[3]
    dx, dy = goal.x - x, goal.y - y
    direction = convert(Int64, round((atan(dy, dx) - θ) * 180 / pi))
    while direction > 180
        direction -= 360
    end
    while direction <= -180
        direction += 360
    end
    v, ω = 0.0, 0.0
    if direction > 10
        v, ω = 0.0, 2.0
    elseif direction < -10
        v, ω = 0.0, -2.0
    else
        v, ω = 1.0, 0.0
    end
    [v, ω]
end

function init_policy(index_nums::Vector{Int64}, indices::Vector{Tuple{Int64, Int64, Int64}}, reso::Vector{Float64}, pose_min::Vector{Float64}, goal::Goal)
    ret = zeros(Float64, index_nums..., 2)
    for ind in 1:lastindex(indices)
        index = indices[ind]
        center = pose_min .+ reso .* ([index...] * 1.0 .- 0.5)
        ret[index..., :] = initial_policy(center, goal)
    end
    return ret
end

# constructor
function PolicyEvaluator(reso::Vector{Float64}, goal::Goal, lowerleft=[-4.0, -4.0], upperright=[4.0, 4.0])
    pose_min = vcat(lowerleft, [0.0])
    pose_max = vcat(upperright, [2pi])
    
    index_nums = [convert(Int64, round( (pose_max[i] - pose_min[i])/reso[i] )) for i in 1:3]
    v = zeros(Float64, index_nums[1], index_nums[2], index_nums[3])
    f = zeros(Float64, index_nums[1], index_nums[2], index_nums[3])
    indices = Vector{Tuple{Int64, Int64, Int64}}(undef, 0)
    for id1 in 1:index_nums[1]
        for id2 in 1:index_nums[2]
            for id3 in 1:index_nums[3]
                index = (id1, id2, id3,)
                val = set_final_state(reso, goal, pose_min, index)
                @inbounds f[index...] = val
                @inbounds v[index...] = (val == 1.0) ? goal.value : (-100.0)
                # this line is problematic if we use @distributed
                push!(indices, (id1, id2, id3,))
            end
        end
    end
    
    policy_ = init_policy(index_nums, indices, reso, pose_min, goal)
    return PolicyEvaluator(pose_min, pose_max, reso, goal, index_nums, indices, v, f, policy_)
end

In [None]:
pe = PolicyEvaluator([0.2, 0.2, pi/18], Goal(-3.0, -3.0));

In [None]:
p = zeros(pe.index_nums...)
for index in pe.indices
    p[index...] = sum(pe.policy_[index..., :])
end
plt = heatmap(rotl90(p[:, :, 19]))
savefig("images/ch10_policy_evaluation3.png")

<img src="images/ch10_policy_evaluation3.png">