Skip to content

Commit

Permalink
Added functionality to learn a driving policy against adversaries (#3)
Browse files Browse the repository at this point in the history
* added code to train a SUT driving agent

* simplified example to demonstrate how to train an agent to drive with DP
  • Loading branch information
ancorso committed Jul 18, 2020
1 parent 354e4b0 commit 86e7c47
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 22 deletions.
3 changes: 0 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ os:
julia:
- 1.4

codecov: true
coveralls: true

notifications:
email: false

Expand Down
30 changes: 30 additions & 0 deletions examples/train_sut_to_drive.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using AdversarialDriving
using POMDPs, POMDPPolicies, POMDPSimulators
using Distributions
using GridInterpolations, LocalFunctionApproximation, LocalApproximationValueIteration

## setup training params
Np, Nv = 3,3 #25, 15
Nsteps = 2#50

## Construct a disturbance model for the adversaries
da_std = rand(Exponential(0.5))
goal_toggle_p = 10. ^rand(Uniform(-5., -1.))
blinker_toggle_p = 10. ^rand(Uniform(-5., -1.))
v_des = rand(Uniform(15., 30.))
per_timestep_penalty = rand([0, 1e-4, 1e-3, 5e-3])
disturbances = Sampleable[Normal(0,da_std), Bernoulli(goal_toggle_p), Bernoulli(blinker_toggle_p), Bernoulli(0.), Bernoulli(0.)]

## construct the MDP
sut_agent = BlinkerVehicleAgent(rand_up_left(id=1, s_dist=Uniform(15.,25.), v_dist=Uniform(10., 29)), TIDM(Tint_TIDM_template))
right_adv = BlinkerVehicleAgent(rand_right(id=3, s_dist=Uniform(15.,35.), v_dist=Uniform(15., 29.)), TIDM(Tint_TIDM_template), disturbance_model = disturbances)
mdp = DrivingMDP(sut_agent, [right_adv], Tint_roadway, 0.2, γ = 0.95, per_timestep_penalty = per_timestep_penalty, v_des = v_des)

## Solve using local approximation value iteration
goals = Float64.(Tint_goals[laneid(adversaries(mdp)[1].get_initial_entity())])
grid = RectangleGrid(range(0., stop=100.,length=Np), range(0, stop=30., length=Nv), goals, [0.0, 1.0],
range(15., stop=100., length=Np), range(0., stop=30., length=Nv), [5.0], [1.0])

interp = LocalGIFunctionApproximator(grid)
solver = LocalApproximationValueIterationSolver(interp, is_mdp_generative = true, n_generative_samples = 5, verbose = true, max_iterations = Nsteps)
policy = solve(solver, mdp)
4 changes: 2 additions & 2 deletions src/AdversarialDriving.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module AdversarialDriving
update_veh_state, noisy_scene, noisy_entity, blinker, goals, noise,
BlinkerVehicle, NoisyPedestrian, Disturbance, BlinkerVehicleControl,
PedestrianControl, AdversarialPedestrian, TIDM, lane_belief, laneid,
can_have_goal, any_collides, ego_collides, end_of_road
can_have_goal, any_collides, ego_collides, end_of_road, PolicyModel
include("driving_models.jl")

# T-Intersection roadway exports
Expand All @@ -30,7 +30,7 @@ module AdversarialDriving
# MDP
export Agent, BlinkerVehicleAgent, NoisyPedestrianAgent, id,
AdversarialDrivingMDP, action_probability, step_scene,
agents, adversaries, model, sut, sutid, update_adversary!
agents, adversaries, model, sut, sutid, update_adversary!, DrivingMDP
include("mdp.jl")

# states
Expand Down
1 change: 1 addition & 0 deletions src/actions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ function Distributions.logpdf(m::Vector{Sampleable}, a::Vector{Disturbance}, mdp
sum([logpdf(m[i], avec[i]) for i=1:length(avec)])
end
Base.rand(rng::AbstractRNG, m::Vector{Sampleable}, mdp::AdversarialDrivingMDP) = convert_a(Vector{Disturbance}, [rand(rng, d) for d in m], mdp)
Base.rand(rng::AbstractRNG, m::Vector{Sampleable}, convert_fn::Function) = convert_fn([rand(rng, d) for d in m])


## Discrete BlinkerVehicle Actions
Expand Down
23 changes: 20 additions & 3 deletions src/driving_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ end
@with_kw mutable struct TIDM <: DriverModel{BlinkerVehicleControl}
idm::IntelligentDriverModel = IntelligentDriverModel() # underlying idm
noisy_observations::Bool = false # Whether or not this model gets noisy observations
ttc_threshold = 5 # threshold through intersection
ttc_threshold = 7 # threshold through intersection
next_action::BlinkerVehicleControl = BlinkerVehicleControl() # The next action that the model will do (for controllable vehicles)

# Describes the intersection and rules of the road
Expand Down Expand Up @@ -246,8 +246,8 @@ function AutomotiveSimulator.observe!(model::TIDM, input_scene::Scene, roadway::
# If the vehicle does not have right of way then stop before the intersection
if !has_right_of_way
# Compare ttc
exit_time = [time_to_cross_distance_const_vel(veh, distance_to_point(veh, roadway, model.intersection_exit_loc[laneid(veh)])) for veh in vehicles_to_yield_to]
enter_time = [time_to_cross_distance_const_vel(veh, distance_to_point(veh, roadway, model.intersection_enter_loc[laneid(veh)])) for veh in vehicles_to_yield_to]
exit_time = [time_to_cross_distance_const_acc(veh, model.idm, distance_to_point(veh, roadway, model.intersection_exit_loc[laneid(veh)])) for veh in vehicles_to_yield_to]
enter_time = [time_to_cross_distance_const_acc(veh, model.idm, distance_to_point(veh, roadway, model.intersection_enter_loc[laneid(veh)])) for veh in vehicles_to_yield_to]
Δs_in_lane = [compute_inlane_headway(ego, veh, roadway) for veh in vehicles_to_yield_to]
# The intersection is clear of car i if, it exited the intersection in the past, or
# it will enter the intersection after you have crossed it, or
Expand Down Expand Up @@ -426,3 +426,20 @@ function AutomotiveSimulator.get_by_id(scene::Scene, id)
scene[entity_index]
end


## Neural Network driving model
@with_kw mutable struct PolicyModel <: DriverModel{LaneFollowingAccel}
policy
state = nothing
end

# Sample an action from TIDM model
function Base.rand(rng::AbstractRNG, model::PolicyModel)
action(model.policy, model.state)
end

# Observe function for TIDM
function AutomotiveSimulator.observe!(model::PolicyModel, input_scene::Scene, roadway::Roadway, egoid::Int64)
model.state = input_scene
end

86 changes: 77 additions & 9 deletions src/mdp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ function AdversarialDrivingMDP(sut::Agent, adversaries::Vector{Agent}, road::Roa
end

# Returns the intial state of the mdp simulator
function POMDPs.initialstate(mdp::AdversarialDrivingMDP, rng::AbstractRNG = Random.GLOBAL_RNG)
function POMDPs.initialstate(mdp::MDP{Scene, A}, rng::AbstractRNG = Random.GLOBAL_RNG) where A
Scene([a.get_initial_entity(rng) for a in agents(mdp)])
end

# The generative interface to the POMDP
function POMDPs.gen(mdp::AdversarialDrivingMDP, s::Scene, a::Vector{Disturbance}, rng::Random.AbstractRNG = Random.GLOBAL_RNG)
function POMDPs.gen(mdp::MDP{Scene, A}, s::Scene, a::A, rng::Random.AbstractRNG = Random.GLOBAL_RNG) where A
mdp.last_observation = convert_s(AbstractArray, s, mdp)
sp = step_scene(mdp, s, a, rng)
r = reward(mdp, s, a, sp)
Expand All @@ -100,10 +100,10 @@ function POMDPs.reward(mdp::AdversarialDrivingMDP, s::Scene, a::Vector{Disturban
end

# Discount factor for the POMDP (Set to 1 because of the finite horizon)
POMDPs.discount(mdp::AdversarialDrivingMDP) = mdp.γ
POMDPs.discount(mdp::MDP{Scene, A}) where A = mdp.γ

# The simulation is terminal if there is collision with the ego vehicle or if the maximum simulation time has been reached
POMDPs.isterminal(mdp::AdversarialDrivingMDP, s::Scene) = !(sutid(mdp) in s)|| any_collides(s)
POMDPs.isterminal(mdp::MDP{Scene, A}, s::Scene) where A = !(sutid(mdp) in s)|| any_collides(s)

# Define the set of actions, action index and probability
POMDPs.actions(mdp::AdversarialDrivingMDP) = get_actions(mdp.disturbance_model)
Expand Down Expand Up @@ -142,19 +142,19 @@ function step_scene(mdp::AdversarialDrivingMDP, s::Scene, actions::Vector{Distur
end

# Returns the list of agents in the mdp
agents(mdp::AdversarialDrivingMDP) = mdp.agents
agents(mdp::MDP{Scene,A}) where A = mdp.agents

# Returns the list of adversaries in the mdp
adversaries(mdp::AdversarialDrivingMDP) = view(mdp.agents, 1:mdp.num_adversaries)
adversaries(mdp::MDP{Scene,A}) where A = view(mdp.agents, 1:mdp.num_adversaries)

# Returns the model associated with the vehid
model(mdp::AdversarialDrivingMDP, vehid::Int) = mdp.agents[mdp.vehid2ind[vehid]].model
model(mdp::MDP{Scene,A}, vehid::Int) where A = mdp.agents[mdp.vehid2ind[vehid]].model

# Returns the system under test
sut(mdp::AdversarialDrivingMDP) = mdp.agents[mdp.num_adversaries + 1]
sut(mdp::MDP{Scene,A}) where A = mdp.agents[mdp.num_adversaries + 1]

# Returns the sut id
sutid(mdp::AdversarialDrivingMDP) = id(sut(mdp))
sutid(mdp::MDP{Scene,A}) where A = id(sut(mdp))

function update_adversary!(adversary::Agent, action::Disturbance, s::Scene)
index = findfirst(id(adversary), s)
Expand All @@ -165,3 +165,71 @@ function update_adversary!(adversary::Agent, action::Disturbance, s::Scene)
s[index] = Entity(state_type(veh.state, noise = action.noise), veh.def, veh.id) # replace the entity in the scene
end


## SUT driving MDP
mutable struct DrivingMDP <: MDP{Scene, BlinkerVehicleControl}
agents::Vector{Agent} # All the agents ordered by (adversaries..., sut, others...)
vehid2ind::Dict{Int64, Int64} # Dictionary that maps vehid to index in agent list
num_adversaries::Int64 # The number of adversaries
roadway::Roadway # The roadway for the simulation
dt::Float64 # Simulation timestep
last_observation::Array{Float64} # Last observation of the vehicle state
γ::Float64 # discount
end_of_road::Float64 # Early stopping of road
per_timestep_penalty::Float64
v_des::Float64
end

# Constructor
function DrivingMDP(sut::Agent, adversaries::Vector{Agent}, road::Roadway, dt::Float64; γ = 1, end_of_road = Inf, per_timestep_penalty = 0, v_des = 25)
agents = [adversaries..., sut]
d = Dict(id(agents[i]) => i for i=1:length(agents))
DrivingMDP(agents, d, length(adversaries), road, dt, Float64[], γ, end_of_road, per_timestep_penalty, v_des)
end

# Get the reward from the actions taken and the next state
function POMDPs.reward(mdp::DrivingMDP, s::Scene, a::BlinkerVehicleControl, sp::Scene)
id = sutid(mdp)
v = vel(get_by_id(s, id))

r = -abs(mdp.per_timestep_penalty)
# If the simulation ends but the SUT is not at the end of the road, big penalty
if ego_collides(id, sp)
r += -1
elseif isterminal(mdp, sp)
r += 1
end
r += -.1 * (v < 0)
r += -0.001 * abs(v - mdp.v_des)
r
end

function step_scene(mdp::DrivingMDP, s::Scene, action::BlinkerVehicleControl, rng::AbstractRNG = Random.GLOBAL_RNG)
entities = []
sid = sutid(mdp)

# Choose random actions for the adversaries
for adversary in adversaries(mdp)
adv_action = rand(rng, adversary.disturbance_model, adversary.vec_to_disturbance)
update_adversary!(adversary, adv_action, s)
end

# Loop through the vehicles in the scene, apply action and add to next scene
for (i, veh) in enumerate(s)
if veh.id == sid # for the sut, use the prescribed action
a = action
else # For the other vehicles use their model
m = model(mdp, veh.id)
observe!(m, s, mdp.roadway, veh.id)
a = rand(rng, m)
end
bv = Entity(propagate(veh, a, mdp.roadway, mdp.dt), veh.def, veh.id)
!end_of_road(bv, mdp.roadway, mdp.end_of_road) && push!(entities, bv)
end
isempty(entities) ? Scene(typeof(sut(mdp).get_initial_entity())) : Scene([entities...])
end

# Define the set of actions, action index and probability
POMDPs.actions(mdp::DrivingMDP) = [BlinkerVehicleControl(a = -4.), BlinkerVehicleControl(a= -2.), BlinkerVehicleControl(a = 0.), BlinkerVehicleControl(a = 1.5), BlinkerVehicleControl(a = 3.)]
POMDPs.actionindex(mdp::DrivingMDP, a::BlinkerVehicleControl) = findfirst([a] .== actions(mdp))

10 changes: 5 additions & 5 deletions src/states.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Convert_s functions for the mdp
# Converts from vector to a Scene
function POMDPs.convert_s(::Type{Scene}, s::AbstractArray{Float64}, mdp::AdversarialDrivingMDP)
function POMDPs.convert_s(::Type{Scene}, s::AbstractArray{Float64}, mdp::MDP{Scene, A}) where A
entities = []

# Loop through all the agents of the mdp
Expand All @@ -14,7 +14,7 @@ function POMDPs.convert_s(::Type{Scene}, s::AbstractArray{Float64}, mdp::Adversa
end

# Convert from Scene to a vector
function POMDPs.convert_s(::Type{AbstractArray}, state::Scene, mdp::AdversarialDrivingMDP)
function POMDPs.convert_s(::Type{AbstractArray}, state::Scene, mdp::MDP{Scene, A}) where A
isempty(mdp.last_observation) && (mdp.last_observation = zeros(sum([a.entity_dim for a in agents(mdp)])))
index = 1
for a in agents(mdp)
Expand All @@ -26,9 +26,9 @@ function POMDPs.convert_s(::Type{AbstractArray}, state::Scene, mdp::AdversarialD
copy(mdp.last_observation)
end

POMDPs.convert_s(::Type{Array{Float64, 1}}, state::Scene, mdp::AdversarialDrivingMDP) = convert_s(AbstractArray, state, mdp)
POMDPs.convert_s(::Type{Array{Float32, 1}}, state::Scene, mdp::AdversarialDrivingMDP) = convert(Array{Float32,1}, convert_s(AbstractArray, state, mdp))
POMDPs.convert_s(::Type{Array{Float32}}, state::Scene, mdp::AdversarialDrivingMDP) = convert_s(Array{Float32,1}, state, mdp)
POMDPs.convert_s(::Type{Array{Float64, 1}}, state::Scene, mdp::MDP{Scene, A}) where A = convert_s(AbstractArray, state, mdp)
POMDPs.convert_s(::Type{Array{Float32, 1}}, state::Scene, mdp::MDP{Scene, A}) where A = convert(Array{Float32,1}, convert_s(AbstractArray, state, mdp))
POMDPs.convert_s(::Type{Array{Float32}}, state::Scene, mdp::MDP{Scene, A}) where A = convert_s(Array{Float32,1}, state, mdp)


## Adversarial Pedestrians vehicles
Expand Down

0 comments on commit 86e7c47

Please sign in to comment.