Skip to content

Commit

Permalink
simplified example to demonstrate how to train an agent to drive with DP
Browse files Browse the repository at this point in the history
  • Loading branch information
ancorso committed Jul 18, 2020
1 parent 13c7b19 commit debfa4c
Showing 1 changed file with 6 additions and 32 deletions.
38 changes: 6 additions & 32 deletions examples/train_sut_to_drive.jl
Expand Up @@ -2,20 +2,17 @@ using AdversarialDriving
using POMDPs, POMDPPolicies, POMDPSimulators
using Distributions
using GridInterpolations, LocalFunctionApproximation, LocalApproximationValueIteration
using Serialization

save_folder ="examples/training_hist/"
Np, Nv = 3, 3
N_suts = 10
## setup training params
Np, Nv = 3,3 #25, 15
Nsteps = 2#50

# for i=1:N_suts
i = 1
## Construct a disturbance model for the adversaries
da_std = rand(Exponential(0.5))
goal_toggle_p = 10. ^rand(Uniform(-5., -1.))
blinker_toggle_p = 10. ^rand(Uniform(-5., -1.))
v_des = rand(Uniform(15., 30.))
per_timestep_penalty = rand([0, 1e-4, 1e-3, 5e-3])

disturbances = Sampleable[Normal(0,da_std), Bernoulli(goal_toggle_p), Bernoulli(blinker_toggle_p), Bernoulli(0.), Bernoulli(0.)]

## construct the MDP
Expand All @@ -28,29 +25,6 @@ goals = Float64.(Tint_goals[laneid(adversaries(mdp)[1].get_initial_entity())])
grid = RectangleGrid(range(0., stop=100.,length=Np), range(0, stop=30., length=Nv), goals, [0.0, 1.0],
range(15., stop=100., length=Np), range(0., stop=30., length=Nv), [5.0], [1.0])

interp = LocalGIFunctionApproximator(grid) # Create the local function approximator using the grid
solver = LocalApproximationValueIterationSolver(interp, is_mdp_generative = true, n_generative_samples = 10, verbose = true, max_iterations = 25, belres = 1e-6)
interp = LocalGIFunctionApproximator(grid)
solver = LocalApproximationValueIterationSolver(interp, is_mdp_generative = true, n_generative_samples = 5, verbose = true, max_iterations = Nsteps)
policy = solve(solver, mdp)

folder = string(save_folder, "sut_version_", lpad(i, 2, "0"), "/")
try mkdir(folder) catch end
println("saving...")
serialize(string(folder, "mdp"), mdp)
serialize(string(folder, "policy"), policy)

println("Evaluating the performance...")

h = simulate(HistoryRecorder(max_steps = 150), mdp, policy)
scenes_to_gif(state_hist(h), mdp.roadway, "out.gif")
## Evaluate the performance
tot_suc, Ntrials =0, 10
for k=1:Ntrials
println("k:", k)
h = simulate(HistoryRecorder(max_steps = 150), mdp, policy)
global tot_suc += any(reward_hist(h) .> 1.)
end
tot_suc
prob_suc = tot_suc / Ntrials
println("success prob: ", prob_suc)
write("performance.txt", string(prob_suc))
# end

0 comments on commit debfa4c

Please sign in to comment.