diff --git a/Project.toml b/Project.toml index 2825964..fa27222 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" diff --git a/src/AutomotiveDrivingModels.jl b/src/AutomotiveDrivingModels.jl index d5de315..02b8808 100644 --- a/src/AutomotiveDrivingModels.jl +++ b/src/AutomotiveDrivingModels.jl @@ -7,72 +7,14 @@ using StaticArrays using Distributions using Reexport using Random -using SparseArrays using DataFrames using Tricks: static_hasmethod include("vec/Vec.jl") @reexport using .Vec -# Records -export - Entity, - Frame, - EntityFrame, - RecordFrame, - RecordState, - ListRecord, - QueueRecord, - EntityQueueRecord, - - ListRecordFrameIterator, - ListRecordStateByIdIterator, - - get_statetype, - get_deftype, - get_idtype, - - capacity, - nframes, - nstates, - nids, - frame_inbounds, - pastframe_inbounds, - n_objects_in_frame, - id2index, - get_ids, - nth_id, - get_state, - get_def, - get_time, - get_timestep, - get_elapsed_time, - get_subinterval, - get_by_id, - findfirst_stateindex_with_id, - findfirst_frame_with_id, - findlast_frame_with_id, - get_first_available_id, - push_back_records!, - update!, - allocate_frame, - get_sparse_lookup - -include("records/common.jl") -include("records/entities.jl") -include("records/frames.jl") -include("records/listrecords.jl") -include("records/queuerecords.jl") -include("records/conversions.jl") - # Roadways -export StraightRoadway, - mod_position_to_roadway, - get_headway - -include("roadways/straight_1d_roadways.jl") - export CurvePt, Curve, CurveIndex, @@ -158,20 +100,27 @@ export include("agent-definitions/agent_definitions.jl") export + Entity, + Frame, + EntityFrame, + capacity, + id2index, + get_by_id, + get_first_available_id, posf, posg, vel, velf, velg, VehicleState, - get_vel_s, - get_vel_t, get_center, get_footpoint, get_front, get_rear, get_lane +include("states/entities.jl") +include("states/frames.jl") include("states/interface.jl") include("states/vehicle_state.jl") @@ -285,6 +234,7 @@ include("feature-extraction/lidar_sensor.jl") export propagate, + EntityAction, LaneFollowingAccel, AccelTurnrate, AccelDesang, @@ -303,12 +253,11 @@ include("actions/pedestrian_lat_lon_accel.jl") export DriverModel, StaticDriver, - get_name, action_type, set_desired_speed!, observe!, reset_hidden_state!, - prime_with_history! + reset_hidden_states! include("behaviors/interface.jl") @@ -346,24 +295,16 @@ include("behaviors/tim_2d_driver.jl") include("behaviors/sidewalk_pedestrian_model.jl") export - get_actions!, - tick!, - reset_hidden_states!, simulate, simulate!, - EntityAction, run_callback, - CollisionCallback + CollisionCallback, + observe_from_history!, + simulate_from_history!, + simulate_from_history include("simulation/simulation.jl") include("simulation/callbacks.jl") - - -export - State1D, - Vehicle1D, - Scene1D - -include("deprecated.jl") +include("simulation/simulation_from_history.jl") end # AutomotiveDrivingModels diff --git a/src/agent-definitions/agent_definitions.jl b/src/agent-definitions/agent_definitions.jl index 725baaa..3523386 100644 --- a/src/agent-definitions/agent_definitions.jl +++ b/src/agent-definitions/agent_definitions.jl @@ -76,16 +76,6 @@ function Base.read(io::IO, ::MIME"text/plain", ::Type{VehicleDef}) return VehicleDef(class, length, width) end -""" - Base.convert(::Type{Entity{S, VehicleDef, I}}, veh::Entity{S, D, I}) where {S,D<:AbstractAgentDefinition,I} - -Converts the definition of an entity -""" -function Base.convert(::Type{Entity{S, VehicleDef, I}}, veh::Entity{S, D, I}) where {S,D<:AbstractAgentDefinition,I} - vehdef = VehicleDef(class(veh.def), length(veh.def), width(veh.def)) - return Entity{S, VehicleDef, I}(veh.state, vehdef, veh.id) -end - """ BicycleModel BicycleModel(def::VehicleDef; a::Float64 = 1.5, b::Float64 = 1.5) diff --git a/src/behaviors/MOBIL.jl b/src/behaviors/MOBIL.jl index e3bf2d7..0287275 100644 --- a/src/behaviors/MOBIL.jl +++ b/src/behaviors/MOBIL.jl @@ -30,11 +30,6 @@ function MOBIL( return MOBIL(DIR_MIDDLE, mlon, safe_decel, politeness, advantage_threshold) end -""" -Return the name of the lane changing model -""" -get_name(::MOBIL) = "MOBIL" - """ Set the desired speed of the longitudinal model within MOBIL """ diff --git a/src/behaviors/intelligent_driver_model.jl b/src/behaviors/intelligent_driver_model.jl index 335c039..69fb7c2 100644 --- a/src/behaviors/intelligent_driver_model.jl +++ b/src/behaviors/intelligent_driver_model.jl @@ -34,7 +34,7 @@ around the non-errorable IDM output. d_cmf::Float64 = 2.0 # comfortable deceleration [m/s²] (positive) d_max::Float64 = 9.0 # maximum deceleration [m/s²] (positive) end -get_name(::IntelligentDriverModel) = "IDM" + function set_desired_speed!(model::IntelligentDriverModel, v_des::Float64) model.v_des = v_des model diff --git a/src/behaviors/interface.jl b/src/behaviors/interface.jl index 2253e22..8cdfe75 100644 --- a/src/behaviors/interface.jl +++ b/src/behaviors/interface.jl @@ -9,12 +9,6 @@ The DriverModel type is an abstract type! Custom driver models should inherit fr """ abstract type DriverModel{DriveAction} end -""" - get_name(::DriverModel) -returns the name of the driver model -""" -function get_name end - """ action_type(::DriverModel{A}) where {A} returns the type of the actions that are sampled from the model @@ -34,6 +28,17 @@ Resets the hidden states of the model. """ function reset_hidden_state! end +""" + reset_hidden_states!(models::Dict{I,M}) where {M<:DriverModel} +reset hidden states of all driver models in `models` +""" +function reset_hidden_states!(models::Dict{I,M}) where {I, M<:DriverModel} + for model in values(models) + reset_hidden_state!(model) + end + return models +end + """ observe!(model::DriverModel, scene, roadway, egoid) Observes the scene and updates the model states accordingly. @@ -48,40 +53,6 @@ Samples an action from the model. Base.rand(model::DriverModel) = rand(Random.GLOBAL_RNG, model) Base.rand(rng::AbstractRNG, model::DriverModel) = error("AutomotiveDrivingModelsError: Base.rand(::AbstractRNG, ::$(typeof(model))) not implemented") -function prime_with_history!( - model::DriverModel, - trajdata::ListRecord{S,D,I}, - roadway::R, - frame_start::Int, - frame_end::Int, - egoid::I, - scene::EntityFrame{S,D,I} = allocate_frame(trajdata), - ) where {S,D,I,R} - - reset_hidden_state!(model) - - for frame in frame_start : frame_end - get!(scene, trajdata, frame) - observe!(model, scene, roadway, egoid) - end - - return model -end -function prime_with_history!(model::DriverModel, rec::EntityQueueRecord{S,D,I}, roadway::R, egoid::I; - pastframe_start::Int=1-nframes(rec), - pastframe_end::Int=0, - ) where {S,D,I,R} - - reset_hidden_state!(model) - - for pastframe in pastframe_start : pastframe_end - scene = rec[pastframe] - observe!(model, scene, roadway, egoid) - end - - model -end - #### """ @@ -97,7 +68,6 @@ struct StaticDriver{A,P<:ContinuousMultivariateDistribution} <: DriverModel{A} distribution::P end -get_name(::StaticDriver) = "StaticDriver" function Base.rand(rng::AbstractRNG, model::StaticDriver{A,P}) where {A,P} a = rand(rng, model.distribution) return convert(A, a) diff --git a/src/behaviors/lane_following_drivers.jl b/src/behaviors/lane_following_drivers.jl index 978787b..f8d5410 100644 --- a/src/behaviors/lane_following_drivers.jl +++ b/src/behaviors/lane_following_drivers.jl @@ -26,7 +26,6 @@ mutable struct StaticLaneFollowingDriver <: LaneFollowingDriver end StaticLaneFollowingDriver() = StaticLaneFollowingDriver(LaneFollowingAccel(0.0)) StaticLaneFollowingDriver(a::Float64) = StaticLaneFollowingDriver(LaneFollowingAccel(a)) -get_name(::StaticLaneFollowingDriver) = "ProportionalSpeedTracker" Base.rand(rng::AbstractRNG, model::StaticLaneFollowingDriver) = model.a Distributions.pdf(model::StaticLaneFollowingDriver, a::LaneFollowingAccel) = isapprox(a.a, model.a.a) ? Inf : 0.0 Distributions.logpdf(model::StaticLaneFollowingDriver, a::LaneFollowingAccel) = isapprox(a.a, model.a.a) ? Inf : -Inf diff --git a/src/behaviors/lat_lon_separable_driver.jl b/src/behaviors/lat_lon_separable_driver.jl index 777e99a..5d60568 100644 --- a/src/behaviors/lat_lon_separable_driver.jl +++ b/src/behaviors/lat_lon_separable_driver.jl @@ -5,7 +5,6 @@ end LatLonSeparableDriver(mlat::LateralDriverModel, mlon::LaneFollowingDriver) = LatLonSeparableDriver{LatLonAccel}(mlat, mlon) -get_name(model::LatLonSeparableDriver) = @sprintf("%s + %s", get_name(model.mlat), get_name(model.mlon)) function set_desired_speed!(model::LatLonSeparableDriver, v_des::Float64) set_desired_speed!(model.mlon, v_des) model diff --git a/src/behaviors/lateral_driver_models.jl b/src/behaviors/lateral_driver_models.jl index e3d9148..21b89c7 100644 --- a/src/behaviors/lateral_driver_models.jl +++ b/src/behaviors/lateral_driver_models.jl @@ -22,8 +22,6 @@ A controller that executes the lane change decision made by the `lane change mod kd::Float64 = 2.0 # derivative constant for lane tracking end -get_name(::ProportionalLaneTracker) = "ProportionalLaneTracker" - function track_lateral!(model::ProportionalLaneTracker, laneoffset::Float64, lateral_speed::Float64) model.a = -laneoffset*model.kp - lateral_speed*model.kd model diff --git a/src/behaviors/princeton_driver.jl b/src/behaviors/princeton_driver.jl index 8cf7fea..9986afb 100644 --- a/src/behaviors/princeton_driver.jl +++ b/src/behaviors/princeton_driver.jl @@ -14,7 +14,7 @@ A lane following driver model that controls longitudinal speed by following a fr k::Float64 = 1.0 # proportional constant for speed tracking [s⁻¹] v_des::Float64 = 29.0 # desired speed [m/s] end -get_name(::PrincetonDriver) = "PrincetonDriver" + function set_desired_speed!(model::PrincetonDriver, v_des::Float64) model.v_des = v_des model diff --git a/src/behaviors/sidewalk_pedestrian_model.jl b/src/behaviors/sidewalk_pedestrian_model.jl index 21a0a00..44c6536 100644 --- a/src/behaviors/sidewalk_pedestrian_model.jl +++ b/src/behaviors/sidewalk_pedestrian_model.jl @@ -65,7 +65,6 @@ Walks along the sidewalk until approaching the crosswalk. Waits for the cars to phases = Int[] end -AutomotiveDrivingModels.get_name(model::SidewalkPedestrianModel) = "SidewalkPedestrianModel" Base.rand(rng::AbstractRNG, model::SidewalkPedestrianModel) = model.a function AutomotiveDrivingModels.observe!(model::SidewalkPedestrianModel, scene::Frame{Entity{VehicleState, D, I}}, roadway::Roadway, egoid::I) where {D, I} diff --git a/src/behaviors/speed_trackers.jl b/src/behaviors/speed_trackers.jl index f1e2371..6f5e92c 100644 --- a/src/behaviors/speed_trackers.jl +++ b/src/behaviors/speed_trackers.jl @@ -14,7 +14,7 @@ Longitudinal proportional speed control. k::Float64 = 1.0# proportional constant for speed tracking [s⁻¹] v_des::Float64 = 29.0 # desired speed [m/s] end -get_name(::ProportionalSpeedTracker) = "ProportionalSpeedTracker" + function set_desired_speed!(model::ProportionalSpeedTracker, v_des::Float64) model.v_des = v_des model diff --git a/src/behaviors/tim_2d_driver.jl b/src/behaviors/tim_2d_driver.jl index 361e447..324efc9 100644 --- a/src/behaviors/tim_2d_driver.jl +++ b/src/behaviors/tim_2d_driver.jl @@ -16,8 +16,6 @@ Driver that combines longitudinal driver and lateral driver into one model. mlane::LaneChangeModel = TimLaneChanger() end -get_name(::Tim2DDriver) = "Tim2DDriver" - function set_desired_speed!(model::Tim2DDriver, v_des::Float64) set_desired_speed!(model.mlon, v_des) set_desired_speed!(model.mlane, v_des) diff --git a/src/behaviors/tim_lane_changer.jl b/src/behaviors/tim_lane_changer.jl index 7feb371..b3a4488 100644 --- a/src/behaviors/tim_lane_changer.jl +++ b/src/behaviors/tim_lane_changer.jl @@ -8,11 +8,10 @@ Has not been published anywhere, so first use in a paper would have to describe See MOBIL if you want a lane changer you can cite. # Constructors - TimLaneChanger(timestep::Float64;v_des::Float64=29.0,rec::QueueRecord=QueueRecord(2,timestep),threshold_fore::Float64 = 50.0,threshold_lane_change_gap_fore::Float64 = 10.0, threshold_lane_change_gap_rear::Float64 = 10.0,dir::Int=DIR_MIDDLE) + TimLaneChanger(v_des::Float64=29.0, threshold_fore::Float64 = 50.0,threshold_lane_change_gap_fore::Float64 = 10.0, threshold_lane_change_gap_rear::Float64 = 10.0,dir::Int=DIR_MIDDLE) # Fields - `dir::Int = DIR_MIDDLE` the desired lane to go to eg: left,middle (i.e. stay in same lane) or right -- `rec::QueueRecord` TODO - `v_des::Float64 = 29.0` desired velocity - `threshold_fore::Float64 = 50.0` Distance from lead vehicle - `threshold_lane_change_gap_fore::Float64 = 10.0` Space in front @@ -26,8 +25,6 @@ See MOBIL if you want a lane changer you can cite. threshold_lane_change_gap_rear::Float64 = 10.0 end -get_name(::TimLaneChanger) = "TimLaneChanger" - function set_desired_speed!(model::TimLaneChanger, v_des::Float64) model.v_des = v_des model diff --git a/src/deprecated.jl b/src/deprecated.jl deleted file mode 100644 index 3ae4dd0..0000000 --- a/src/deprecated.jl +++ /dev/null @@ -1,194 +0,0 @@ -@deprecate get_vel_s velf(state).s -@deprecate get_vel_t velf(state).t -@deprecate get_name x->@sprintf("%s", typeof(x)) - -@deprecate propagate(veh, state, roadway, Δt) propagate(veh, state, a, roadway, Δt) - -function simulate!( - scene::Frame{E}, - roadway::R, - models::Dict{I,M}, - nticks::Int64, - timestep::Float64; - rng::AbstractRNG = Random.GLOBAL_RNG, - callbacks = nothing -) where {E<:Entity,A,R,I,M<:DriverModel} - Base.depwarn( -"`simulate!` without specifying a pre-allocated data structure for `scenes` is now deprecated.\n -You probably want to use the `simulate` function instead.\n -Alternatively, you can provide a pre-allocated data structure via the `scenes=` keyword", - :simulate_no_prealloc - ) - return simulate(scene, roadway, models, nticks, timestep, rng=rng, callbacks=callbacks) -end - -""" - # TODO: this should be removed, but where to document the function then? - run_callback(callback::Any, scenes::Vector{F}, roadway::R, models::Dict{I,M}, tick::Int) where {F,I,R,M<:DriverModel} - run_callback(callback::Any, rec::EntityQueueRecord{S,D,I}, roadway::R, models::Dict{I,M}, tick::Int) where {S,D,I,R,M<:DriverModel} -run callback and return whether simlation should terminate -A new method should be implemented when defining a new callback object. -""" -function run_callback(callback, scenes, actions::Union{Nothing, Vector{Frame{A}}}, roadway, models, tick) where {A<:EntityAction} - Base.depwarn( -"Using a deprecated version of `run_callback`. Since v0.7.10, user-defined callback functions should also take an `actions` argument. - If you have implemented `run_callback` with an actions argument, make sure the method signature is more specific than this one.\n -Check the section `Simulation > Callbacks` in the package documentation for more information - on how to define callback functions.\n -Your function call is being forwarded to `run_callback` without actions argument.", - :run_callback_no_actions - ) - return run_callback(callback, scenes, roadway, models, tick) -end - -# 1D state - -state_1d_dep_msgs = """ - State1D, Vehicle1D, and Scene1D are deprecated, use VehicleState, Vehicle, and Scene instead. -""" - -""" - State1D - -A data type to represent one dimensional states - -# Fields - - `s::Float64` position - - `v::Float64` speed [m/s] -""" -struct State1D - s::Float64 # position - v::Float64 # speed [m/s] -end - -function vel(s::State1D) - Base.depwarn(state_1d_dep_msgs, :vel) - return s.v -end - -posf(s::State1D) = VecSE2{Float64}(s.s) -posg(s::State1D) = VecSE2{Float64}(s.s, 0., 0.) # TODO: how to derive global position & angle -velf(s::State1D) = (s=s.v, t=0.) -velg(s::State1D) = (x=s.v*cos(posg(s).θ), y=s.v*sin(posg(s).θ)) - -function Base.write(io::IO, ::MIME"text/plain", s::State1D) - Base.depwarn(state_1d_dep_msgs, :write) - @printf(io, "%.16e %.16e", s.s, s.v) -end -function Base.read(io::IO, ::MIME"text/plain", ::Type{State1D}) - Base.depwarn(state_1d_dep_msgs, :read) - i = 0 - tokens = split(strip(readline(io)), ' ') - s = parse(Float64, tokens[i+=1]) - v = parse(Float64, tokens[i+=1]) - return State1D(s,v) -end - -""" - Vehicle1D -A specific instance of the Entity type defined in Records.jl to represent vehicles in 1d environments. -""" -const Vehicle1D = Entity{State1D, VehicleDef, Int64} - -""" - Scene1D - -A specific instance of the Frame type defined in Records.jl to represent a list of vehicles in 1d environments. - -# constructors - Scene1D(n::Int=100) - Scene1D(arr::Vector{Vehicle1D}) -""" -const Scene1D = Frame{Vehicle1D} - -function Scene1D(n::Int=100) - Base.depwarn(state_1d_dep_msgs, :Scene1D) - Frame(Vehicle1D, n) -end - -function Scene1D(arr::Vector{Vehicle1D}) - Base.depwarn(state_1d_dep_msgs, :Scene1D) - Frame{Vehicle1D}(arr, length(arr)) -end - -function get_center(veh::Vehicle1D) - Base.depwarn(state_1d_dep_msgs, :get_center) - veh.state.s -end -function get_footpoint(veh::Vehicle1D) - Base.depwarn(state_1d_dep_msgs, :get_footpoint) - veh.state.s -end -function get_front(veh::Vehicle1D) - Base.depwarn(state_1d_dep_msgs, :get_front) - veh.state.s + length(veh.def)/2 -end -function get_rear(veh::Vehicle1D) - Base.depwarn(state_1d_dep_msgs, :get_rear) - veh.state.s - length(veh.def)/2 -end - - -function get_headway(veh_rear::Vehicle1D, veh_fore::Vehicle1D, roadway::StraightRoadway) - Base.depwarn(state_1d_dep_msgs, :get_headway) - return get_headway(get_front(veh_rear), get_rear(veh_fore), roadway) -end -function get_neighbor_fore(scene::Frame{Entity{State1D, D, I}}, vehicle_index::I, roadway::StraightRoadway) where {D, I} - Base.depwarn(state_1d_dep_msgs, :get_neighbor_fore) - ego = scene[vehicle_index] - best_ind = nothing - best_gap = Inf - for (i,veh) in enumerate(scene) - if i != vehicle_index - Δs = get_headway(ego, veh, roadway) - if Δs < best_gap - best_gap, best_ind = Δs, i - end - end - end - return NeighborLongitudinalResult(best_ind, best_gap) -end -function get_neighbor_rear(scene::Frame{Entity{State1D, D, I}}, vehicle_index::I, roadway::StraightRoadway) where {D, I} - Base.depwarn(state_1d_dep_msgs, :get_neighbor_rear) - ego = scene[vehicle_index] - best_ind = nothing - best_gap = Inf - for (i,veh) in enumerate(scene) - if i != vehicle_index - Δs = get_headway(veh, ego, roadway) - if Δs < best_gap - best_gap, best_ind = Δs, i - end - end - end - return NeighborLongitudinalResult(best_ind, best_gap) -end - -function propagate(veh::Vehicle1D, action::LaneFollowingAccel, roadway::StraightRoadway, Δt::Float64) - Base.depwarn(state_1d_dep_msgs, :propagate) - a = action.a - s, v = veh.state.s, veh.state.v - - s′ = s + v*Δt + a*Δt*Δt/2 - v′ = max(v + a*Δt, 0.) # no negative velocities - - s′ = mod_position_to_roadway(s′, roadway) - - return State1D(s′, v′) -end - - -function observe!(model::LaneFollowingDriver, scene::Frame{Entity{State1D, D, I}}, roadway::StraightRoadway, egoid::I) where {D, I} - Base.depwarn(state_1d_dep_msgs, :observe!) - vehicle_index = findfirst(egoid, scene) - - fore_res = get_neighbor_fore(scene, vehicle_index, roadway) - - v_ego = vel(scene[vehicle_index].state) - v_oth = vel(scene[fore_res.ind].state) - headway = fore_res.Δs - - track_longitudinal!(model, v_ego, v_oth, headway) - - return model -end diff --git a/src/records/common.jl b/src/records/common.jl deleted file mode 100644 index 8134bf3..0000000 --- a/src/records/common.jl +++ /dev/null @@ -1,11 +0,0 @@ -Base.write(io::IO, ::MIME"text/plain", ::Nothing) = nothing -function Base.read(io::IO, ::MIME"text/plain", ::Nothing) - readline(io) - return nothing -end - -Base.write(io::IO, ::MIME"text/plain", i::Integer) = print(io, i) -Base.read(io::IO, ::MIME"text/plain", ::Type{I}) where {I<:Integer} = parse(I, readline(io)) - -Base.write(io::IO, ::MIME"text/plain", r::Float64) = print(io, r) -Base.read(io::IO, ::MIME"text/plain", ::Type{F}) where {F<:AbstractFloat} = parse(F, readline(io)) \ No newline at end of file diff --git a/src/records/conversions.jl b/src/records/conversions.jl deleted file mode 100644 index b389f11..0000000 --- a/src/records/conversions.jl +++ /dev/null @@ -1,96 +0,0 @@ -""" - convert(ListRecord, qrec::QueueRecord{E}) - -Converts a QueueRecord into the corresponding ListRecord. -""" -function Base.convert(::Type{ListRecord{S,D,I}}, qrec::QueueRecord{Entity{S,D,I}}) where {S,D,I} - - frames = Array{RecordFrame}(undef, nframes(qrec)) - states = Array{RecordState{S,I}}(undef, nstates(qrec)) - defs = Dict{I, D}() - - lo = 1 - for (i,pastframe) in enumerate(1-nframes(qrec) : 0) - frame = qrec[pastframe] - - hi = lo-1 - for entity in frame - hi += 1 - defs[entity.id] = entity.def - states[hi] = RecordState{S,I}(entity.state, entity.id) - end - - frames[i] = RecordFrame(lo, hi) - lo = hi + 1 - end - - return ListRecord{S,D,I}(get_timestep(qrec), frames, states, defs) -end -Base.convert(::Type{ListRecord}, qrec::QueueRecord{Entity{S,D,I}}) where {S,D,I} = convert(ListRecord{S,D,I}, qrec) - -""" - convert(QueueRecord, lrec::ListRecord) - -Converts a ListRecord into the corresponding QueueRecord{Entity{S,D,I}}. -Note that the timesteps for a ListRecord are not necessarily constant timesteps. -""" -function Base.convert(::Type{QueueRecord{Entity{S,D,I}}}, lrec::ListRecord{S,D,I}) where {S,D,I} - - N = nframes(lrec) - M = maximum(n_objects_in_frame(lrec, i) for i in 1 : N) - retval = QueueRecord(Entity{S,D,I}, N, get_timestep(lrec), M) - - frame = Frame(Entity{S,D,I}, M) - for i in 1 : N - get!(frame, lrec, i) - update!(retval, frame) - end - - return retval -end - - -""" - get_sparse_lookup(lrec::ListRecord{S,D,I}) - -Converts a ListRecord into the corresponding sparse matrix containing the states, SparseMatrixCSC{S,Int}. -This requires I <: Integer. - -In the sparse array, each column is the state of an entity, and reach row is a frame entry. - -A tuple containing the sparse matrix and a dictionary mapping from ids to the column number are returned. -""" -function get_sparse_lookup(rec::ListRecord{S,D,I}) where {S,D,I<:Integer} - - # ids are by time of entry - # translate id to index on range of 1:n - id_lookup = Dict{I,Int}(id => index for (index, id) in - (rec.defs |> keys |> collect |> sort |> enumerate)) - - m = nframes(rec) # num rows - n = nids(rec) # num cols - n_states = length(rec.states) - - # the row/frame of each state - Is = Vector{Int}(undef, n_states) - # the column/car index (not the id, see id_lookup) - Js = similar(Is) - # the states themselves - Vs = similar(Is, S) - - # the index into I, J, and V (on [1:n_states]) - idx = 1 - for (fid, frame) in enumerate(rec.frames) - for stateid in frame.lo : frame.hi - recstate = rec.states[stateid] - Is[idx] = fid - Js[idx] = id_lookup[recstate.id] - Vs[idx] = recstate.state - - idx += 1 - end - end - - sparsemat = sparse(Is, Js, Vs, m, n) - return (sparsemat, id_lookup) -end \ No newline at end of file diff --git a/src/records/entities.jl b/src/records/entities.jl deleted file mode 100644 index 904746e..0000000 --- a/src/records/entities.jl +++ /dev/null @@ -1,8 +0,0 @@ -struct Entity{S,D,I} # state, definition, identification - state::S - def::D - id::I -end -Entity(entity::Entity{S,D,I}, s::S) where {S,D,I} = Entity(s, entity.def, entity.id) - -Base.:(==)(a::Entity{S,D,I}, b::Entity{S,D,I}) where {S,D,I} = a.state == b.state && a.def == b.def && a.id == b.id \ No newline at end of file diff --git a/src/records/listrecords.jl b/src/records/listrecords.jl deleted file mode 100644 index 33b72d3..0000000 --- a/src/records/listrecords.jl +++ /dev/null @@ -1,252 +0,0 @@ -struct RecordFrame - lo::Int - hi::Int -end -Base.length(recframe::RecordFrame) = recframe.hi - recframe.lo + 1 # number of objects in the frame -Base.write(io::IO, ::MIME"text/plain", recframe::RecordFrame) = @printf(io, "%d %d", recframe.lo, recframe.hi) -function Base.read(io::IO, ::MIME"text/plain", ::Type{RecordFrame}) - tokens = split(strip(readline(io)), ' ') - lo = parse(Int, tokens[1]) - hi = parse(Int, tokens[2]) - return RecordFrame(lo, hi) -end - -struct RecordState{S,I} - state::S - id::I -end - -mutable struct ListRecord{S,D,I} # State, Definition, Identification - timestep::Float64 - frames::Vector{RecordFrame} - states::Vector{RecordState{S,I}} - defs::Dict{I, D} -end -ListRecord(timestep::Float64, ::Type{S}, ::Type{D}, ::Type{I}=Int) where {S,D,I} = ListRecord{S,D,I}(timestep, RecordFrame[], RecordState{S}[], Dict{I,D}()) - -Base.show(io::IO, rec::ListRecord{S,D,I}) where {S,D,I} = @printf(io, "ListRecord{%s, %s, %s}(%d frames)", string(S), string(D), string(I), nframes(rec)) -function Base.write(io::IO, mime::MIME"text/plain", rec::ListRecord) - - show(io, rec) - print(io, "\n") - @printf(io, "%.16e\n", rec.timestep) - - # defs - println(io, length(rec.defs)) - for (id,def) in rec.defs - write(io, mime, id) - print(io, "\n") - write(io, mime, def) - print(io, "\n") - end - - # ids & states - println(io, length(rec.states)) - for recstate in rec.states - write(io, mime, recstate.id) - print(io, "\n") - write(io, mime, recstate.state) - print(io, "\n") - end - - # frames - println(io, nframes(rec)) - for recframe in rec.frames - write(io, mime, recframe) - print(io, "\n") - end -end -function Base.read(io::IO, mime::MIME"text/plain", ::Type{ListRecord{S,D,I}}) where {S,D,I} - readline(io) # skip first line - - timestep = parse(Float64, readline(io)) - - n = parse(Int, readline(io)) - defs = Dict{I,D}() - for i in 1 : n - id = read(io, mime, I) - defs[id] = read(io, mime, D) - end - - n = parse(Int, readline(io)) - states = Array{RecordState{S,I}}(undef, n) - for i in 1 : n - id = read(io, mime, I) - state = read(io, mime, S) - states[i] = RecordState{S,I}(state, id) - end - - n = parse(Int, readline(io)) - frames = Array{RecordFrame}(undef, n) - for i in 1 : n - frames[i] = read(io, mime, RecordFrame) - end - - return ListRecord{S,D,I}(timestep, frames, states, defs) -end - -function Base.push!(rec::ListRecord{S,D,I}, frame::EntityFrame{S,D,I}) where {S,D,I} - states = RecordState{S,I}[RecordState(e.state, e.id) for e in frame] - push!(rec.frames, RecordFrame(length(rec.states)+1, length(rec.states)+length(states))) - append!(rec.states, states) - for e in frame - if !haskey(rec.defs, e.id) - rec.defs[e.id] = e.def - else - @assert rec.defs[e.id] == e.def # not sure whether to keep this check - end - end - return rec -end - -get_statetype(rec::ListRecord{S,D,I}) where {S,D,I} = S -get_deftype(rec::ListRecord{S,D,I}) where {S,D,I} = D -get_idtype(rec::ListRecord{S,D,I}) where {S,D,I} = I - -nframes(rec::ListRecord) = length(rec.frames) -nstates(rec::ListRecord) = length(rec.states) -nids(rec::ListRecord) = length(keys(rec.defs)) - -frame_inbounds(rec::ListRecord, frame_index::Int) = 1 ≤ frame_index ≤ nframes(rec) -n_objects_in_frame(rec::ListRecord, frame_index::Int) = length(rec.frames[frame_index]) - -get_ids(rec::ListRecord) = collect(keys(rec.defs)) -nth_id(rec::ListRecord, frame_index::Int, n::Int=1) = rec.states[rec.frames[frame_index].lo + n-1].id - -get_time(rec::ListRecord, frame_index::Int) = rec.timestep * (frame_index-1) -get_timestep(rec::ListRecord) = rec.timestep -get_elapsed_time(rec::ListRecord, frame_lo::Int, frame_hi::Int) = rec.timestep * (frame_hi - frame_lo) - -function findfirst_stateindex_with_id(rec::ListRecord{S,D,I}, id::I, frame_index::Int) where {S,D,I} - recframe = rec.frames[frame_index] - for i in recframe.lo : recframe.hi - if rec.states[i].id == id - return i - end - end - return nothing -end -function findfirst_frame_with_id(rec::ListRecord{S,D,I}, id::I) where {S,D,I} - for frame in 1:length(rec.frames) - if findfirst_stateindex_with_id(rec, id, frame) != nothing - return frame - end - end - return nothing -end -function findlast_frame_with_id(rec::ListRecord{S,D,I}, id::Int) where {S,D,I} - for frame in reverse(1:length(rec.frames)) - if findfirst_stateindex_with_id(rec, id, frame) != nothing - return frame - end - end - return nothing -end - -Base.in(id::I, rec::ListRecord{S,D,I}, frame_index::Int) where {S,D,I} = findfirst_stateindex_with_id(rec, id, frame_index) != nothing -get_state(rec::ListRecord{S,D,I}, id::I, frame_index::Int) where {S,D,I} = rec.states[findfirst_stateindex_with_id(rec, id, frame_index)].state -get_def(rec::ListRecord{S,D,I}, id::I) where {S,D,I} = rec.defs[id] -Base.get(rec::ListRecord{S,D,I}, id::I, frame_index::Int) where {S,D,I} = Entity(get_state(rec, id, frame_index), get_def(rec,id), id) -function Base.get(rec::ListRecord, stateindex::Int) - recstate = rec.states[stateindex] - return Entity(recstate.state, get_def(rec, recstate.id), recstate.id) -end - -function get_subinterval(rec::ListRecord{S,D,I}, frame_index_lo::Int, frame_index_hi::Int) where {S,D,I} - frame_index_hi ≥ frame_index_lo || throw(DomainError()) - - frame_indexes = frame_index_lo : frame_index_hi - frames = Array{RecordFrame}(undef, length(frame_indexes)) - states = Array{RecordState{S,I}}(undef, rec.frames[frame_index_hi].hi - rec.frames[frame_index_lo].lo + 1) - defs = Dict{I, D}() - - hi = 1 - for (i,frame_index) in enumerate(frame_indexes) - frame = rec.frames[frame_index] - n = length(frame) - lo = hi - hi += n-1 - copyto!(states, lo, rec.states, frame.lo, n) - frames[i] = RecordFrame(lo, hi) - hi += 1 - end - - for state in states - defs[state.id] = get_def(rec, state.id) - end - - return ListRecord{S,D,I}(rec.timestep, frames, states, defs) -end -get_subinterval(rec::ListRecord, range::UnitRange{Int64}) = get_subinterval(rec, a.start, a.stop) - -################################# - -EntityFrame(rec::ListRecord{S,D,I}, N::Int=100) where {S,D,I} = Frame(Entity{S,D,I}, N) -function allocate_frame(rec::ListRecord{S,D,I}) where {S,D,I} - max_n_objects = maximum(n_objects_in_frame(rec,i) for i in 1 : nframes(rec)) - return Frame(Entity{S,D,I}, max_n_objects) -end -function Base.get!(frame::EntityFrame{S,D,I}, rec::ListRecord{S,D,I}, frame_index::Int) where {S,D,I} - - empty!(frame) - - if frame_inbounds(rec, frame_index) - recframe = rec.frames[frame_index] - for stateindex in recframe.lo : recframe.hi - push!(frame, get(rec, stateindex)) - end - end - - return frame -end - - -################################# - -""" -An iterator for looping over all states for a particular entity id. -Each element is a Tuple{Int,S} containing the frame index and the state. -""" -struct ListRecordStateByIdIterator{S,D,I} - rec::ListRecord{S,D,I} - id::I -end - -function Base.iterate(iter::ListRecordStateByIdIterator, frame_index::Int=1) - while frame_index ≤ nframes(iter.rec) && !in(iter.id, iter.rec, frame_index) - frame_index += 1 - end - if frame_index > nframes(iter.rec) - return nothing - end - item = (frame_index, get_state(iter.rec, iter.id, frame_index)) - return (item, frame_index+1) -end - -Base.length(iter::ListRecordStateByIdIterator) = sum(in(iter.id, iter.rec, i) for i in 1:nframes(iter.rec)) -Base.eltype(iter::ListRecordStateByIdIterator{S,D,I}) where {S,D,I} = Tuple{Int, S} - -################################ - -""" -An iterator for looping over all scenes. -Each element is an EntityFrame{S,D,I}. -The same frame is continuously overwritten. -As such, one should not call collect() on a frame iterator. -""" -struct ListRecordFrameIterator{S,D,I} - rec::ListRecord{S,D,I} - scene::EntityFrame{S,D,I} -end -ListRecordFrameIterator(rec::ListRecord{S,D,I}) where {S,D,I} = ListRecordFrameIterator(rec, allocate_frame(rec)) - -function Base.iterate(iter::ListRecordFrameIterator, frame_index::Int=1) - if frame_index > nframes(iter.rec) - return nothing - end - get!(iter.scene, iter.rec, frame_index) - return (iter.scene, frame_index+1) -end - -Base.length(iter::ListRecordFrameIterator) = nframes(iter.rec) -Base.eltype(iter::ListRecordFrameIterator{S,D,I}) where {S,D,I} = EntityFrame{S,D,I} \ No newline at end of file diff --git a/src/records/queuerecords.jl b/src/records/queuerecords.jl deleted file mode 100644 index d380e51..0000000 --- a/src/records/queuerecords.jl +++ /dev/null @@ -1,89 +0,0 @@ -mutable struct QueueRecord{E} - frames::Vector{Frame{E}} - timestep::Float64 - nframes::Int # number of active Frames -end -function QueueRecord(::Type{E}, capacity::Int, timestep::Float64, frame_capacity::Int=100) where {E} - frames = Array{Frame{E}}(undef, capacity) - for i in 1 : length(frames) - frames[i] = Frame(E, frame_capacity) - end - QueueRecord{E}(frames, timestep, 0) -end - -Base.show(io::IO, rec::QueueRecord) = print(io, "QueueRecord(nframes=", rec.nframes, ")") - -capacity(rec::QueueRecord) = length(rec.frames) -nframes(rec::QueueRecord) = rec.nframes -function nstates(rec::QueueRecord) - retval = 0 - for frame_index in 1 : nframes(rec) - retval += length(rec.frames[frame_index]) - end - return retval -end - -function Base.deepcopy(rec::QueueRecord) - retval = QueueRecord(capacity(rec), rec.timestep, capacity(rec.frames[1])) - for i in 1 : rec.nframes - copyto!(retval.frames[i], rec.frames[i]) - end - retval -end - -pastframe_inbounds(rec::QueueRecord, pastframe::Int) = 1 ≤ 1-pastframe ≤ rec.nframes - -""" -Indexed by pastframe, so pastframe == 0 is the current scene, -1 is the previous frame, etc. -""" -Base.getindex(rec::QueueRecord, pastframe::Int) = rec.frames[1 - pastframe] - -get_time(rec::QueueRecord, pastframe::Int) = -get_elapsed_time(rec, pastframe) -get_timestep(rec::QueueRecord) = rec.timestep -get_elapsed_time(rec::QueueRecord, pastframe::Int) = -pastframe*rec.timestep -function get_elapsed_time( - rec::QueueRecord, - pastframe_farthest_back::Int, - pastframe_most_recent::Int, - ) - - (pastframe_most_recent - pastframe_farthest_back)*rec.timestep -end - -function Base.empty!(rec::QueueRecord) - rec.nframes = 0 - return rec -end - -function push_back_records!(rec::QueueRecord) - for i in min(rec.nframes+1, capacity(rec)) : -1 : 2 - copyto!(rec.frames[i], rec.frames[i-1]) - end - return rec -end -function Base.insert!(rec::QueueRecord{E}, frame::Frame{E}, pastframe::Int=0) where {E} - copyto!(rec[pastframe], frame) - return rec -end -function Base.get!(frame::Frame{E}, rec::QueueRecord{E}, pastframe::Int=0) where {E} - copyto!(frame, rec[pastframe]) - frame -end -function update!(rec::QueueRecord{E}, frame::Frame{E}) where {E} - push_back_records!(rec) - insert!(rec, frame, 0) - rec.nframes = min(rec.nframes+1, capacity(rec)) - return rec -end - -function allocate_frame(rec::QueueRecord{E}) where {E} - max_n_objects = maximum(length(rec[j]) for j in 0 : 1-length(rec)) - return Frame(E, max_n_objects) -end - -const EntityQueueRecord{S,D,I} = QueueRecord{Entity{S,D,I}} - -Base.length(record::QueueRecord) = nframes(record) -function Base.iterate(record::QueueRecord, state::Int64=(-nframes(record)+1)) - return state<=0 ? (record[state], state+1) : nothing -end diff --git a/src/roadways/straight_1d_roadways.jl b/src/roadways/straight_1d_roadways.jl deleted file mode 100644 index cd2d636..0000000 --- a/src/roadways/straight_1d_roadways.jl +++ /dev/null @@ -1,37 +0,0 @@ -""" - StraightRoadway -A simple type representing a one lane, one dimensional straight roadway -# Fields -- `length::Float64` -""" -struct StraightRoadway - length::Float64 -end - -""" - mod_position_to_roadway(s::Float64, roadway::StraightRoadway) -performs a modulo of the position `s` with the length of `roadway` -""" -function mod_position_to_roadway(s::Float64, roadway::StraightRoadway) - while s > roadway.length - s -= roadway.length - end - while s < 0.0 - s += roadway.length - end - return s -end - -""" - get_headway(s_rear::Float64, s_fore::Float64, roadway::StraightRoadway) -returns a positive distance between s_rear and s_fore. -""" -function get_headway(s_rear::Float64, s_fore::Float64, roadway::StraightRoadway) - while s_fore < s_rear - s_fore += roadway.length - end - while s_fore > s_rear + roadway.length - s_fore -= roadway.length - end - return s_fore - s_rear # positive distance -end diff --git a/src/simulation/callbacks.jl b/src/simulation/callbacks.jl index a04ab7a..861072c 100644 --- a/src/simulation/callbacks.jl +++ b/src/simulation/callbacks.jl @@ -1,7 +1,7 @@ """ Run all callbacks """ -function _run_callbacks(callbacks::C, scenes::Union{EntityQueueRecord{S,D,I}, Vector{Frame{Entity{S,D,I}}}}, actions::Union{Nothing, Vector{Frame{A}}}, roadway::R, models::Dict{I,M}, tick::Int) where {S,D,I,A<:EntityAction,R,M<:DriverModel,C<:Tuple{Vararg{Any}}} +function _run_callbacks(callbacks::C, scenes::Vector{Frame{Entity{S,D,I}}}, actions::Union{Nothing, Vector{Frame{A}}}, roadway::R, models::Dict{I,M}, tick::Int) where {S,D,I,A<:EntityAction,R,M<:DriverModel,C<:Tuple{Vararg{Any}}} isdone = false for callback in callbacks isdone |= run_callback(callback, scenes, actions, roadway, models, tick) @@ -9,54 +9,6 @@ function _run_callbacks(callbacks::C, scenes::Union{EntityQueueRecord{S,D,I}, Ve return isdone end -function simulate!( - ::Type{A}, - rec::EntityQueueRecord{S,D,I}, - scene::EntityFrame{S,D,I}, - roadway::R, - models::Dict{I,M}, - nticks::Int, - callbacks::C, - ) where {S,D,I,A,R,M<:DriverModel,C<:Tuple{Vararg{Any}}} - Base.depwarn( -"`simulate!` using `EntityQueueRecord`s is deprecated since v0.7.10 and may be removed in future versions. - You should pass a pre-allocated vector of entitites `scenes::Vector{Frame{Entity{S,D,I}}}` to `simulate!` - or use the convenience function `simulate` without pre-allocation instead.", - :simulate_rec - ) - - empty!(rec) - update!(rec, scene) - - # potential early out right off the bat - if _run_callbacks(callbacks, rec, nothing, roadway, models, 0) - return rec - end - - actions = Array{A}(undef, length(scene)) - for tick in 1 : nticks - get_actions!(actions, scene, roadway, models) - tick!(scene, roadway, actions, get_timestep(rec)) - update!(rec, scene) - if _run_callbacks(callbacks, rec, nothing, roadway, models, tick) - break - end - end - - return rec -end -function simulate!( - rec::EntityQueueRecord{S,D,I}, - scene::EntityFrame{S,D,I}, - roadway::R, - models::Dict{I,M}, - nticks::Int, - callbacks::C, - ) where {S,D,I,R,M<:DriverModel,C<:Tuple{Vararg{Any}}} - - return simulate!(Any, rec, scene, roadway, models, nticks, callbacks) -end - ## Implementations of useful callbacks """ @@ -68,17 +20,6 @@ Terminates the simulation once a collision occurs mem::CPAMemory=CPAMemory() end -function run_callback( - callback::CollisionCallback, - rec::EntityQueueRecord{S,D,I}, - roadway::R, - models::Dict{I,M}, - tick::Int - ) where {S,D,I,R,M<:DriverModel} - - return !is_collision_free(rec[0], callback.mem) -end - function run_callback( callback::CollisionCallback, scenes::Vector{Frame{E}}, diff --git a/src/simulation/simulation.jl b/src/simulation/simulation.jl index ab1d0c3..24dc08a 100644 --- a/src/simulation/simulation.jl +++ b/src/simulation/simulation.jl @@ -1,117 +1,3 @@ -""" - get_actions!(actions::Vector{A}, scene::EntityFrame{S,D,I}, roadway::R, models::Dict{I, M},) where {S,D,I,A,R,M<:DriverModel} -Fill in `actions` with the actions of each agent present in the scene. It calls `observe!` -and `rand` for each driver models. -`actions` will contain the actions to apply to update the state of each vehicle. -""" -function get_actions!( - actions::Vector{A}, - scene::EntityFrame{S,D,I}, - roadway::R, - models::Dict{I, M}, # id → model - ) where {S,D,I,A,R,M<:DriverModel} - - - for (i,veh) in enumerate(scene) - model = models[veh.id] - observe!(model, scene, roadway, veh.id) - actions[i] = rand(model) - end - - actions -end - -""" - tick!(scene::EntityFrame{S,D,I}, roadway::R, actions::Vector{A}, Δt::Float64) where {S,D,I,A,R} -update `scene` in place by updating the state of each vehicle given their current action in `actions`. -It calls the `propagate` method for each vehicle in the scene. -""" -function tick!( - scene::EntityFrame{S,D,I}, - roadway::R, - actions::Vector{A}, - Δt::Float64, - ) where {S,D,I,A,R} - - for i in 1 : length(scene) - veh = scene[i] - state′ = propagate(veh, actions[i], roadway, Δt) - scene[i] = Entity(state′, veh.def, veh.id) - end - - return scene -end - -""" - reset_hidden_states!(models::Dict{Int,M}) where {M<:DriverModel} -reset hidden states of all driver models in `models` -""" -function reset_hidden_states!(models::Dict{Int,M}) where {M<:DriverModel} - for model in values(models) - reset_hidden_state!(model) - end - return models -end - -""" -DEPRECATION WARNING: this version of `simulate!` is now deprecated. - - simulate!(scene::Frame{E}, roadway::R, models::Dict{I,M<:DriverModel}, nticks::Int64, timestep::Float64; rng::AbstractRNG = Random.GLOBAL_RNG, scenes::Vector{Frame{E}} = [Frame(E, length(scene)) for i=1:nticks+1], callbacks=nothing) - -Run `nticks` steps of simulation with time step `dt` and return a vector of scenes from time step 0 to nticks. - - simulate!(::Type{A}, rec::EntityQueueRecord{S,D,I}, scene::EntityFrame{S,D,I}, roadway::R, models::Dict{I,M<:DriverModel}, nticks::Int) - simulate!(rec::EntityQueueRecord{S,D,I}, scene::EntityFrame{S,D,I}, roadway::R, models::Dict{I,M<:DriverModel}, nticks::Int) - -Run nticks of simulation and place all nticks+1 scenes into the QueueRecord - - simulate!(::Type{A},rec::EntityQueueRecord{S,D,I}, scene::EntityFrame{S,D,I}, roadway::R, models::Dict{I,M}, nticks::Int, callbacks::C) where {S,D,I,A,R,M<:DriverModel,C<:Tuple{Vararg{Any}}} - simulate!(rec::EntityQueueRecord{S,D,I}, scene::EntityFrame{S,D,I}, roadway::R, models::Dict{I,M}, nticks::Int, callbacks::C) where {S,D,I,A,R,M<:DriverModel,C<:Tuple{Vararg{Any}}} - -Callback objects can also be passed in the simulate! function. - -""" -function simulate!( - ::Type{A}, - rec::EntityQueueRecord{S,D,I}, - scene::EntityFrame{S,D,I}, - roadway::R, - models::Dict{I,M}, - nticks::Int, - ) where {S,D,I,A,R,M<:DriverModel} - Base.depwarn( -"`simulate!` using `EntityQueueRecord`s is deprecated since v0.7.10 and may be removed in future versions. - You should pass a pre-allocated vector of entitites `scenes::Vector{Frame{Entity{S,D,I}}}` to `simulate!` - or use the convenience function `simulate` without pre-allocation instead.", - :simulate_rec - ) - - empty!(rec) - update!(rec, scene) - actions = Array{A}(undef, length(scene)) - - for tick in 1 : nticks - get_actions!(actions, scene, roadway, models) - tick!(scene, roadway, actions, rec.timestep) - update!(rec, scene) - end - - return rec -end - - -function simulate!( - rec::EntityQueueRecord{S,D,I}, - scene::EntityFrame{S,D,I}, - roadway::R, - models::Dict{I,M}, - nticks::Int - ) where {S,D,I,R,M<:DriverModel} - - return simulate!(Any, rec, scene, roadway, models, nticks) -end - - """ simulate( scene::Frame{E}, roadway::R, models::Dict{I,M}, nticks::Int64, timestep::Float64; @@ -138,7 +24,6 @@ function simulate( return scenes[1:(n+1)] end - """ simulate!( @@ -207,56 +92,3 @@ function simulate!( end return nticks end - -""" - Run a simulation and store the resulting scenes in the provided QueueRecord. -Only the ego vehicle is simulated; the other vehicles are as they were in the provided trajdata -Other vehicle states will be interpolated -""" -function simulate!( - rec::EntityQueueRecord{S,D,I}, - model::DriverModel, - egoid::I, - trajdata::ListRecord{S,D,I}, - roadway::R, - frame_start::Int, - frame_end::Int; - prime_history::Int=0, # no prime-ing - scene::EntityFrame{S,D,I} = allocate_frame(trajdata), - ) where {S,D,I,R} - - @assert(isapprox(get_timestep(rec), get_timestep(trajdata))) - - # prime with history - prime_with_history!(model, trajdata, roadway, frame_start, frame_end, egoid, scene) - - # add current frame - update!(rec, get!(scene, trajdata, frame_start)) - observe!(model, scene, roadway, egoid) - - # run simulation - frame_index = frame_start - ego_veh = get_by_id(scene, egoid) - while frame_index < frame_end - - # pull original scene - get!(scene, trajdata, frame_index) - - # propagate ego vehicle and set - ego_action = rand(model) - ego_state = propagate(ego_veh, ego_action, roadway, get_timestep(rec)) - ego_veh = Entity(ego_veh, ego_state) - scene[findfirst(ego_veh.id, scene)] = ego_veh - - # update record - update!(rec, scene) - - # observe - observe!(model, scene, roadway, ego_veh.id) - - # update time - frame_index += 1 - end - - return rec -end diff --git a/src/simulation/simulation_from_history.jl b/src/simulation/simulation_from_history.jl new file mode 100644 index 0000000..0251c17 --- /dev/null +++ b/src/simulation/simulation_from_history.jl @@ -0,0 +1,80 @@ +""" + observe_from_history!(model::DriverModel, roadway::Roadway, trajdata::Vector{<:EntityFrame}, egoid, start::Int, stop::Int) + +Given a prerecorded trajectory `trajdata`, run the observe function of a driver model for the scenes between `start` and `stop` for the vehicle of id `egoid`. +The ego vehicle does not take any actions, it just observe the scenes, +""" +function observe_from_history!( + model::DriverModel, + roadway::Roadway, + trajdata::Vector{<:EntityFrame}, + egoid, + start::Int = 1, + stop::Int = length(trajdata)) + reset_hidden_state!(model) + + for i=start:stop + observe!(model, trajdata[i], roadway, egoid) + end + + return model +end + +function maximum_entities(trajdata::Vector{<:EntityFrame}) + return maximum(capacity, trajdata) +end + +function simulate_from_history( + model::DriverModel, + roadway::Roadway, + trajdata::Vector{Frame{E}}, + egoid, + timestep::Float64, + start::Int = 1, + stop::Int = length(trajdata); + rng::AbstractRNG = Random.GLOBAL_RNG + ) where {E<:Entity} + scenes = [Frame(E, maximum_entities(trajdata)) for i=1:(stop - start + 1)] + n = simulate_from_history!(model, roadway, trajdata, egoid, timestep, + start, stop, scenes, + rng=rng) + return scenes[1:(n+1)] +end + +function simulate_from_history!( + model::DriverModel, + roadway::Roadway, + trajdata::Vector{Frame{E}}, + egoid, + timestep::Float64, + start::Int, + stop::Int, + scenes::Vector{Frame{E}}; + actions::Union{Nothing, Vector{Frame{A}}} = nothing, + rng::AbstractRNG = Random.GLOBAL_RNG + ) where {E<:Entity, A<:EntityAction} + + # run model (unsure why it is needed, it was in the old code ) + observe_from_history!(model, roadway, trajdata, egoid, start, stop) + + copyto!(scenes[1], trajdata[start]) + for tick=1:(stop - start) + + empty!(scenes[tick + 1]) + if (actions !== nothing) empty!(actions[tick]) end + + ego = get_by_id(scenes[tick], egoid) + observe!(model, scenes[tick], roadway, egoid) + a = rand(rng, model) + + ego_state_p = propagate(ego, a, roadway, timestep) + + copyto!(scenes[tick+1], trajdata[start+tick]) + egoind = findfirst(egoid, scenes[tick+1]) + scenes[tick+1][egoind] = Entity(ego_state_p, ego.def, egoid) + + if (actions !== nothing) push!(actions[tick], EntityAction(a, egoid)) end + + end + return (stop - start) +end diff --git a/src/states/entities.jl b/src/states/entities.jl new file mode 100644 index 0000000..a467df7 --- /dev/null +++ b/src/states/entities.jl @@ -0,0 +1,26 @@ +""" + Entity{S,D,I} +Immutable data structure to represent entities (vehicle, pedestrian, ...). +Entities are defined by a state, a definition, and an id. +The state of an entity usually models changing values while the definition and the id should not change. + +# Constructor + +`Entity(state, definition, id)` + +Copy constructor that keeps the definition and id but changes the state (a new object is still created): + +`Entity(entity::Entity{S,D,I}, s::S)` + +# Fields + +- `state::S` +- `def::D` +- `id::I` +""" +struct Entity{S,D,I} # state, definition, identification + state::S + def::D + id::I +end +Entity(entity::Entity{S,D,I}, s::S) where {S,D,I} = Entity(s, entity.def, entity.id) diff --git a/src/records/frames.jl b/src/states/frames.jl similarity index 77% rename from src/records/frames.jl rename to src/states/frames.jl index 06f9d4e..eda3455 100644 --- a/src/records/frames.jl +++ b/src/states/frames.jl @@ -1,3 +1,23 @@ +""" + Frame{E} + +Container to store a list of entities. +The main difference from a regular array is that its size is defined at construction and is fixed. +(`push!` is O(1)) + +# Constructors + +- `Frame(arr::AbstractVector; capacity::Int=length(arr))` +- `Frame(::Type{E}, capacity::Int=100) where {E}` + + +# Fields + +To interact with `Frame` object it is preferable to use functions rather than accessing the fields directly. + +- `entities::Vector{E}` +- `n::Int` current number of entities in the scene +""" mutable struct Frame{E} entities::Vector{E} # NOTE: I tried StaticArrays; was not faster n::Int @@ -15,7 +35,13 @@ end Base.show(io::IO, frame::Frame{E}) where {E}= @printf(io, "Frame{%s}(%d entities)", string(E), length(frame)) +""" + capacity(frame::Frame) +returns the maximum number of entities that can be put in the frame. +To get the current number of entities use `length` instead. +""" capacity(frame::Frame) = length(frame.entities) + Base.length(frame::Frame) = frame.n Base.getindex(frame::Frame, i::Int) = frame.entities[i] Base.eltype(frame::Frame{E}) where {E} = E @@ -61,12 +87,21 @@ end #### +""" + EntityFrame{S,D,I} = Frame{Entity{S,D,I}} +Alias for `Frame` when the entities in the frame are of type `Entity` + +# Constructors +- `EntityFrame(::Type{S},::Type{D},::Type{I}) where {S,D,I}` +- `EntityFrame(::Type{S},::Type{D},::Type{I},capacity::Int)` +""" const EntityFrame{S,D,I} = Frame{Entity{S,D,I}} EntityFrame(::Type{S},::Type{D},::Type{I}) where {S,D,I} = Frame(Entity{S,D,I}) -EntityFrame(::Type{S},::Type{D},::Type{I},N::Int) where {S,D,I} = Frame(Entity{S,D,I}, N) +EntityFrame(::Type{S},::Type{D},::Type{I},capacity::Int) where {S,D,I} = Frame(Entity{S,D,I}, capacity) + +Base.in(id::I, frame::EntityFrame{S,D,I}) where {S,D,I} = findfirst(id, frame) !== nothing -Base.in(id::I, frame::EntityFrame{S,D,I}) where {S,D,I} = findfirst(id, frame) != nothing function Base.findfirst(id::I, frame::EntityFrame{S,D,I}) where {S,D,I} for entity_index in 1 : frame.n entity = frame.entities[entity_index] @@ -76,6 +111,7 @@ function Base.findfirst(id::I, frame::EntityFrame{S,D,I}) where {S,D,I} end return nothing end + function id2index(frame::EntityFrame{S,D,I}, id::I) where {S,D,I} entity_index = findfirst(id, frame) if entity_index === nothing @@ -83,7 +119,13 @@ function id2index(frame::EntityFrame{S,D,I}, id::I) where {S,D,I} end return entity_index end + +""" + get_by_id(frame::EntityFrame{S,D,I}, id::I) where {S,D,I} +Retrieve the entity by its `id`. This function uses `findfirst` which is O(n). +""" get_by_id(frame::EntityFrame{S,D,I}, id::I) where {S,D,I} = frame[id2index(frame, id)] + function get_first_available_id(frame::EntityFrame{S,D,I}) where {S,D,I} ids = Set{I}(entity.id for entity in frame) id_one = one(I) diff --git a/src/states/vehicle_state.jl b/src/states/vehicle_state.jl index 2e26295..36f68ca 100644 --- a/src/states/vehicle_state.jl +++ b/src/states/vehicle_state.jl @@ -66,17 +66,6 @@ function Vec.lerp(a::VehicleState, b::VehicleState, t::Float64, roadway::Roadway VehicleState(posG, roadway, v) end -""" - get_vel_s(s::VehicleState) -returns the longitudinal velocity (along the lane) -""" -get_vel_s(s::VehicleState) = s.v * cos(s.posF.ϕ) # velocity along the lane -""" - get_vel_t(s::VehicleState) -returns the lateral velocity (⟂ to lane) -""" -get_vel_t(s::VehicleState) = s.v * sin(s.posF.ϕ) # velocity ⟂ to lane - """ move_along(vehstate::VehicleState, roadway::Roadway, Δs::Float64; ϕ₂::Float64=vehstate.posF.ϕ, t₂::Float64=vehstate.posF.t, v₂::Float64=vehstate.v) @@ -123,7 +112,6 @@ returns the position of the rear of the vehicle """ get_rear(veh::Entity{VehicleState, D, I}) where {D<:AbstractAgentDefinition, I} = veh.state.posG - polar(length(veh.def)/2, veh.state.posG.θ) - """ get_lane(roadway::Roadway, vehicle::Entity) get_lane(roadway::Roadway, vehicle::VehicleState) @@ -136,3 +124,13 @@ function get_lane(roadway::Roadway, vehicle::VehicleState) lane_tag = vehicle.posF.roadind.tag return roadway[lane_tag] end + +""" + Base.convert(::Type{Entity{S, VehicleDef, I}}, veh::Entity{S, D, I}) where {S,D<:AbstractAgentDefinition,I} + +Converts the definition of an entity +""" +function Base.convert(::Type{Entity{S, VehicleDef, I}}, veh::Entity{S, D, I}) where {S,D<:AbstractAgentDefinition,I} + vehdef = VehicleDef(class(veh.def), length(veh.def), width(veh.def)) + return Entity{S, VehicleDef, I}(veh.state, vehdef, veh.id) +end diff --git a/test/runtests.jl b/test/runtests.jl index bee6009..2034c6d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,13 +7,10 @@ using Distributions include("vec-tests/vec_runtests.jl") end -@testset "Data structures" begin - include("test_records.jl") -end - @testset "AutomotiveDrivingModels" begin include("test_roadways.jl") include("test_agent_definitions.jl") + include("test_frames.jl") include("test_states.jl") include("test_collision_checkers.jl") include("test_actions.jl") diff --git a/test/test_actions.jl b/test/test_actions.jl index 7480abf..0022436 100644 --- a/test/test_actions.jl +++ b/test/test_actions.jl @@ -1,7 +1,7 @@ @testset "action interface" begin roadway = get_test_roadway() trajdata = get_test_trajdata(roadway) - veh = get(trajdata, 1, 1) + veh = trajdata[1][1] s = VehicleState() @test VehicleState() == propagate(veh, s, roadway, NaN) end @@ -9,7 +9,7 @@ end @testset "AccelTurnrate" begin roadway = get_test_roadway() trajdata = get_test_trajdata(roadway) - veh = get(trajdata, 1, 1) + veh = trajdata[1][1] a = AccelTurnrate(0.1,0.2) io = IOBuffer() show(io, a) @@ -29,7 +29,7 @@ end @testset "AccelDesang" begin roadway = get_test_roadway() trajdata = get_test_trajdata(roadway) - veh = get(trajdata, 1, 1) + veh = trajdata[1][1] a = AccelDesang(0.1,0.2) @test a == convert(AccelDesang, [0.1,0.2]) @test copyto!([NaN, NaN], AccelDesang(0.1,0.2)) == [0.1,0.2] @@ -48,7 +48,7 @@ end @testset "AccelSteeringAngle" begin roadway = get_test_roadway() trajdata = get_test_trajdata(roadway) - veh = get(trajdata, 1, 1) + veh = trajdata[1][1] a = AccelSteeringAngle(0.1,0.2) io = IOBuffer() show(io, a) @@ -74,17 +74,17 @@ end @testset "LaneFollowingAccel" begin a = LaneFollowingAccel(1.0) - roadway1d = StraightRoadway(20.0) - s1d = State1D(10.0, 10.0) - veh = Vehicle1D(s1d, VehicleDef(), 1) - vehp = propagate(veh, a, roadway1d, 1.0) - @test vehp.s == 0.5 + roadway = gen_straight_roadway(3, 100.0) + s = VehicleState(VecSE2(0.0, 0.0, 0.0), roadway, 0.0) + veh = Entity(s, VehicleDef(), 1) + vehp = propagate(veh, a, roadway, 1.0) + @test posf(vehp).s == 0.5 end @testset "LatLonAccel" begin roadway = get_test_roadway() trajdata = get_test_trajdata(roadway) - veh = get(trajdata, 1, 1) + veh = trajdata[1][1] a = LatLonAccel(0.1,0.2) io = IOBuffer() show(io, a) @@ -104,7 +104,7 @@ end @testset "Pedestrian LatLon" begin roadway = get_test_roadway() trajdata = get_test_trajdata(roadway) - veh = get(trajdata, 1, 1) + veh = trajdata[1][1] a = PedestrianLatLonAccel(0.5,1.0, roadway[LaneTag(2,1)]) Δt = 1.0 s = propagate(veh, a, roadway, Δt) diff --git a/test/test_behaviors.jl b/test/test_behaviors.jl index a6c87b1..202a254 100644 --- a/test/test_behaviors.jl +++ b/test/test_behaviors.jl @@ -8,8 +8,8 @@ struct FakeDriverModel <: DriverModel{FakeDriveAction} end model = FakeDriverModel() @test_throws MethodError reset_hidden_state!(model) - @test_throws MethodError observe!(model, Frame(), roadway, 1) - @test_throws MethodError prime_with_history!(model, trajdata, roadway, 1, 2, 1) + @test_throws MethodError observe!(model, Frame(Entity{VehicleState, VehicleDef, Int64}), roadway, 1) + @test_throws MethodError observe_from_history!(model, roadway, trajdata, 1, 2, 1) @test action_type(model) <: FakeDriveAction @test_throws MethodError set_desired_speed!(model, 0.0) @@ -35,8 +35,6 @@ end n_steps = 40 dt = 0.1 - rec = QueueRecord(typeof(veh1), n_steps, dt) - @test_deprecated simulate!(rec, scene, roadway, models, n_steps) simulate(scene, roadway, models, n_steps, dt) @test isapprox(get_by_id(scene, 2).state.v, models[2].v_des) @@ -49,11 +47,10 @@ end @test logpdf(models[1], LaneFollowingAccel(0.0)) < 0.0 n_steps = 40 dt = 0.1 - rec = QueueRecord(typeof(veh1), n_steps, dt) - @test_deprecated simulate!(rec, scene, roadway, models, n_steps) - simulate(scene, roadway, models, n_steps, dt) - prime_with_history!(IntelligentDriverModel(), rec, roadway, 2) + scenes = simulate(scene, roadway, models, n_steps, dt) + + observe_from_history!(IntelligentDriverModel(), roadway, scenes, 2) println("There should be a warning here: ") @@ -65,8 +62,6 @@ end scene = Frame([veh1, veh2]) - rec = QueueRecord(eltype(scene), n_steps, dt) - @test_deprecated simulate!(rec, scene, roadway, models, 1) simulate(scene, roadway, models, 1, dt) end @@ -141,12 +136,10 @@ end scene = Frame([veh1, veh2]) - rec = QueueRecord(typeof(veh1), n_steps, dt) - @test_deprecated simulate!(rec, scene, roadway, models, n_steps) - simulate(scene, roadway, models, n_steps, dt) + scenes = simulate(scene, roadway, models, n_steps, dt) - @test scene[1].state.posF.roadind.tag == LaneTag(1, 3) - @test scene[2].state.posF.roadind.tag == LaneTag(1, 2) + @test scenes[end][1].state.posF.roadind.tag == LaneTag(1, 3) + @test scenes[end][2].state.posF.roadind.tag == LaneTag(1, 2) end @testset "lane following" begin @@ -178,12 +171,11 @@ end n_steps = 40 dt = 0.1 - rec = QueueRecord(typeof(veh1), n_steps, dt) - @test_deprecated simulate!(rec, scene, roadway, models, n_steps) - simulate(scene, roadway, models, n_steps, dt) - @test isapprox(get_by_id(scene, 2).state.v, models[2].v_des, atol=1e-3) - @test isapprox(get_by_id(scene, 3).state.v, models[3].v_des) + scenes = simulate(scene, roadway, models, n_steps, dt) + + @test isapprox(get_by_id(scenes[end], 2).state.v, models[2].v_des, atol=1e-3) + @test isapprox(get_by_id(scenes[end], 3).state.v, models[3].v_des) # same wth noise models = Dict{Int, DriverModel}() @@ -199,13 +191,11 @@ end n_steps = 40 dt = 0.1 - rec = QueueRecord(typeof(veh1), n_steps, dt) - @test_deprecated simulate!(rec, scene, roadway, models, n_steps) - simulate(scene, roadway, models, n_steps, dt) + scenes = simulate(scene, roadway, models, n_steps, dt) - @test isapprox(get_by_id(scene, 2).state.v, models[2].v_des, atol=1.0) - @test isapprox(get_by_id(scene, 3).state.v, models[3].v_des, atol=1.0) + @test isapprox(get_by_id(scenes[end], 2).state.v, models[2].v_des, atol=1.0) + @test isapprox(get_by_id(scenes[end], 3).state.v, models[3].v_des, atol=1.0) end function generate_sidewalk_env() @@ -285,10 +275,7 @@ end ) nticks = 300 - rec = QueueRecord(typeof(car), nticks+1, timestep) - # Execute the simulation - @test_deprecated simulate!(rec, scene, roadway, models, nticks) - simulate(scene, roadway, models, nticks, timestep) + scenes = simulate(scene, roadway, models, nticks, timestep) - ped = get_by_id(rec[0], ped_id) + ped = get_by_id(scenes[end], ped_id) end diff --git a/test/test_collision_checkers.jl b/test/test_collision_checkers.jl index 1f254d5..a25224e 100644 --- a/test/test_collision_checkers.jl +++ b/test/test_collision_checkers.jl @@ -33,18 +33,18 @@ end trajdata = get_test_trajdata(roadway) scene = Frame(Entity{VehicleState, VehicleDef, Int64}) - col = get_first_collision(get!(scene, trajdata, 1)) + col = get_first_collision(trajdata[1]) @test col.is_colliding @test col.A == 1 @test col.B == 2 - @test get_first_collision(get!(scene, trajdata, 2), CPAMemory()).is_colliding == true - scene = Frame(Entity{VehicleState, VehicleDef, Int64}) - @test is_collision_free(get!(scene, trajdata, 1)) == false - @test is_collision_free(get!(scene, trajdata, 1), [1]) == false + @test get_first_collision(trajdata[1], CPAMemory()).is_colliding == true + scene = trajdata[1] + @test is_collision_free(scene) == false + @test is_collision_free(scene, [1]) == false @test is_colliding(scene[1], scene[2]) @test get_distance(scene[1], scene[2]) == 0 - @test is_collision_free(get!(scene, trajdata, 3)) + @test is_collision_free(trajdata[2]) get_distance(scene[1], scene[2]) roadway = gen_straight_roadway(2, 100.0) @@ -77,7 +77,6 @@ end roadway = get_test_roadway() trajdata = get_test_trajdata(roadway) - scene = Frame(Entity{VehicleState, VehicleDef, Int64}) - get!(scene, trajdata, 1) + scene = trajdata[1] @test collision_checker(scene, 1) end diff --git a/test/test_features.jl b/test/test_features.jl index d002e2e..5844d28 100644 --- a/test/test_features.jl +++ b/test/test_features.jl @@ -23,13 +23,11 @@ @test find_neighbor(scene, roadway, scene[1], lane=rightlane(roadway, scene[1]), rear=true) == NeighborLongitudinalResult(4,10.0) trajdata = get_test_trajdata(roadway) - scene = get!(Frame(Entity{VehicleState, VehicleDef, Int64}), trajdata, 1) + scene = trajdata[1] @test find_neighbor(scene, roadway, scene[1]) == NeighborLongitudinalResult(2, 3.0) - scene = get!(Frame(Entity{VehicleState, VehicleDef, Int64}), trajdata, 1) @test find_neighbor(scene, roadway, scene[2]) == NeighborLongitudinalResult(nothing, 250.0) - scene = get!(Frame(Entity{VehicleState, VehicleDef, Int64}), trajdata, 2) + scene = trajdata[2] @test find_neighbor(scene, roadway, scene[1]) == NeighborLongitudinalResult(2, 4.0) - scene = get!(Frame(Entity{VehicleState, VehicleDef, Int64}), trajdata, 2) @test find_neighbor(scene, roadway, scene[2]) == NeighborLongitudinalResult(nothing, 250.0) roadway = gen_stadium_roadway(1) diff --git a/test/test_frames.jl b/test/test_frames.jl new file mode 100644 index 0000000..8de1421 --- /dev/null +++ b/test/test_frames.jl @@ -0,0 +1,100 @@ +@testset "Frame" begin + @testset begin + frame = Frame([1,2,3]) + @test length(frame) == 3 + @test capacity(frame) == 3 + for i in 1 : 3 + @test frame[i] == i + end + + frame = Frame([1,2,3], capacity=5) + @test length(frame) == 3 + @test capacity(frame) == 5 + for i in 1 : 3 + @test frame[i] == i + end + + @test_throws ErrorException Frame([1,2,3], capacity=2) + + frame = Frame(Int) + @test length(frame) == 0 + @test capacity(frame) > 0 + + frame = Frame(Int, 2) + @test length(frame) == 0 + @test capacity(frame) == 2 + + frame[1] = 999 + frame[2] = 888 + + @test frame[1] == 999 + @test frame[2] == 888 + @test length(frame) == 0 # NOTE: length does not change + @test capacity(frame) == 2 + + empty!(frame) + @test length(frame) == 0 + @test capacity(frame) == 2 + + push!(frame, 999) + push!(frame, 888) + @test length(frame) == 2 + @test capacity(frame) == 2 + + @test_throws BoundsError push!(frame, 777) + + frame = Frame([999,888]) + deleteat!(frame, 1) + @test length(frame) == 1 + @test capacity(frame) == 2 + @test frame[1] == 888 + + deleteat!(frame, 1) + @test length(frame) == 0 + @test capacity(frame) == 2 + + frame = Frame([1,2,3]) + frame2 = copy(frame) + for i in 1 : 3 + @test frame[i] == frame2[i] + end + frame[1] = 999 + @test frame2[1] == 1 + end + + @testset begin + frame = EntityFrame(Int, Float64, String) + @test eltype(frame) == Entity{Int,Float64,String} + + frame = Frame([Entity(1,1,"A"),Entity(2,2,"B"),Entity(3,3,"C")]) + @test in("A", frame) + @test in("B", frame) + @test in("C", frame) + @test !in("D", frame) + @test findfirst("A", frame) == 1 + @test findfirst("B", frame) == 2 + @test findfirst("C", frame) == 3 + @test findfirst("D", frame) == nothing + @test id2index(frame, "A") == 1 + @test id2index(frame, "B") == 2 + @test id2index(frame, "C") == 3 + @test_throws BoundsError id2index(frame, "D") + + frame = Frame([Entity(1,1,"A"),Entity(2,2,"B"),Entity(3,3,"C")]) + @test get_by_id(frame, "A") == frame[1] + @test get_by_id(frame, "B") == frame[2] + @test get_by_id(frame, "C") == frame[3] + + delete!(frame, Entity(2,2,"B")) + @test frame[1] == Entity(1,1,"A") + @test frame[2] == Entity(3,3,"C") + @test length(frame) == 2 + + delete!(frame, "A") + @test frame[1] == Entity(3,3,"C") + @test length(frame) == 1 + + frame = Frame([Entity(1,1,1),Entity(2,2,2)], capacity=3) + @test get_first_available_id(frame) == 3 + end +end diff --git a/test/test_records.jl b/test/test_records.jl deleted file mode 100644 index 22699f8..0000000 --- a/test/test_records.jl +++ /dev/null @@ -1,291 +0,0 @@ -@testset "Conversions" begin - @testset "list" begin - lrec = ListRecord(1.0, Float64, Bool, Int) - append!(lrec.frames, [RecordFrame(1,2), RecordFrame(3,4)]) - append!(lrec.states, [RecordState(1.0,1),RecordState(2.0,2),RecordState(3.0,1),RecordState(4.0,3)]) - lrec.defs[1] = true - lrec.defs[2] = true - lrec.defs[3] = false - - sparsemat, id_lookup = get_sparse_lookup(lrec) - @test sparsemat[1,1] == 1.0 - @test sparsemat[2,1] == 3.0 - @test sparsemat[1,2] == 2.0 - @test sparsemat[2,2] == 0.0 - @test sparsemat[1,3] == 0.0 - @test sparsemat[2,3] == 4.0 - @test id_lookup == Dict(2=>2,3=>3,1=>1) - - qrec = convert(QueueRecord{Entity{Float64, Bool, Int}}, lrec) - @test qrec[-1][1] == Entity(1.0,true,1) - @test qrec[-1][2] == Entity(2.0,true,2) - @test qrec[ 0][1] == Entity(3.0,true,1) - @test qrec[ 0][2] == Entity(4.0,false,3) - @test nframes(qrec) == 2 - end - - @testset "queue" begin - qrec = QueueRecord(Entity{Float64, Bool, Int}, 2, 0.1) - @test nframes(qrec) == 0 - - update!(qrec, Frame([Entity(1.0,true,1), Entity(2.0,true,2)])) - update!(qrec, Frame([Entity(3.0,true,1), Entity(4.0,false,3)])) - - @test nframes(qrec) == 2 - - lrec = convert(ListRecord{Float64, Bool, Int}, qrec) - @test nframes(lrec) == 2 - @test lrec.frames == [RecordFrame(1,2), RecordFrame(3,4)] - @test lrec.states == [RecordState(1.0,1),RecordState(2.0,2),RecordState(3.0,1),RecordState(4.0,3)] - @test lrec.defs[1] == true - @test lrec.defs[2] == true - @test lrec.defs[3] == false - @test length(lrec.defs) == 3 - end -end - -@testset "Frame" begin - @testset begin - frame = Frame([1,2,3]) - @test length(frame) == 3 - @test capacity(frame) == 3 - for i in 1 : 3 - @test frame[i] == i - end - - frame = Frame([1,2,3], capacity=5) - @test length(frame) == 3 - @test capacity(frame) == 5 - for i in 1 : 3 - @test frame[i] == i - end - - @test_throws ErrorException Frame([1,2,3], capacity=2) - - frame = Frame(Int) - @test length(frame) == 0 - @test capacity(frame) > 0 - - frame = Frame(Int, 2) - @test length(frame) == 0 - @test capacity(frame) == 2 - - frame[1] = 999 - frame[2] = 888 - - @test frame[1] == 999 - @test frame[2] == 888 - @test length(frame) == 0 # NOTE: length does not change - @test capacity(frame) == 2 - - empty!(frame) - @test length(frame) == 0 - @test capacity(frame) == 2 - - push!(frame, 999) - push!(frame, 888) - @test length(frame) == 2 - @test capacity(frame) == 2 - - @test_throws BoundsError push!(frame, 777) - - frame = Frame([999,888]) - deleteat!(frame, 1) - @test length(frame) == 1 - @test capacity(frame) == 2 - @test frame[1] == 888 - - deleteat!(frame, 1) - @test length(frame) == 0 - @test capacity(frame) == 2 - - frame = Frame([1,2,3]) - frame2 = copy(frame) - for i in 1 : 3 - @test frame[i] == frame2[i] - end - frame[1] = 999 - @test frame2[1] == 1 - end - - @testset begin - frame = EntityFrame(Int, Float64, String) - @test eltype(frame) == Entity{Int,Float64,String} - - frame = Frame([Entity(1,1,"A"),Entity(2,2,"B"),Entity(3,3,"C")]) - @test in("A", frame) - @test in("B", frame) - @test in("C", frame) - @test !in("D", frame) - @test findfirst("A", frame) == 1 - @test findfirst("B", frame) == 2 - @test findfirst("C", frame) == 3 - @test findfirst("D", frame) == nothing - @test id2index(frame, "A") == 1 - @test id2index(frame, "B") == 2 - @test id2index(frame, "C") == 3 - @test_throws BoundsError id2index(frame, "D") - - frame = Frame([Entity(1,1,"A"),Entity(2,2,"B"),Entity(3,3,"C")]) - @test get_by_id(frame, "A") == frame[1] - @test get_by_id(frame, "B") == frame[2] - @test get_by_id(frame, "C") == frame[3] - - delete!(frame, Entity(2,2,"B")) - @test frame[1] == Entity(1,1,"A") - @test frame[2] == Entity(3,3,"C") - @test length(frame) == 2 - - delete!(frame, "A") - @test frame[1] == Entity(3,3,"C") - @test length(frame) == 1 - - frame = Frame([Entity(1,1,1),Entity(2,2,2)], capacity=3) - @test get_first_available_id(frame) == 3 - end -end - -@testset "IO" begin - @testset begin - lrec = ListRecord(1.0, Float64, Bool, Int) - append!(lrec.frames, [RecordFrame(1,2), RecordFrame(3,4)]) - append!(lrec.states, [RecordState(1.0,1),RecordState(2.0,2),RecordState(3.0,1),RecordState(4.0,3)]) - lrec.defs[1] = true - lrec.defs[2] = true - lrec.defs[3] = false - - file = tempname() - open(file, "w") do io - write(io, MIME"text/plain"(), lrec) - end - - lrec2 = open(file, "r") do io - read(io, MIME"text/plain"(), ListRecord{Float64, Bool, Int}) - end - @test nframes(lrec2) == 2 - @test lrec2.frames == [RecordFrame(1,2), RecordFrame(3,4)] - @test lrec2.states == [RecordState(1.0,1),RecordState(2.0,2),RecordState(3.0,1),RecordState(4.0,3)] - @test lrec2.defs[1] == true - @test lrec2.defs[2] == true - @test lrec2.defs[3] == false - @test length(lrec2.defs) == 3 - - rm(file) - end - - @testset begin - frames = [ - Frame([Entity(1.0,true,1), Entity(2.0, true,2)]), - Frame([Entity(3.0,true,1), Entity(4.0,false,3)]), - ] - - file = tempname() - open(file, "w") do io - write(io, MIME"text/plain"(), frames) - end - - frames2 = open(file, "r") do io - read(io, MIME"text/plain"(), Vector{Frame{Entity{Float64, Bool, Int}}}) - end - @test frames2[1][1] == Entity(1.0,true,1) - @test frames2[1][2] == Entity(2.0,true,2) - @test frames2[2][1] == Entity(3.0,true,1) - @test frames2[2][2] == Entity(4.0,false,3) - @test length(frames2) == 2 - - rm(file) - end -end - -@testset "ListRecord" begin - @testset begin - rec = ListRecord(1.0, Float64, Bool, Int) - @test get_statetype(rec) == Float64 - @test get_deftype(rec) == Bool - @test get_idtype(rec) == Int - - append!(rec.frames, [RecordFrame(1,2), RecordFrame(3,4)]) - append!(rec.states, [RecordState(1.0,1),RecordState(2.0,2),RecordState(3.0,1),RecordState(4.0,3)]) - rec.defs[1] = true - rec.defs[2] = true - rec.defs[3] = false - - @test nframes(rec) == 2 - @test nstates(rec) == 4 - @test nids(rec) == 3 - - @test sort!(get_ids(rec)) == [1,2,3] - @test nth_id(rec, 1, 1) == 1 - @test nth_id(rec, 1, 2) == 2 - @test nth_id(rec, 2, 1) == 1 - @test nth_id(rec, 2, 2) == 3 - - @test length(ListRecordStateByIdIterator(rec, 1)) == 2 - @test length(ListRecordStateByIdIterator(rec, 2)) == 1 - @test length(ListRecordStateByIdIterator(rec, 3)) == 1 - - @test collect(ListRecordStateByIdIterator(rec, 1)) == [(1,1.0),(2,3.0)] - @test collect(ListRecordStateByIdIterator(rec, 2)) == [(1,2.0)] - @test collect(ListRecordStateByIdIterator(rec, 3)) == [(2,4.0)] - - @test get_time(rec, 1) == 0.0 - @test get_time(rec, 2) == 1.0 - @test get_time(rec, 3) == 2.0 - @test get_timestep(rec) == 1.0 - @test get_elapsed_time(rec, 1, 10) == 9*1.0 - - @test findfirst_frame_with_id(rec, 1) == 1 - @test findfirst_frame_with_id(rec, 3) == 2 - @test findfirst_frame_with_id(rec, 4) == nothing - - @test findlast_frame_with_id(rec, 2) == 1 - @test findlast_frame_with_id(rec, 3) == 2 - @test findlast_frame_with_id(rec, 4) == nothing - - @test length(ListRecordFrameIterator(rec)) == 2 - len = 0 - for frame in ListRecordFrameIterator(rec) - len += 1 - end - @test len == 2 - - subrec = get_subinterval(rec, 2, 2) - @test nframes(subrec) == 1 - @test get(subrec, 1, 1) == Entity(3.0, true, 1) - @test get(subrec, 3, 1) == Entity(4.0, false, 3) - end - - - @testset begin - rec = ListRecord(1.0, Float64, Bool, Int) - @test nframes(rec) == 0 - - scene = EntityFrame(Float64, Bool, Int, 2) - push!(scene, Entity(1.0, true, 1)) - push!(scene, Entity(2.0, true, 2)) - push!(rec, scene) - - @test nframes(rec) == 1 - - empty!(scene) - push!(scene, Entity(3.0, true, 1)) - push!(scene, Entity(4.0, false, 3)) - push!(rec, scene) - @test nframes(rec) == 2 - - @test get(rec, 1, 1) == Entity(1.0, true, 1) - @test get(rec, 1, 2) == Entity(3.0, true, 1) - @test get(rec, 2, 1) == Entity(2.0, true, 2) - @test get(rec, 3, 2) == Entity(4.0, false, 3) - end -end - -@testset "records_iterator" begin - q = QueueRecord(Int64, 100, 1., 10) - data = [[1,2,3],[4,5,6],[7,8,9]] - for d in data - update!(q, Frame(d)) - end - @test all([data[i] == [x[1], x[2], x[3]] for (i, x) in enumerate(q)]) - @test iterate(q,1) === nothing -end diff --git a/test/test_roadways.jl b/test/test_roadways.jl index 0fe5cf7..8465956 100644 --- a/test/test_roadways.jl +++ b/test/test_roadways.jl @@ -60,25 +60,6 @@ function get_test_roadway() roadway end -@testset "1d roadway" begin - roadway = StraightRoadway(20.0) - s = 10.0 - @test mod_position_to_roadway(s, roadway) == s - s = 25.0 - @test mod_position_to_roadway(s, roadway) == 5.0 - s = 45.0 - @test mod_position_to_roadway(s, roadway) == 5.0 - s = -5.0 - @test mod_position_to_roadway(s, roadway) == 15.0 - s_rear = 10.0 - s_fore = 15.0 - @test get_headway(s_rear, s_fore, roadway) == 5.0 - s_fore = 25.0 - @test get_headway(s_rear, s_fore, roadway) == 15.0 - s_fore = 5.0 - @test get_headway(s_rear, s_fore, roadway) == 15.0 -end - @testset "Curves" begin p = lerp(CurvePt(VecSE2(0.0,0.0,0.0), 0.0), CurvePt(VecSE2(1.0,2.0,3.0), 4.0), 0.25) show(IOBuffer(), p) diff --git a/test/test_simulation.jl b/test/test_simulation.jl index dcbad51..7c4ff5f 100644 --- a/test/test_simulation.jl +++ b/test/test_simulation.jl @@ -20,9 +20,6 @@ AutomotiveDrivingModels.run_callback(callback::WithActionCallback, scenes::Vecto n_steps = 40 dt = 0.1 - rec = QueueRecord(typeof(veh1), n_steps, dt) - @test_deprecated simulate!(rec, scene, roadway, models, n_steps) - @test_deprecated simulate!(scene, roadway, models, n_steps, dt) @inferred simulate(scene, roadway, models, n_steps, dt) reset_hidden_states!(models) @@ -35,15 +32,10 @@ AutomotiveDrivingModels.run_callback(callback::WithActionCallback, scenes::Vecto scene = Frame([veh1, veh2]) - rec = QueueRecord(typeof(veh1), n_steps, dt) - @test_deprecated simulate!(rec, scene, roadway, models, 10, (CollisionCallback(),)) - scenes = @inferred simulate(scene, roadway, models, n_steps, dt, callbacks=(CollisionCallback(),)) @test length(scenes) < 10 # make sure warnings, errors and deprecations in run_callback work as expected - @test_deprecated @test_throws MethodError simulate!(rec, scene, roadway, models, 10, (NoCallback(),)) - @test_deprecated simulate(scene, roadway, models, 10, .1, callbacks=(NoActionCallback(),)) @test_nowarn simulate(scene, roadway, models, 10, .1, callbacks=(WithActionCallback(),)) # collision right from start @@ -54,9 +46,6 @@ AutomotiveDrivingModels.run_callback(callback::WithActionCallback, scenes::Vecto scene = Frame([veh1, veh2]) - rec = QueueRecord(eltype(scene), n_steps, dt) - @test_deprecated simulate!(rec, scene, roadway, models, 10, (CollisionCallback(),)) - scenes = @inferred simulate(scene, roadway, models, n_steps, dt, callbacks=(CollisionCallback(),)) @test length(scenes) == 1 end @@ -68,8 +57,8 @@ end veh_state = VehicleState(Frenet(roadway[LaneTag(1,1)], 6.0), roadway, 10.) ego = Entity(veh_state, VehicleDef(), 2) model = ProportionalSpeedTracker() - dt = get_timestep(trajdata) - rec = QueueRecord(typeof(ego), 3, dt) - simulate!(rec, model, ego.id, trajdata, roadway, 1, 2) - @test findfirst(ego.id, rec[0]) != nothing + + scenes = simulate_from_history(model, roadway, trajdata, ego.id, 0.1) + + @test findfirst(ego.id, scenes[end]) != nothing end diff --git a/test/test_states.jl b/test/test_states.jl index 380bf6b..50f819c 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -1,42 +1,17 @@ function get_test_trajdata(roadway::Roadway) - trajdata = ListRecord(0.1, VehicleState, VehicleDef, Int64) - - trajdata.defs[1] = VehicleDef(AgentClass.CAR, 5.0, 3.0) - trajdata.defs[2] = VehicleDef(AgentClass.CAR, 5.0, 3.0) - - push!(trajdata.states, RecordState{VehicleState, Int64}(VehicleState(VecSE2(0.0,0.0,0.0), roadway, 10.0), 1)) # car 1, frame 1 - push!(trajdata.states, RecordState{VehicleState, Int64}(VehicleState(VecSE2(3.0,0.0,0.0), roadway, 20.0), 2)) # car 2, frame 1 - push!(trajdata.states, RecordState{VehicleState, Int64}(VehicleState(VecSE2(1.0,0.0,0.0), roadway, 10.0), 1)) # car 1, frame 2 - push!(trajdata.states, RecordState{VehicleState, Int64}(VehicleState(VecSE2(5.0,0.0,0.0), roadway, 20.0), 2)) # car 2, frame 2 - - push!(trajdata.frames, RecordFrame(1,2)) - push!(trajdata.frames, RecordFrame(3,4)) - - trajdata -end - -@testset "1d state" begin - s = State1D(0.0, 0.0) - path, io = mktemp() - write(io, MIME"text/plain"(), s) - close(io) - io = open(path) - s2 = read(io, MIME"text/plain"(), State1D) - close(io) - @test s == s2 - - veh = Vehicle1D(s, VehicleDef(), 1) - scene = Scene1D() - push!(scene, veh) - scene2 = Scene1D([veh]) - @test scene[1].state.s == 0.0 - @test first(scene2.entities) == first(scene.entities) - @test scene2.n == scene.n == 1 - @test get_center(veh) == 0.0 - @test get_footpoint(veh) == 0.0 - @test get_front(veh) == veh.def.length/2 - @test get_rear(veh) == - veh.def.length/2 - + scene1 = Frame([ + Entity(VehicleState(VecSE2(0.0,0.0,0.0), roadway, 10.0), VehicleDef(), 1), + Entity(VehicleState(VecSE2(3.0,0.0,0.0), roadway, 20.0), VehicleDef(), 2) + ] + ) + scene2 = Frame([ + Entity(VehicleState(VecSE2(1.0,0.0,0.0), roadway, 10.0), VehicleDef(), 1), + Entity(VehicleState(VecSE2(5.0,0.0,0.0), roadway, 20.0), VehicleDef(), 2) + ] + ) + + trajdata = [scene1, scene2] + return trajdata end @testset "VehicleState" begin @@ -47,8 +22,6 @@ end show(IOBuffer(), s) s = VehicleState(VecSE2(0.0,0.0,0.0), Frenet(NULL_ROADINDEX, 0.0, 0.0, 0.1), 10.0) - @test isapprox(get_vel_s(s), 10.0*cos(0.1)) - @test isapprox(get_vel_t(s), 10.0*sin(0.1)) @test isapprox(velf(s).s, 10.0*cos(0.1)) @test isapprox(velf(s).t, 10.0*sin(0.1)) @test isapprox(velg(s).x, 10.0) @@ -106,25 +79,24 @@ end veh3 = convert(Entity{VehicleState, VehicleDef, Int64}, veh2) @test veh3 == veh - rec = QueueRecord(typeof(veh), 1, 0.5) - rec = QueueRecord(typeof(veh), 1, 0.5, 10) + scene = Frame([veh, veh2, veh3]) io = IOBuffer() - show(io, rec) + show(io, scene) close(io) scene = Frame(typeof(veh)) - get!(scene, trajdata, 1) + copyto!(scene, trajdata[1]) @test length(scene) == 2 for (i,veh) in enumerate(scene) - @test scene[i].state == get_state(trajdata, i, 1) - @test scene[i].def == get_def(trajdata, i) + @test scene[i].state == trajdata[1][i].state + @test scene[i].def == trajdata[1][i].def end scene2 = Frame(deepcopy(scene.entities), 2) @test length(scene2) == 2 for (i,veh) in enumerate(scene2) - @test scene2[i].state == get_state(trajdata, i, 1) - @test scene2[i].def == get_def(trajdata, i) + @test scene2[i].state == trajdata[1][i].state + @test scene2[i].def == trajdata[1][i].def end @test get_by_id(scene, 1) == scene[1] @@ -135,17 +107,17 @@ end copyto!(scene2, scene) @test length(scene2) == 2 for (i,veh) in enumerate(scene2) - @test scene2[i].state == get_state(trajdata, i, 1) - @test scene2[i].def == get_def(trajdata, i) + @test scene2[i].state == trajdata[1][i].state + @test scene2[i].def == trajdata[1][i].def end delete!(scene2, scene2[1]) @test length(scene2) == 1 - @test scene2[1].state == get_state(trajdata, 2, 1) - @test scene2[1].def == get_def(trajdata, 2) + @test scene2[1].state == trajdata[1][2].state + @test scene2[1].def == trajdata[1][2].def scene2[1] = deepcopy(scene[1]) - @test scene2[1].state == get_state(trajdata, 1, 1) - @test scene2[1].def == get_def(trajdata, 1) + @test scene2[1].state == trajdata[1][1].state + @test scene2[1].def == trajdata[1][1].def @test findfirst(1, scene) == 1 @test findfirst(2, scene) == 2 @@ -156,217 +128,8 @@ end @test !in(3, scene) veh = scene[2] - @test veh.state == get_state(trajdata, 2, 1) - @test veh.def == get_def(trajdata, 2) - - push!(scene, get_state(trajdata, 1, 1)) -end - -@testset "trajdata" begin - roadway = get_test_roadway() - trajdata = get_test_trajdata(roadway) - - @test nframes(trajdata) == 2 - @test !frame_inbounds(trajdata, 0) - @test frame_inbounds(trajdata, 1) - @test frame_inbounds(trajdata, 2) - @test !frame_inbounds(trajdata, 3) - - @test n_objects_in_frame(trajdata, 1) == 2 - @test n_objects_in_frame(trajdata, 2) == 2 - - @test nth_id(trajdata, 1) == 1 - @test nth_id(trajdata, 1, 2) == 2 - @test nth_id(trajdata, 2, 1) == 1 - @test nth_id(trajdata, 2, 2) == 2 - - @test findfirst_frame_with_id(trajdata, 1) == 1 - @test findfirst_frame_with_id(trajdata, 2) == 1 - @test findfirst_frame_with_id(trajdata, -1) == nothing - @test findlast_frame_with_id(trajdata, 1) == 2 - @test findlast_frame_with_id(trajdata, 2) == 2 - @test findlast_frame_with_id(trajdata, -1) == nothing - - @test sort!(get_ids(trajdata)) == [1,2] - - @test in(1, trajdata, 1) - @test in(1, trajdata, 2) - @test in(2, trajdata, 1) - @test in(2, trajdata, 2) - @test !in(3, trajdata, 1) - - @test isapprox(get_time(trajdata, 1), 0.0) - @test isapprox(get_time(trajdata, 2), 0.1) - - @test isapprox(get_elapsed_time(trajdata, 1, 2), 0.1) - @test isapprox(get_elapsed_time(trajdata, 2, 1), -0.1) - - @test get_timestep(trajdata) == 0.1 - - veh = get(trajdata, 1, 1) - @test veh.state == VehicleState(VecSE2(0.0,0.0,0.0), roadway, 10.0) - @test_throws ArgumentError get(trajdata, 10, 1) - @test_throws BoundsError get(trajdata, 1, 10) - - let - iter = ListRecordStateByIdIterator(trajdata, 1) - items = collect(iter) # list of (frame_index, state) - @test length(items) == 2 - @test items[1][1] == 1 - @test items[1][2] == get_state(trajdata, 1, 1) - @test items[2][1] == 2 - @test items[2][2] == get_state(trajdata, 1, 2) - - iter = ListRecordStateByIdIterator(trajdata, 2) - items = collect(iter) - @test length(items) == 2 - @test items[1][1] == 1 - @test items[1][2] == get_state(trajdata, 2, 1) - @test items[2][1] == 2 - @test items[2][2] == get_state(trajdata, 2, 2) - end - - path, io = mktemp() - write(io, MIME"text/plain"(), trajdata) - close(io) - - io = open(path) - trajdata2 = read(io, MIME"text/plain"(), ListRecord{VehicleState, VehicleDef, Int64}) - close(io) - rm(path) - - @test nframes(trajdata2) == nframes(trajdata) - for i in 1 : nframes(trajdata2) - @test n_objects_in_frame(trajdata2, i) == n_objects_in_frame(trajdata, i) - for j in 1 : n_objects_in_frame(trajdata, i) - veh1 = get(trajdata, j, i) - veh2 = get(trajdata2, j, i) - @test veh1.id == veh2.id - @test veh1.def.class == veh2.def.class - @test isapprox(veh1.def.length, veh2.def.length) - @test isapprox(veh1.def.width, veh2.def.width) - - @test isapprox(veh1.state.v, veh2.state.v) - @test isapprox(veh1.state.posG, veh2.state.posG, atol=1e-3) - @test isapprox(veh1.state.posF.s, veh2.state.posF.s, atol=1e-3) - @test isapprox(veh1.state.posF.t, veh2.state.posF.t, atol=1e-3) - @test isapprox(veh1.state.posF.ϕ, veh2.state.posF.ϕ, atol=1e-6) - @test veh1.state.posF.roadind.tag == veh2.state.posF.roadind.tag - @test veh1.state.posF.roadind.ind.i == veh2.state.posF.roadind.ind.i - @test isapprox(veh1.state.posF.roadind.ind.t, veh2.state.posF.roadind.ind.t, atol=1e-3) - end - end - - trajdata3 = get_subinterval(trajdata2, 1, nframes(trajdata2)) - @test nframes(trajdata3) == nframes(trajdata2) - for i in 1 : nframes(trajdata3) - @test n_objects_in_frame(trajdata3, i) == n_objects_in_frame(trajdata2, i) - for j in 1 : n_objects_in_frame(trajdata2, i) - veh1 = get(trajdata2, j, i) - veh2 = get(trajdata3, j, i) - @test veh1.id == veh2.id - @test veh1.def.class == veh2.def.class - @test isapprox(veh1.def.length, veh2.def.length) - @test isapprox(veh1.def.width, veh2.def.width) - - @test isapprox(veh1.state.v, veh2.state.v) - @test isapprox(veh1.state.posG, veh2.state.posG, atol=1e-3) - @test isapprox(veh1.state.posF.s, veh2.state.posF.s, atol=1e-3) - @test isapprox(veh1.state.posF.t, veh2.state.posF.t, atol=1e-3) - @test isapprox(veh1.state.posF.ϕ, veh2.state.posF.ϕ, atol=1e-6) - @test veh1.state.posF.roadind.tag == veh2.state.posF.roadind.tag - @test veh1.state.posF.roadind.ind.i == veh2.state.posF.roadind.ind.i - @test isapprox(veh1.state.posF.roadind.ind.t, veh2.state.posF.roadind.ind.t, atol=1e-3) - end - end - - trajdata3 = get_subinterval(trajdata2, 1, 1) - @test nframes(trajdata3) == 1 - let - i = 1 - @test n_objects_in_frame(trajdata3, i) == n_objects_in_frame(trajdata2, i) - for j in 1 : n_objects_in_frame(trajdata2, i) - veh1 = get(trajdata2, j, i) - veh2 = get(trajdata3, j, i) - @test veh1.id == veh2.id - @test veh1.def.class == veh2.def.class - @test isapprox(veh1.def.length, veh2.def.length) - @test isapprox(veh1.def.width, veh2.def.width) - - @test isapprox(veh1.state.v, veh2.state.v) - @test isapprox(veh1.state.posG, veh2.state.posG, atol=1e-3) - @test isapprox(veh1.state.posF.s, veh2.state.posF.s, atol=1e-3) - @test isapprox(veh1.state.posF.t, veh2.state.posF.t, atol=1e-3) - @test isapprox(veh1.state.posF.ϕ, veh2.state.posF.ϕ, atol=1e-6) - @test veh1.state.posF.roadind.tag == veh2.state.posF.roadind.tag - @test veh1.state.posF.roadind.ind.i == veh2.state.posF.roadind.ind.i - @test isapprox(veh1.state.posF.roadind.ind.t, veh2.state.posF.roadind.ind.t, atol=1e-3) - end - end -end - -@testset "QueueRecord" begin - roadway = get_test_roadway() - trajdata = get_test_trajdata(roadway) - - Δt = 0.1 - rec = QueueRecord(Entity{VehicleState, VehicleDef, Int64}, 5, Δt) - @test capacity(rec) == 5 - @test nframes(rec) == 0 - @test !pastframe_inbounds(rec, 0) - @test !pastframe_inbounds(rec, -1) - @test !pastframe_inbounds(rec, 1) - - scene = get!(Frame(Entity{VehicleState, VehicleDef, Int64}), trajdata, 1) - update!(rec, scene) - @test nframes(rec) == 1 - @test pastframe_inbounds(rec, 0) - @test !pastframe_inbounds(rec, -1) - @test !pastframe_inbounds(rec, 1) - @test isapprox(get_elapsed_time(rec, 0), 0) - @test rec[0][1].state == get_state(trajdata, 1, 1) - @test rec[0][1].def == get_def(trajdata, 1) - @test rec[0][2].state == get_state(trajdata, 2, 1) - @test rec[0][2].def == get_def(trajdata, 2) - show(IOBuffer(), rec) - - - get!(scene, trajdata, 2) - update!(rec, scene) - @test nframes(rec) == 2 - @test pastframe_inbounds(rec, 0) - @test pastframe_inbounds(rec, -1) - @test !pastframe_inbounds(rec, 1) - @test isapprox(get_elapsed_time(rec, 0), 0) - @test isapprox(get_elapsed_time(rec, -1), Δt) - @test isapprox(get_elapsed_time(rec, -1, 0), Δt) - @test rec[0][1].state == get_state(trajdata, 1, 2) - @test rec[0][1].def == get_def(trajdata, 1) - @test rec[0][2].state == get_state(trajdata, 2, 2) - @test rec[0][2].def == get_def(trajdata, 2) - @test rec[-1][1].state == get_state(trajdata, 1, 1) - @test rec[-1][1].def == get_def(trajdata, 1) - @test rec[-1][2].state == get_state(trajdata, 2, 1) - @test rec[-1][2].def == get_def(trajdata, 2) - - scene2 = get!(Frame(Entity{VehicleState, VehicleDef, Int64}), rec) - @test scene2[1].state == get_state(trajdata, 1, 2) - @test scene2[1].def == get_def(trajdata, 1) - @test scene2[2].state == get_state(trajdata, 2, 2) - @test scene2[2].def == get_def(trajdata, 2) - - get!(scene2, rec, -1) - @test scene2[1].state == get_state(trajdata, 1, 1) - @test scene2[1].def == get_def(trajdata, 1) - @test scene2[2].state == get_state(trajdata, 2, 1) - @test scene2[2].def == get_def(trajdata, 2) - - empty!(rec) - @test nframes(rec) == 0 + @test veh.state == trajdata[1][2].state + @test veh.def == trajdata[1][2].def - test_veh_state = VehicleState(VecSE2(7.0,7.0,2.0), roadway, 10.0) - test_veh_def = VehicleDef(AgentClass.CAR, 5.0, 3.0) - test_veh = Entity(test_veh_state, test_veh_def, 999) - rec[-1][1] = test_veh - @test rec[-1][1].state == test_veh_state + push!(scene, trajdata[1][1].state) end