Skip to content

Commit

Permalink
update for POMDPs v0.9
Browse files Browse the repository at this point in the history
  • Loading branch information
ancorso committed Oct 21, 2020
1 parent 7bf3797 commit 7e97fe7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.1.0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GridInterpolations = "bb4c363b-b914-514b-8517-4eb369bc008a"
LocalFunctionApproximation = "db97f5ab-fc25-52dd-a8f9-02a257c35074"
POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755"
POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
Expand Down
14 changes: 9 additions & 5 deletions src/LocalApproximationPolicyEvaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ module LocalApproximationPolicyEvaluation
using LocalFunctionApproximation
using Distributions
import POMDPs: Solver, solve, Policy, action, value
using POMDPLinter: @POMDP_require, @req, @subreq, @warn_requirements

export LocalPolicyEvalSolver, LocalPolicyEvalPolicy, action_and_prob

Expand Down Expand Up @@ -153,13 +154,13 @@ module LocalApproximationPolicyEvaluation
# Generative Model
for j in 1:solver.n_generative_samples
sp_point, r, isTerm = 0, 0, false
if haskey(next_state_dict, (s,a))
if solver.n_generative_samples ==1 && haskey(next_state_dict, (s,a))
sp_point, r, isTerm = next_state_dict[(s,a)]
else
sp, r = gen(DDNOut(:sp,:r), mdp, s, a, solver.rng)
sp, r = gen(mdp, s, a, solver.rng)
isTerm = isterminal(mdp, sp)
sp_point = POMDPs.convert_s(Vector{Float64}, sp, mdp)
next_state_dict[(s,a)] = (sp_point, r, isTerm)
solver.n_generative_samples == 1 && (next_state_dict[(s,a)] = (sp_point, r, isTerm))
end

u += r
Expand Down Expand Up @@ -253,9 +254,12 @@ module LocalApproximationPolicyEvaluation
# mdp is generative or explicit
if policy.is_mdp_generative
for j in 1:policy.n_generative_samples
sp, r = gen(DDNOut(:sp,:r), mdp, s, a, policy.rng)
sp, r = gen(mdp, s, a, policy.rng)
sp_point = POMDPs.convert_s(Vector{Float64}, sp, mdp)
u += r + discount_factor*LocalFunctionApproximation.compute_value(policy.interp, sp_point)
u += r
if !isterminal(mdp, sp)
u += discount_factor*LocalFunctionApproximation.compute_value(policy.interp, sp_point)
end
end
u = u / policy.n_generative_samples
else
Expand Down
5 changes: 3 additions & 2 deletions test/gridworld_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ g_size = (9,9)
g = SimpleGridWorld(size = g_size, rewards = Dict(GWPos(g_size...) => 1, GWPos(1,1) => 0), tprob = 1., discount=1)
action_probability(g::SimpleGridWorld, s, a) = 0.25
POMDPs.convert_s(::Type{AbstractArray}, s::GWPos, g::SimpleGridWorld) = SVector{4, Float64}(1., s[1], s[2], s[1]*s[2])
POMDPs.gen(d::DDNOut{(:sp,:r)}, mdp::SimpleGridWorld, s, a, rng = Random.GLOBAL_RNG) = (sp =rand(transition(g, s, a )), r=reward(g, s, a))
POMDPs.initialstate(g::SimpleGridWorld) = rand(initialstate_distribution(g))
POMDPs.gen(mdp::SimpleGridWorld, s, a, rng = Random.GLOBAL_RNG) = (sp =rand(transition(g, s, a )), r=reward(g, s, a))
POMDPs.initialstate(g::SimpleGridWorld) = initialstate_distribution(g)



# Step 2 - Solve the problem semi-exactly using local approximation
Expand Down

0 comments on commit 7e97fe7

Please sign in to comment.