# Imports and global definitions

In [None]:
# full imports
import gc
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import os
import subprocess
import sys
import tensorflow as tf
import time

# partial imports
from enum import Enum
from matplotlib import cm, path
from PIL import Image

# custom imports
sys.path.append("../code/waymo-open-dataset/waymo_open_dataset/protos")  # replace with own path
import scenario_pb2
import map_pb2
from waymo_open_dataset import dataset_pb2 as open_dataset


# Constants etc

In [None]:
class AgentType(Enum):
    vehicle = 1
    pedestrian = 2
    cyclist = 3
    other = 4
    
class TrafficLightFamily(Enum):
    unknown = -1
    stop = 0
    caution = 1
    go = 2

LIST_AGENTS_TYPES = [AgentType.vehicle, AgentType.cyclist, AgentType.pedestrian]
AGENTS_TYPES_MARKERS = {AgentType.vehicle: '.', AgentType.pedestrian: '+', AgentType.cyclist: "x"}
TL_COLOR_MAPPING = {0: "grey", 1: "red", 2: "orange", 3: "green", 4: "red", 5: "orange", 6: "green", 7: "red", 8: "orange"}
FIGURE_SIZE = 22
FIGURE_DPI = 160
FILENAME = '/mnt/waymo/data/uncompressed_scenario_training_training.tfrecord-00000-of-01000'  # replace with own path
HALF_LANE_WIDTH = 12 * 0.305 / 2  # 12 feet in meters divided by 2
FAKE_LANE_POLYGON_LENGTH = 10  # 10 meters
MAX_DISTANCE_NEIGHBORING_POINT_LANE_POLYGON_COMPUTATION = 1.5  # 1.5 meter
MAX_FACTOR_LANE_WIDTH = 3
MIN_FACTOR_LANE_WIDTH = 0.3

# Utilities

In [None]:
def compute_2d_distance(A, B):
    """
    Computes the 2d distance between two points
    Args:
        A (MapPoint or List[int]): first point.
        B (MapPoint or List[int]): second point.
    """
    if type(A) == map_pb2.MapPoint:
        A = [A.x, A.y]
    if type(B) == map_pb2.MapPoint:
        B = [B.x, B.y]
        
    return np.sqrt((A[0] - B[0])**2 + (A[1] - B[1])**2)

def find_closest_point_in_polyline(interest_point, polyline):
    """
    Finds the point in a polyline that is closest to a given MapPoint
    Args:
        interest_point (MapPoint): map point to find the closest neighbor of.
        polyline (List[MapPoint] or List[List[int]]): List of MapPoint objects.
    """
    min_distance = np.inf
    closest_point_in_polyline = None   
    closest_point_in_polyline_index = None   
    for i, candidate in enumerate(polyline):
        candidate_distance = compute_2d_distance(candidate, interest_point)
        if candidate_distance < min_distance:
            closest_point_in_polyline = candidate
            min_distance = candidate_distance
            closest_point_in_polyline_index = i
    return closest_point_in_polyline, min_distance, closest_point_in_polyline_index

def compute_local_lane_directions(lane):
    """
    Computes the 2d lane directions between all points: using 2 points for extremities and 3 points for center points.
    Args:
        lane (LaneCenter): Lane to process.
        
    Returns:
        List[List[int]]: Local lane directions.
    """
    local_lane_directions = []
    
    if len(lane.lane.polyline) < 2:
        print("problem: lane too short")
        sys.exit(1)

    # For the first point, the local lane direction is computed the first and second point
    local_lane_direction_first = [lane.lane.polyline[1].x - lane.lane.polyline[0].x, lane.lane.polyline[1].y - lane.lane.polyline[0].y]
    local_lane_direction_first /= np.linalg.norm(local_lane_direction_first)
    local_lane_directions.append(local_lane_direction_first)

    # For middle points, the local lane direction is computed using the point, the point before and the point
    for lane_point_id in range(1, len(lane.lane.polyline)-1):
        local_lane_direction_before = [lane.lane.polyline[lane_point_id].x - lane.lane.polyline[lane_point_id-1].x, lane.lane.polyline[lane_point_id].y - lane.lane.polyline[lane_point_id-1].y]
        local_lane_direction_before /= np.linalg.norm(local_lane_direction_before)
        local_lane_direction_after = [lane.lane.polyline[lane_point_id+1].x - lane.lane.polyline[lane_point_id].x, lane.lane.polyline[lane_point_id+1].y - lane.lane.polyline[lane_point_id].y]
        local_lane_direction_after /= np.linalg.norm(local_lane_direction_after)
        local_lane_direction_average = [(local_lane_direction_before[0] + local_lane_direction_after[0]) / 2, (local_lane_direction_before[1] + local_lane_direction_after[1]) / 2]
        # TODO can be simplified to not use the center
        local_lane_directions.append(local_lane_direction_average)

    # For the last point, the local lane direction is computed the second-to-last and last point
    local_lane_direction_last = [lane.lane.polyline[-1].x - lane.lane.polyline[-2].x, lane.lane.polyline[-1].y - lane.lane.polyline[-2].y]
    local_lane_direction_last /= np.linalg.norm(local_lane_direction_last)
    local_lane_directions.append(local_lane_direction_last)

    return local_lane_directions
    
def compute_intersection_infinite_lines(P, n, B, v):
    """
    Intersection between 2 lines defined by a point and a vector (in 2d)
    Line A:
        x = P[0] + n[0] * t_A
        y = P[1] + n[1] * t_A
    Line B:
        x = B[0] + v[0] * t_B
        y = B[1] + v[1] * t_B

    t_B = (B[1] - P[1] - (n[1]/n[0])*(B[0]-P[0])) / ((n[1]/n[0]) * v[0] - v[1])
    
    Args:
        P (List[int]): Point in the first line.
        n (List[int]): Direction vector for the first line.
        B (List[int]): Point in the second line.
        v (List[int]): Direction vector for the second line.
        
    Returns:
        List[List[int]]: Intersection point.
    """

    # TODO add logic to catch exceptions

    t_B = (B[1] - P[1] - (n[1]/n[0])*(B[0]-P[0])) / ((n[1]/n[0]) * v[0] - v[1])
    x = B[0] + v[0] * t_B
    y = B[1] + v[1] * t_B
    return [x, y]

def project_point_on_line(P, A, B):
    """
    Projects point P on the line defined by 2 points
    
    Args:
        P (List[int]): Point to project.
        A (List[int]): First point of the line to project onto.
        B (List[int]): Second point of the line to project onto.
        
    Returns:
        List[int]: Projection of P onto line [A, B]
        
    """
    if type(P) == map_pb2.MapPoint:
        P = [P.x, P.y]
    if type(A) == map_pb2.MapPoint:
        A = [A.x, A.y]
    if type(B) == map_pb2.MapPoint:
        B = [B.x, B.y]
        
    P = np.array(P)
    A = np.array(A)
    B = np.array(B)

    d = np.dot(B - A, P - A) / np.linalg.norm(A - B)
    projected = A + d * (B - A) / np.linalg.norm(A - B)
    return projected


In [None]:
def plot_agents_at_all_steps(agent_states, agent_type):
    """
    TODO Could maybe be improved to get rid of the valids array by using similar logic as in plot_agents_at_one_step
    
    Plots all agents at all steps from the matrix where first dim is the agents and second dim is the timesteps.
    The agent_type is used to decide which marker to use in the scatter
    
    Args:
        agent_states (List[List[ObjectState]]): Matrix of agents states.
        agent_type (AgentType): type of the agent.
    """

    xs = []
    ys = []
    valids = []
    for agent_state in agent_states:
        xs.append([state.center_x for state in agent_state])
        ys.append([state.center_y for state in agent_state])
        valids.append([state.valid for state in agent_state])

    xs = np.array(xs)    
    ys = np.array(ys)    
    valids = np.array(valids)    
    colors_agents = []

    for j, valid in enumerate(valids):
        nbr_valid = sum(valid)
        colors_agents.extend([COLORS_ARRAY_AGENTS[j]]*nbr_valid)

    plt.scatter(xs[valids], ys[valids], color=colors_agents, marker=AGENTS_TYPES_MARKERS[agent_type])

    
def plot_agents_at_one_step(agent_states, agent_type):
    """
    Plots all agents at 1 step from the list of ObjectStates
    The agent_type is used to decide which marker to use in the scatter
    
    Args:
        agent_states (List[ObjectState]): List of agents states for this timestamp.
        agent_type (AgentType): type of the agent.
    """

    xs = [state.center_x for state in agent_states if state.valid]
    ys = [state.center_y for state in agent_states if state.valid]
    colors_agents = [COLORS_ARRAY_AGENTS[i] for i, state in enumerate(agent_states) if state.valid]
    plt.scatter(xs, ys, color=colors_agents, marker=AGENTS_TYPES_MARKERS[agent_type])    
    
def plot_road_edges(road_edges):
    """
    Plots road edges
    Args:
        road_edges (List[RoadEdge]): List of RoadEdge objects.
    """
    for road_edge in road_edges:
        road_edge_xs = [map_point.x for map_point in road_edge.road_edge.polyline]
        road_edge_ys = [map_point.y for map_point in road_edge.road_edge.polyline]

        plt.plot(road_edge_xs, road_edge_ys, color="black")

def plot_lanes(lanes, color="darkgrey"):
    """
    Plots lanes
    Args:
        lanes (List[LaneCenter]): List of LaneCenter objects.
        color (str): color to plot the lane
    """
    if type(lanes) != list:
        print("I expect a list of lanes")
        return
    
    for lane in lanes:
        lane_xs = [map_point.x for map_point in lane.lane.polyline]
        lane_ys = [map_point.y for map_point in lane.lane.polyline]

        plt.plot(lane_xs, lane_ys, "--", color=color, alpha=0.5)

def plot_traffic_light_stops_points(traffic_light_stops_points, colors=None):
    """
    Plots the traffic lights and, if given, their associated color. If no color is given, they are displayed as grey.
    Args:
        traffic_light_stops_points (List[MapPoint]): List of stop points associated with each TL
        colors (List[???]): List of colors associated with each TL.
    """
    if not colors:
        colors = ["grey"] * len(traffic_light_stops_points)
    for traffic_light_stops_point, color in zip(traffic_light_stops_points, colors):
        plt.scatter(traffic_light_stops_point.x, traffic_light_stops_point.y, marker="d", color=color)
        
def plot_crosswalks(crosswalks):
    """
    Plots the crosswalks as polygons
    Args:
        crosswalks (List[Crosswalk]): List of Crosswalk objects
    """
    for crosswalk in crosswalks:
        crosswalk_xs = [map_point.x for map_point in crosswalk.crosswalk.polygon]
        crosswalk_ys = [map_point.y for map_point in crosswalk.crosswalk.polygon]

        crosswalk_xs.append(crosswalk_xs[0])  # closing the polygon
        crosswalk_ys.append(crosswalk_ys[0])  # closing the polygon

        plt.plot(crosswalk_xs, crosswalk_ys, color="orange")

def plot_road_lines(road_lines):
    """
    Plots road lines
    Args:
        road_lines (List[RoadLine]): List of RoadLine objects.
    """
    for road_line in road_lines:
        road_line_xs = [map_point.x for map_point in road_line.road_line.polyline]
        road_line_ys = [map_point.y for map_point in road_line.road_line.polyline]

        plt.plot(road_line_xs, road_line_ys, color="dimgrey", alpha=0.5)

def plot_stop_signs(stop_signs, lanes):
    """
    Plots stop signs on the lane with which they are associated
    
    Args:
        stop_signs (List[StopSign]): List of StopSign objects.
        lanes (List[LaneCenter]): List of the LaneCenter objects.
    """
    # generating the lane dict
    lanes_dict = {lane.id: lane for lane in lanes}

    for stop_sign in stop_signs:
        # find the point in the lanes associated with the stop sign where cars need to stop
        # it is a hypothesis since we are not given the actual line
        for associated_lane_id in stop_sign.stop_sign.lane:
            polyline = lanes_dict[associated_lane_id].lane.polyline
            closest, _, _ = find_closest_point_in_polyline(stop_sign.stop_sign.position, polyline)
            plt.scatter(closest.x, closest.y, marker="H", color="red", s=150)

def plot_speed_bumps(speed_bumps):
    """
    Plots the speed bumps as polygons
    Args:
        speed_bumps (List[SpeedBump]): List of SpeedBump objects
    """
    for speed_bump in speed_bumps:
        speed_bump_xs = [map_point.x for map_point in speed_bump.speed_bump.polygon]
        speed_bump_ys = [map_point.y for map_point in speed_bump.speed_bump.polygon]

        speed_bump_xs.append(speed_bump_xs[0])  # closing the polygon
        speed_bump_ys.append(speed_bump_ys[0])  # closing the polygon

        plt.plot(speed_bump_xs, speed_bump_ys, "--", color="black")
        
        
def plot_scenario_image(dict_agents_states, crosswalks, lanes, road_edges, road_lines, speed_bumps, stop_signs, scenario_idx, save=False):
    """
    Generates an image representing all the positions of all the agents and all the map features for the scenario
    Args:
        dict_agents_states (Dict[AgentType: List[List[ObjectState]]]): dict of agents states matrices
        crosswalks (List[Crosswalk]): crosswalks 
        lanes (List[LaneCenter]): lanes
        road_edges (List[RoadEdge]): road edges
        road_lines (List[RoadLine]): road lines
        speed_bumps (List[SpeedBump]): speed bumps
        stop_signs (List[StopSign]): stop signs
        scenario_idx (int): id of the scenario in the file
        save (bool): whether to save to a file or not
    """
    fig = plt.figure(figsize=(FIGURE_SIZE, FIGURE_SIZE), dpi=FIGURE_DPI)

    if len(dict_agents_states[AgentType.vehicle]):
        plot_agents_at_all_steps(dict_agents_states[AgentType.vehicle],  AgentType.vehicle)    
    if len(dict_agents_states[AgentType.pedestrian]):
        plot_agents_at_all_steps(dict_agents_states[AgentType.pedestrian], AgentType.pedestrian)    
    if len(dict_agents_states[AgentType.cyclist]):
        plot_agents_at_all_steps(dict_agents_states[AgentType.cyclist], AgentType.cyclist)    

    plot_crosswalks(crosswalks)
    plot_lanes(lanes)
    plot_road_edges(road_edges)
    plot_road_lines(road_lines)
    plot_speed_bumps(speed_bumps)
    plot_stop_signs(stop_signs, lanes)
    plot_traffic_light_stops_points(traffic_light_stop_points)

    plt.axis("equal")
    if save:
        plt.savefig(f"scenario_images/{scenario_idx}.png")

    plt.close()

# Functions with logic    

In [None]:
def find_traffic_lights_transitions(start_state, end_state, traffic_lights_stop_points_dict, dynamic_map_states):
    """
    Finds the traffic lights that are changing from a state to another.
    
    Args:
        start_state (TrafficLightFamily): state before the transition
        stop_state (TrafficLightFamily): state after the transition
        traffic_lights_stop_points_dict (Dict{int: MapPoint}): dict of traffic lights ids and their associated stop point
        dynamic_map_states (List[DynamicState]): list of the dynamic states to search transitions in
    """
        
    unknown_states = [0]  # TrafficLightFamily.unknown
    stop_states = [1, 4, 7]  # TrafficLightFamily.stop
    caution_states = [2, 5, 8]  # TrafficLightFamily.caution
    go_states = [3, 6]  # TrafficLightFamily.go

    traffic_light_lanes = traffic_lights_stop_points_dict.keys()
    dict_traffic_light_states_current = {}
    dict_traffic_light_states_previous = {}
    for key in traffic_light_lanes:
        dict_traffic_light_states_current[key] = None 
        dict_traffic_light_states_previous[key] = None 

    # detect transitions from go to caution
    for step in dynamic_map_states:
        for lane_state in step.lane_states:
            traffic_light_lane = lane_state.lane
            current_state = None
            if lane_state.state in stop_states:
                current_state = TrafficLightFamily.stop
            elif lane_state.state in caution_states:
                current_state = TrafficLightFamily.caution
            elif lane_state.state in go_states:
                current_state = TrafficLightFamily.go
            elif lane_state.state in unknown_states:
                current_state = TrafficLightFamily.unknown
            else:
                print("ERROR")

            dict_traffic_light_states_previous[traffic_light_lane] = dict_traffic_light_states_current[traffic_light_lane]
            dict_traffic_light_states_current[traffic_light_lane] = current_state

        for lane in dict_traffic_light_states_current.keys():
            if dict_traffic_light_states_current[lane] == end_state and dict_traffic_light_states_previous[lane] == start_state:
                print(f"Transition: {start_state} to {end_state}")
                break

In [None]:
def count_traffic_lights_transitions(traffic_lights_stop_points_dict, dynamic_map_states):
    """
    Counts how many of each type of transitions we have
    
    Args:
        traffic_lights_stop_points_dict (Dict{int: MapPoint}): dict of traffic lights ids and their associated stop point
        dynamic_map_states (List[DynamicState]): list of the dynamic states to search transitions in
        
    Returns:
        np.array(int, int): the matrix representing the counts of all state transitions
    """
        
    transition_matrix = np.zeros((9, 9))

    traffic_light_lanes = traffic_lights_stop_points_dict.keys()
    dict_traffic_light_states_current = {}
    dict_traffic_light_states_previous = {}
    for key in traffic_light_lanes:
        dict_traffic_light_states_current[key] = 0 
        dict_traffic_light_states_previous[key] = 0 

    # detect transitions from go to caution
    for step in dynamic_map_states:
        for lane_state in step.lane_states:
            traffic_light_lane = lane_state.lane

            dict_traffic_light_states_previous[traffic_light_lane] = dict_traffic_light_states_current[traffic_light_lane]
            dict_traffic_light_states_current[traffic_light_lane] = lane_state.state

        for lane in dict_traffic_light_states_current.keys():
            if dict_traffic_light_states_current[lane] != dict_traffic_light_states_previous[lane]:
                transition_matrix[dict_traffic_light_states_previous[lane]][dict_traffic_light_states_current[lane]] += 1
    
    dict_names = {
        0: "LANE_STATE_UNKNOWN",
        1: "LANE_STATE_ARROW_STOP",
        2: "LANE_STATE_ARROW_CAUTION",
        3: "LANE_STATE_ARROW_GO",
        4: "LANE_STATE_STOP",
        5: "LANE_STATE_CAUTION",
        6: "LANE_STATE_GO",
        7: "LANE_STATE_FLASHING_STOP",
        8: "LANE_STATE_FLASHING_CAUTION",
    }

    for i in range(0, 9):
        for j in range(0, 9):
            if i != j:
                print(f"Transition {dict_names[i]} to {dict_names[j]}: {transition_matrix[i][j]}")
        print("\n")

    return transition_matrix
    

In [None]:
                
def generate_video(dict_agents_states, timestamps_seconds, dynamic_map_states, crosswalks, lanes, road_edges, road_lines, speed_bumps, stop_signs, scenario_idx):
    """
    Generates a video where each frame is a plot of all agents and all map features at each timestamp
    Args:
        dict_agents_states (Dict[AgentType: List[List[ObjectState]]]): dict of agents states matrices
        timestamps_seconds (List[double]): list of timestamps
        dynamic_map_states (List[DynamicMapState]): list of the dynamic map states at all timestamps 
        crosswalks (List[Crosswalk]): crosswalks 
        lanes (List[LaneCenter]): lanes
        road_edges (List[RoadEdge]): road edges
        road_lines (List[RoadLine]): road lines
        speed_bumps (List[SpeedBump]): speed bumps
        stop_signs (List[StopSign]): stop signs
        scenario_idx (int): id of the scenario in the file
    """
    if not os.path.exists(f"scenario_videos/{scenario_idx}"):
        os.mkdir(f"scenario_videos/{scenario_idx}")

    
    lanes_with_associated_tl = traffic_lights_stop_points_dict.keys()
    dict_all_states = {key: [] for key in lanes_with_associated_tl}
    for dynamic_map_state in dynamic_map_states:
        lanes_updated = []
        for state in dynamic_map_state.lane_states:
            lanes_updated.append(state.lane)
            dict_all_states[state.lane].append(state.state)
        traffic_lights_not_seen_at_this_step = lanes_with_associated_tl - lanes_updated
        for lane in traffic_lights_not_seen_at_this_step:
            dict_all_states[lane].append(0)  # adding unknown manually
    
    for i, timestamp in enumerate(timestamps_seconds):
        fig = plt.figure(figsize=(FIGURE_SIZE, FIGURE_SIZE), dpi=FIGURE_DPI)
        
        # plotting things
        for agent_type in LIST_AGENTS_TYPES:
            if dict_agents_states[agent_type]:
                plot_agents_at_one_step(np.array(dict_agents_states[agent_type])[:,i], agent_type)
        plot_crosswalks(crosswalks)
        plot_lanes(lanes)
        plot_road_edges(road_edges)
        plot_road_lines(road_lines)
        plot_speed_bumps(speed_bumps)
        plot_stop_signs(stop_signs, lanes)
        
        # plotting traffic lights
        traffic_light_stops_points = [traffic_lights_stop_points_dict[lane] for lane in dict_all_states.keys()]
        colors = [TL_COLOR_MAPPING[dict_all_states[lane][i]] for lane in dict_all_states.keys()]
        plot_traffic_light_stops_points(traffic_light_stops_points, colors)

        plt.axis("equal")
        plt.savefig(f"scenario_videos/{scenario_idx}/{i:02d}.png")
        plt.close()


    if os.path.exists(f"scenario_videos/{scenario_idx}/video_{scenario_idx}.mp4"):
        os.remove(f"scenario_videos/{scenario_idx}/video_{scenario_idx}.mp4")    
    ffmpeg_cmd = f"ffmpeg -framerate 10 -i scenario_videos/{scenario_idx}/%02d.png scenario_videos/{scenario_idx}/video_{scenario_idx}.mp4".split()
    subprocess.call(ffmpeg_cmd)

In [None]:
def generate_polygon_for_lane_first_approach(lane, road_edges_dict, road_lines_dict, plot=False):
    """
    Generate a polygon that is my approximation of the lane boundaries, with the following logic:
    for each point from the lane
        if has a boundary
            we compute the closest point from the boudnary and set is as the boundary for this point
        if not
            we compute the normal point at distance D
    
    Args:
        lane (LaneCenter): lane to generate the polygon for.
        road_edges_dict (Dict{int: List[RoadEdge]}): dict of road edges that can be used as boundaries.
        road_lines_dict (Dict{int: List[RoadLine]}): dict of road lines that can be used as boundaries.
        plot (bool): whether to generate a plot.
    
    Returns:
        path.Path: the path representing the computed polygon boundary for the lane
    """
        
    # computing left polygon points for points where we have a boundary
    left_polygon_points_dict = {}
    for left_boundary in lane.lane.left_boundaries:
        lane_start_index = left_boundary.lane_start_index
        lane_end_index = left_boundary.lane_end_index
        
        if left_boundary.boundary_feature_id in road_edges_dict.keys():
            boundary_polyline = road_edges_dict[left_boundary.boundary_feature_id].road_edge.polyline
        elif left_boundary.boundary_feature_id in road_lines_dict.keys():
            boundary_polyline = road_lines_dict[left_boundary.boundary_feature_id].road_line.polyline
        else:
            print("problem")
        
        for lane_point_id in range(lane_start_index, lane_end_index+1):
            
            min_distance = np.inf
            closest = None
            for candidate in boundary_polyline:
                candidate_distance = np.sqrt((candidate.x - lane.lane.polyline[lane_point_id].x)**2 + (candidate.y - lane.lane.polyline[lane_point_id].y)**2)
                if candidate_distance < min_distance:
                    closest = candidate
                    min_distance = candidate_distance
                         
            left_polygon_points_dict[lane_point_id] = [closest.x, closest.y]
    
    # computing left polygon points for points where we don't have a boundary
    for lane_point_id in range(len(lane.lane.polyline)):
        if not lane_point_id in left_polygon_points_dict.keys():
            # we compute the local normal
            if lane_point_id == 0:
                local_lane_direction = [lane.lane.polyline[1].x - lane.lane.polyline[0].x, lane.lane.polyline[1].y - lane.lane.polyline[0].y]
            else:
                local_lane_direction = [lane.lane.polyline[lane_point_id].x - lane.lane.polyline[lane_point_id-1].x, lane.lane.polyline[lane_point_id].y - lane.lane.polyline[lane_point_id-1].y]
            
            normal_left = np.array([-local_lane_direction[1], local_lane_direction[0]])
            normal_left /= np.linalg.norm(normal_left)
                        
            left_polygon_points_dict[lane_point_id] = [lane.lane.polyline[lane_point_id].x + HALF_LANE_WIDTH * normal_left[0], lane.lane.polyline[lane_point_id].y + HALF_LANE_WIDTH * normal_left[1]]
                    
        
    # computing right polygon points for points where we have a boundary
    right_polygon_points_dict = {}
    for right_boundary in lane.lane.right_boundaries:
        lane_start_index = right_boundary.lane_start_index
        lane_end_index = right_boundary.lane_end_index

        if right_boundary.boundary_feature_id in road_edges_dict.keys():
            boundary_polyline = road_edges_dict[right_boundary.boundary_feature_id].road_edge.polyline
        elif right_boundary.boundary_feature_id in road_lines_dict.keys():
            boundary_polyline = road_lines_dict[right_boundary.boundary_feature_id].road_line.polyline
        else:
            print("problem")

        for lane_point_id in range(lane_start_index, lane_end_index+1):
               
            min_distance = np.inf
            closest = None
            for candidate in boundary_polyline:
                candidate_distance = np.sqrt((candidate.x - lane.lane.polyline[lane_point_id].x)**2 + (candidate.y - lane.lane.polyline[lane_point_id].y)**2)
                if candidate_distance < min_distance:
                    closest = candidate
                    min_distance = candidate_distance
                         
            right_polygon_points_dict[lane_point_id] = [closest.x, closest.y]
    
    # computing right polygon points for points where we don't have a boundary
    for lane_point_id in range(len(lane.lane.polyline)):
        if not lane_point_id in right_polygon_points_dict.keys():
            # we compute the local normal
            if lane_point_id == 0:
                local_lane_direction = [lane.lane.polyline[1].x - lane.lane.polyline[0].x, lane.lane.polyline[1].y - lane.lane.polyline[0].y]
            else:
                local_lane_direction = [lane.lane.polyline[lane_point_id].x - lane.lane.polyline[lane_point_id-1].x, lane.lane.polyline[lane_point_id].y - lane.lane.polyline[lane_point_id-1].y]
            normal_right = np.array([local_lane_direction[1], -local_lane_direction[0]])
            normal_right /= np.linalg.norm(normal_right)
            right_polygon_points_dict[lane_point_id] = [lane.lane.polyline[lane_point_id].x + HALF_LANE_WIDTH * normal_right[0], lane.lane.polyline[lane_point_id].y + HALF_LANE_WIDTH * normal_right[1]]
                
    # creating a path
    all_boundary_points = []
    left_polygon_points_dict = dict(sorted(left_polygon_points_dict.items(), key=lambda item: item[0]))
    right_polygon_points_dict = dict(sorted(right_polygon_points_dict.items(), key=lambda item: item[0]))
    all_boundary_points.extend(left_polygon_points_dict.values())
    all_boundary_points.extend(list(right_polygon_points_dict.values())[::-1])
    all_boundary_points.append(all_boundary_points[0])
    lane_boundaries_path = path.Path(all_boundary_points)
    
    # plotting
    if plot:
        fig = plt.figure(figsize=(FIGURE_SIZE, FIGURE_SIZE), dpi=FIGURE_DPI)
        ax = fig.add_subplot(111)
        lane_xs = [map_point.x for map_point in lane.lane.polyline]  # TODO replace by function
        lane_ys = [map_point.y for map_point in lane.lane.polyline]
        plt.plot(lane_xs, lane_ys, "-", color="red", alpha=0.5)

        all_boundaries = []
        all_boundaries.extend(lane.lane.left_boundaries)
        all_boundaries.extend(lane.lane.right_boundaries)
        for boundary in all_boundaries:
            if boundary.boundary_feature_id in road_edges_dict.keys():
                boundary_polyline = road_edges_dict[boundary.boundary_feature_id].road_edge.polyline
            elif boundary.boundary_feature_id in road_lines_dict.keys():
                boundary_polyline = road_lines_dict[boundary.boundary_feature_id].road_line.polyline
            else:
                print("problem")
            edge_xs = [map_point.x for map_point in boundary_polyline]  # TODO replace by function
            edge_ys = [map_point.y for map_point in boundary_polyline]
            plt.plot(edge_xs, edge_ys, "-", color="b")

        patch = patches.PathPatch(lane_boundaries_path, facecolor="orange", lw=2, alpha=0.5)
        ax.add_patch(patch)

        plt.axis("equal")
        plt.show()

    return lane_boundaries_path


In [None]:
def generate_polygon_for_lane(lane, road_edges_dict, road_lines_dict, lane_start_index=None, lane_end_index=None, plot=False):
    """
    Generate a polygon that is my approximation of the lane boundaries, with the following logic:
    for each point from the lane
        if has a boundary
            we search for 2 points from the boundary that are on either side of the normal at this point
            if we found them
                the polygon point is the intersection of the normal and the line defined by these two points
            else
                the polygon point is the projection of the closest boundary point on the normal
            if the computed polygon point is not realistic (too close or far, we replace it)
                we "unset" the polygon point
        if not
            the polygon point is the normal point at distance D       
    
    Args:
        lane (LaneCenter): lane to generate the polygon for.
        road_edges_dict (Dict{int: List[RoadEdge]}): dict of road edges that can be used as boundaries.
        road_lines_dict (Dict{int: List[RoadLine]}): dict of road lines that can be used as boundaries.
        lane_state_index (int): index of the point of the lane where we start generating a polygon.
        lane_end_index (int): index of the point of the lane where we stop generating a polygon.
        plot (bool): whether to generate a plot or not.
    
    Returns:
        path.Path: the path representing the computed polygon boundary for the lane
    """
    # Sanity checks
    if lane_start_index and lane_start_index < 0:
        print(f"problem: lane_start_index smaller than 0: {lane_start_index}")
        sys.exit(1)
    if lane_end_index and lane_end_index > len(lane.lane.polyline):
        print(f"problem: lane_end_index larger than lane length: {lane_end_index}")
        sys.exit(1)
        
    # Setting defaults    
    if not lane_start_index:
        lane_start_index = 0
    if not lane_end_index:
        lane_end_index = len(lane.lane.polyline)
        
    left_polygon_points_dict = {}
    right_polygon_points_dict = {}

    
    # Step 0: pre-computing local lane directions
    local_lane_directions = compute_local_lane_directions(lane)
    
    # Step 1: computing left polygon points for points where we have a boundary
    for left_boundary in lane.lane.left_boundaries:
        left_boundary_lane_start_index = left_boundary.lane_start_index
        left_boundary_lane_end_index = left_boundary.lane_end_index
        
        if left_boundary.boundary_feature_id in road_edges_dict.keys():
            boundary_polyline = road_edges_dict[left_boundary.boundary_feature_id].road_edge.polyline
        elif left_boundary.boundary_feature_id in road_lines_dict.keys():
            boundary_polyline = road_lines_dict[left_boundary.boundary_feature_id].road_line.polyline
        else:
            print("problem: left boundary is neither an road_edge nor a road_line")
            sys.exit(1)
        
        effective_start_index = max(lane_start_index, left_boundary_lane_start_index)
        effective_end_index = min(lane_end_index, left_boundary_lane_end_index) + 1
        for lane_point_id in range(effective_start_index, effective_end_index):
            current_lane_point = lane.lane.polyline[lane_point_id]
            current_lane_point_as_list = [current_lane_point.x, current_lane_point.y]

            # Step 1.1: we compute the local lane direction and normal going left
            local_lane_direction = local_lane_directions[lane_point_id]
            local_lane_normal = np.array([-local_lane_direction[1], local_lane_direction[0]])
            point_on_normal = np.array(current_lane_point_as_list) + np.array(local_lane_normal)

            # Step 1.3: we iterate on the boundary points to find boundary points on both sides of the local normal
            points_left_of_normal = []
            points_right_of_normal = []
            for boundary_point in boundary_polyline:
                direction_to_boundary_point = [boundary_point.x - current_lane_point.x, boundary_point.y - current_lane_point.y]
                direction_to_boundary_point /= np.linalg.norm(direction_to_boundary_point)
                # logic from https://stackoverflow.com/questions/2150050/finding-signed-angle-between-vectors
                formula_direction = np.arctan2(local_lane_normal[0]*direction_to_boundary_point[1] - local_lane_normal[1]*direction_to_boundary_point[0], local_lane_normal[0]*direction_to_boundary_point[0] + local_lane_normal[1]*direction_to_boundary_point[1]);
                if np.sign(formula_direction) == 1:
                    points_left_of_normal.append([boundary_point.x, boundary_point.y])
                elif np.sign(formula_direction) == -1:
                    points_right_of_normal.append([boundary_point.x, boundary_point.y])
                
            # Step 1.4: if we found such points, we use them to interpolate the polygon point
            if points_left_of_normal and points_right_of_normal:
                # Step 1.4.1: we find the closest point on the left to the current lane point (TODO maybe we should do closest point to the normal)
                closest_point_left_of_normal = None
                closest_point_left_of_normal_distance = np.inf
                for candidate in points_left_of_normal:
                    candidate_distance = compute_2d_distance(current_lane_point, candidate)
                    if candidate_distance < closest_point_left_of_normal_distance:
                        closest_point_left_of_normal = candidate
                        closest_point_left_of_normal_distance = candidate_distance
                # Step 1.4.2: we find the closest point on the right to the current lane point (TODO maybe we should do closest point to the normal)   
                closest_point_right_of_normal = None
                closest_point_right_of_normal_distance = np.inf
                for candidate in points_right_of_normal:
                    candidate_distance = compute_2d_distance(current_lane_point, candidate)
                    if candidate_distance < closest_point_right_of_normal_distance:
                        closest_point_right_of_normal = candidate
                        closest_point_right_of_normal_distance = candidate_distance    
                # Step 1.4.3: we check if the closest_point_left_of_normal and closest_point_right_of_normal are close enough, to avoid some edge cases
                if compute_2d_distance(closest_point_left_of_normal, closest_point_right_of_normal) > MAX_DISTANCE_NEIGHBORING_POINT_LANE_POLYGON_COMPUTATION:
                    continue
                # Step 1.4.4: compute the intersection
                vec_between_boundary_points = [closest_point_right_of_normal[0] - closest_point_left_of_normal[0], closest_point_right_of_normal[1] - closest_point_left_of_normal[1]]
                vec_between_boundary_points /= np.linalg.norm(vec_between_boundary_points)
                candidate_polygon_point = compute_intersection_infinite_lines(current_lane_point_as_list, local_lane_normal, closest_point_left_of_normal, vec_between_boundary_points)
                
            # Step 1.5: if not, we interpolate the polygon point using the closest boundary point
            else:
                # Step 1.5.1: we find the closest point in the boundary to the current_lane_point
                closest_point_in_boundary, _, _ = find_closest_point_in_polyline(current_lane_point, boundary_polyline)
                # Step 1.5.2: we project this point on the local normal
                candidate_polygon_point = project_point_on_line(closest_point_in_boundary, current_lane_point_as_list, point_on_normal)
        
            # Step 1.6: we check if the computed candidate is at a reasonable distance from the current_map_point
            if compute_2d_distance(candidate_polygon_point, current_lane_point) <= MAX_FACTOR_LANE_WIDTH * HALF_LANE_WIDTH and compute_2d_distance(candidate_polygon_point, current_lane_point) >= MIN_FACTOR_LANE_WIDTH * HALF_LANE_WIDTH:
                left_polygon_points_dict[lane_point_id] = candidate_polygon_point
        
    # Step 2: computing left polygon points for points where we don't have a boundary
    effective_end_index = min(lane_end_index+1, len(lane.lane.polyline))
    for lane_point_id in range(lane_start_index, effective_end_index):
        if not lane_point_id in left_polygon_points_dict.keys():
            local_lane_direction = local_lane_directions[lane_point_id]
            local_lane_normal = np.array([-local_lane_direction[1], local_lane_direction[0]])
            local_lane_normal /= np.linalg.norm(local_lane_normal)
                        
            left_polygon_points_dict[lane_point_id] = [lane.lane.polyline[lane_point_id].x + HALF_LANE_WIDTH * local_lane_normal[0], lane.lane.polyline[lane_point_id].y + HALF_LANE_WIDTH * local_lane_normal[1]]
            
    # Step 3: computing right polygon points for points where we have a boundary
    for right_boundary in lane.lane.right_boundaries:
        right_boundary_lane_start_index = right_boundary.lane_start_index
        right_boundary_lane_end_index = right_boundary.lane_end_index
        
        if right_boundary.boundary_feature_id in road_edges_dict.keys():
            boundary_polyline = road_edges_dict[right_boundary.boundary_feature_id].road_edge.polyline
        elif right_boundary.boundary_feature_id in road_lines_dict.keys():
            boundary_polyline = road_lines_dict[right_boundary.boundary_feature_id].road_line.polyline
        else:
            print("problem: right boundary is neither an road_edge nor a road_line")
            sys.exit(1)
        
        effective_start_index = max(lane_start_index, right_boundary_lane_start_index)
        effective_end_index = min(lane_end_index, right_boundary_lane_end_index) + 1
        for lane_point_id in range(effective_start_index, effective_end_index):
            current_lane_point = lane.lane.polyline[lane_point_id]
            current_lane_point_as_list = [current_lane_point.x, current_lane_point.y]

            # Step 3.1: we compute the local lane direction and normal going right
            local_lane_direction = local_lane_directions[lane_point_id]
            local_lane_normal = np.array([local_lane_direction[1], -local_lane_direction[0]])
            point_on_normal = np.array(current_lane_point_as_list) + np.array(local_lane_normal)

            # Step 3.3: we iterate on the boundary points to find boundary points on both sides of the local normal
            points_left_of_normal = []
            points_right_of_normal = []
            for boundary_point in boundary_polyline:
                direction_to_boundary_point = [boundary_point.x - current_lane_point.x, boundary_point.y - current_lane_point.y]
                direction_to_boundary_point /= np.linalg.norm(direction_to_boundary_point)
                # logic from https://stackoverflow.com/questions/2150050/finding-signed-angle-between-vectors
                formula_direction = np.arctan2(local_lane_normal[0]*direction_to_boundary_point[1] - local_lane_normal[1]*direction_to_boundary_point[0], local_lane_normal[0]*direction_to_boundary_point[0] + local_lane_normal[1]*direction_to_boundary_point[1]);
                if np.sign(formula_direction) == 1:
                    points_left_of_normal.append([boundary_point.x, boundary_point.y])
                elif np.sign(formula_direction) == -1:
                    points_right_of_normal.append([boundary_point.x, boundary_point.y])
                
            # Step 3.4: if we found such points, we use them to interpolate the polygon point
            if points_left_of_normal and points_right_of_normal:
                # Step 3.4.1 we find the closest point on the left to the current lane point (TODO maybe we should do closest point to the normal)
                closest_point_left_of_normal = None
                closest_point_left_of_normal_distance = np.inf
                for candidate in points_left_of_normal:
                    candidate_distance = compute_2d_distance(current_lane_point, candidate)
                    if candidate_distance < closest_point_left_of_normal_distance:
                        closest_point_left_of_normal = candidate
                        closest_point_left_of_normal_distance = candidate_distance
                # Step 3.4.2 we find the closest point on the right to the current lane point (TODO maybe we should do closest point to the normal)  
                closest_point_right_of_normal = None
                closest_point_right_of_normal_distance = np.inf
                for candidate in points_right_of_normal:
                    candidate_distance = compute_2d_distance(current_lane_point, candidate)
                    if candidate_distance < closest_point_right_of_normal_distance:
                        closest_point_right_of_normal = candidate
                        closest_point_right_of_normal_distance = candidate_distance
                # Step 3.4.3: we check if the closest_point_left_of_normal and closest_point_right_of_normal are close enough, to avoid some edge cases
                if compute_2d_distance(closest_point_left_of_normal, closest_point_right_of_normal) > MAX_DISTANCE_NEIGHBORING_POINT_LANE_POLYGON_COMPUTATION:
                    continue
                # Step 3.4.4 compute the intersection
                vec_between_boundary_points = [closest_point_right_of_normal[0] - closest_point_left_of_normal[0], closest_point_right_of_normal[1] - closest_point_left_of_normal[1]]
                vec_between_boundary_points /= np.linalg.norm(vec_between_boundary_points)
                candidate_polygon_point = compute_intersection_infinite_lines(current_lane_point_as_list, local_lane_normal, closest_point_left_of_normal, vec_between_boundary_points)

            # Step 3.5: if not, we interpolate the polygon point using the closest boundary point
            else:
                # Step 3.5.1: we find the closest point in the boundary to the current_lane_point
                closest_point_in_boundary, _, _ = find_closest_point_in_polyline(current_lane_point, boundary_polyline)
                # Step 3.5.2: we project this point on the local normal
                candidate_polygon_point = project_point_on_line(closest_point_in_boundary, current_lane_point_as_list, point_on_normal)
        
            # Step 3.6: we check if the computed candidate is at a reasonable distance from the current_map_point
            if compute_2d_distance(candidate_polygon_point, current_lane_point) <= MAX_FACTOR_LANE_WIDTH * HALF_LANE_WIDTH and compute_2d_distance(candidate_polygon_point, current_lane_point) >= MIN_FACTOR_LANE_WIDTH * HALF_LANE_WIDTH:
                right_polygon_points_dict[lane_point_id] = candidate_polygon_point
        
    # Step 4: computing right polygon points for points where we don't have a boundary
    effective_end_index = min(lane_end_index+1, len(lane.lane.polyline))
    for lane_point_id in range(lane_start_index, effective_end_index):
        if not lane_point_id in right_polygon_points_dict.keys():
            local_lane_direction = local_lane_directions[lane_point_id]
            local_lane_normal = np.array([local_lane_direction[1], -local_lane_direction[0]])
            local_lane_normal /= np.linalg.norm(local_lane_normal)
                        
            right_polygon_points_dict[lane_point_id] = [lane.lane.polyline[lane_point_id].x + HALF_LANE_WIDTH * local_lane_normal[0], lane.lane.polyline[lane_point_id].y + HALF_LANE_WIDTH * local_lane_normal[1]]
                
    # creating a path
    all_boundary_points = []
    left_polygon_points_dict = dict(sorted(left_polygon_points_dict.items(), key=lambda item: item[0]))
    right_polygon_points_dict = dict(sorted(right_polygon_points_dict.items(), key=lambda item: item[0]))
    all_boundary_points.extend(left_polygon_points_dict.values())
    all_boundary_points.extend(list(right_polygon_points_dict.values())[::-1])
    all_boundary_points.append(all_boundary_points[0])
    lane_boundaries_path = path.Path(all_boundary_points)
    
    # plotting
    if plot:
        fig = plt.figure(figsize=(FIGURE_SIZE, FIGURE_SIZE), dpi=FIGURE_DPI)
        ax = fig.add_subplot(111)
        plot_lanes([lane], "red")

        all_boundaries = []
        all_boundaries.extend(lane.lane.left_boundaries)
        all_boundaries.extend(lane.lane.right_boundaries)
        for boundary in all_boundaries:
            if boundary.boundary_feature_id in road_edges_dict.keys():
                boundary_polyline = road_edges_dict[boundary.boundary_feature_id].road_edge.polyline
            elif boundary.boundary_feature_id in road_lines_dict.keys():
                boundary_polyline = road_lines_dict[boundary.boundary_feature_id].road_line.polyline
            else:
                print("problem: boundary is neither an road_edge nor a road_line")
                sys.exit(1)
            edge_xs = [map_point.x for map_point in boundary_polyline]  # TODO replace by function
            edge_ys = [map_point.y for map_point in boundary_polyline]
            plt.plot(edge_xs, edge_ys, "-", color="b")

        patch = patches.PathPatch(lane_boundaries_path, facecolor="orange", alpha=0.5, lw=2)
        ax.add_patch(patch)

        plt.axis("equal")
        plt.show()
        plt.close()

    return lane_boundaries_path


In [None]:
def generate_fake_predecessor_polygon(lane, plot=False):
    """
    When a lane has no predecessor, we can compute a fake predecessor and generate a polygon for it.
    Args:
        lane (LaneCenter): lane center object to generate a fake predecessor polygon for.
        plot (bool): Whether to generate a plot or not.
    """
        
    # Step 1: find the normals to the lane direction at the start
    local_lane_direction = [lane.lane.polyline[1].x - lane.lane.polyline[0].x, lane.lane.polyline[1].y - lane.lane.polyline[0].y]
    local_lane_direction /= np.linalg.norm(local_lane_direction)
    normal_left = np.array([-local_lane_direction[1], local_lane_direction[0]])
    normal_left /= np.linalg.norm(local_lane_direction)
    normal_right = np.array([local_lane_direction[1], -local_lane_direction[0]])
    normal_right /= np.linalg.norm(normal_right)
    
    # Step 2: generate the 4 corners of the polygon
    corner_1 = [lane.lane.polyline[0].x + HALF_LANE_WIDTH * normal_left[0], lane.lane.polyline[0].y + HALF_LANE_WIDTH * normal_left[1]]
    corner_2 = [lane.lane.polyline[0].x + HALF_LANE_WIDTH * normal_right[0], lane.lane.polyline[0].y + HALF_LANE_WIDTH * normal_right[1]]
    corner_3 = [corner_2[0] - FAKE_LANE_POLYGON_LENGTH * local_lane_direction[0], corner_2[1] - FAKE_LANE_POLYGON_LENGTH * local_lane_direction[1]]
    corner_4 = [corner_1[0] - FAKE_LANE_POLYGON_LENGTH * local_lane_direction[0], corner_1[1] - FAKE_LANE_POLYGON_LENGTH * local_lane_direction[1]]
    corners = [corner_1, corner_2, corner_3, corner_4, corner_1]

    # Step 3: generate the polygon object
    fake_predecessor_polygon = path.Path(corners)
    
    if plot:
        fig = plt.figure(figsize=(FIGURE_SIZE, FIGURE_SIZE), dpi=FIGURE_DPI)
        ax = fig.add_subplot(111)
        plot_lanes([lane])
        plt.scatter(corner_1[0], corner_1[1])
        plt.scatter(corner_2[0], corner_2[1])
        plt.scatter(corner_3[0], corner_3[1])
        plt.scatter(corner_4[0], corner_4[1])
        patch = patches.PathPatch(fake_predecessor_polygon, facecolor="orange", lw=2)
        ax.add_patch(patch)

        plt.axis("equal")
        plt.show()
        plt.close()

    return fake_predecessor_polygon

In [None]:
def generate_fake_successor_polygon(lane, plot=False):
    """
    When a lane has no successor, we can compute a fake successor and generate a polygon for it
    Args:
        lane (LaneCenter): lane center object to generate a fake successor polygon for.
        plot (bool): Whether to generate a plot or not.
    """
        
    # Step 1: find the normals to the lane direction at the start
    local_lane_direction = [lane.lane.polyline[-1].x - lane.lane.polyline[-2].x, lane.lane.polyline[-1].y - lane.lane.polyline[-2].y]
    local_lane_direction /= np.linalg.norm(local_lane_direction)
    normal_left = np.array([-local_lane_direction[1], local_lane_direction[0]])
    normal_left /= np.linalg.norm(local_lane_direction)
    normal_right = np.array([local_lane_direction[1], -local_lane_direction[0]])
    normal_right /= np.linalg.norm(normal_right)
    
    # Step 2: generate the 4 corners of the polygon
    corner_1 = [lane.lane.polyline[-1].x + HALF_LANE_WIDTH * normal_left[0], lane.lane.polyline[-1].y + HALF_LANE_WIDTH * normal_left[1]]
    corner_2 = [lane.lane.polyline[-1].x + HALF_LANE_WIDTH * normal_right[0], lane.lane.polyline[-1].y + HALF_LANE_WIDTH * normal_right[1]]
    corner_3 = [corner_2[0] + FAKE_LANE_POLYGON_LENGTH * local_lane_direction[0], corner_2[1] + FAKE_LANE_POLYGON_LENGTH * local_lane_direction[1]]
    corner_4 = [corner_1[0] + FAKE_LANE_POLYGON_LENGTH * local_lane_direction[0], corner_1[1] + FAKE_LANE_POLYGON_LENGTH * local_lane_direction[1]]
    
    corners = [corner_1, corner_2, corner_3, corner_4, corner_1]

    # Step 3: generate the polygon object
    fake_successor_polygon = path.Path(corners)
    
    if plot:
        fig = plt.figure(figsize=(FIGURE_SIZE, FIGURE_SIZE), dpi=FIGURE_DPI)
        ax = fig.add_subplot(111)
        plot_lanes([lane])
        plt.scatter(corner_1[0], corner_1[1])
        plt.scatter(corner_2[0], corner_2[1])
        plt.scatter(corner_3[0], corner_3[1])
        plt.scatter(corner_4[0], corner_4[1])
        patch = patches.PathPatch(fake_predecessor_polygon, facecolor="orange", lw=2)
        ax.add_patch(patch)

        plt.axis("equal")
        plt.show()
        plt.close()

    return fake_successor_polygon

In [None]:
def find_agents_going_though_traffic_light(tl_lane_id, traffic_light_stop_point, lanes_dict, agents_states_dicts, road_edges_dict, road_lines_dict, plot=False):
    """
    Find all agents that are going through a traffic light, i.e. that are found in the polygon before as well as in the polygon after
    Args:
        tl_lane_id (int): Id of the lane associated with the TL.
        traffic_light_stop_point (MapPoint): Stop point associated with the TL.
        lanes_dict (Dict{int: LaneCenter}): Dictionary of lanes.
        agents_states_dicts (List[Dict{int: ObjectState}]): List of dictionnaries of ObjectState for each agent.
        road_edges_dict (Dict{int: RoadEdge}): RoadEdges used for polygon computation.
        road_lines_dict (Dict{int: RoadLines}): RoadLines used for polygon computation.
        plot (bool): Whether to generate a plot or not.
    """   
    
    # Step 1: generate the set of lanes that are associated with the TL using the input and output lanes
    lanes_before_stop_point = []
    lanes_after_stop_point = []
    lanes_ids_before_stop_point = []
    lanes_ids_after_stop_point = []
    
    # find the lane associated with the TL
    lane = lanes_dict[tl_lane_id]
    
    # checking if the stop point is the first point of the lane:
    if traffic_light_stop_point.x == lane.lane.polyline[0].x and traffic_light_stop_point.y == lane.lane.polyline[0].y:
        for entry_lane in lane.lane.entry_lanes:
            lanes_before_stop_point.append(lanes_dict[entry_lane])
            lanes_ids_before_stop_point.append(entry_lane)
        lanes_after_stop_point.append(lane)
        lanes_ids_after_stop_point.append(lane.id)
        traffic_light_direction = [lane.lane.polyline[1].x - lane.lane.polyline[0].x, lane.lane.polyline[1].y - lane.lane.polyline[0].y]

    # checking if the stop point is the last point of the lane:
    if traffic_light_stop_point.x == lane.lane.polyline[-1].x and traffic_light_stop_point.y == lane.lane.polyline[-1].y:
        lanes_before_stop_point.append(lane)
        lanes_ids_before_stop_point.append(lane.id)
        for exit_lane in lane.lane.exit_lanes:
            lanes_after_stop_point.append(lanes_dict[exit_lane])
            lanes_ids_after_stop_point.append(exit_lane)
        traffic_light_direction = [lane.lane.polyline[-1].x - lane.lane.polyline[-2].x, lane.lane.polyline[-1].y - lane.lane.polyline[-2].y] 
      
    # Step 2: Compute the paths associated with the lanes
    lane_polygons_dict_before = {}
    lane_polygons_dict_after = {}
    for lane_id, lane in zip(lanes_ids_before_stop_point, lanes_before_stop_point):
        lane_polygons_dict_before[lane_id] = generate_polygon_for_lane(lane, road_edges_dict, road_lines_dict)
    for lane_id, lane in zip(lanes_ids_after_stop_point, lanes_after_stop_point):
        lane_polygons_dict_after[lane_id] = generate_polygon_for_lane(lane, road_edges_dict, road_lines_dict)
    
    # Step 3: if there are no lane before, we need to fake the lane polygon
    if not lanes_before_stop_point:
        # there is only one lane_before in this case, so we compute a single polygon for it
        fake_predecessor_polygon = generate_fake_predecessor_polygon(lanes_after_stop_point[0], False)
        lane_polygons_dict_before[lanes_ids_after_stop_point[0]] = fake_predecessor_polygon
    
    # Step 4: if there are no lane after, we need to fake lane polygon
    if not lanes_after_stop_point:
        # there is only one lane_after in this case, so we compute a single polygon for it
        fake_successor_polygon = generate_fake_successor_polygon(lanes_before_stop_point[0], False)
        lane_polygons_dict_after[lanes_ids_before_stop_point[0]] = fake_successor_polygon
    
    # Step 5: find agents that are associated with any of these lanes
    agents_who_crossed = []
    agents_ids_who_crossed = []
    for agents_states_dict in agents_states_dicts:
        agent_id = list(agents_states_dict.keys())[0]
        agent_states = agents_states_dict[agent_id]

        
        found_before = False
        found_after = False
        for agent_state in agent_states:
            
            if agent_state.center_x != 0:

                # we compute the distance between the agent and the TL
                distance = compute_2d_distance(traffic_light_stop_point, [agent_state.center_x, agent_state.center_y])
                # we do not consider points where the agent is too far from the light
                if distance > 10:
                    continue

                for lane_polygon_before in lane_polygons_dict_before.values():
                    if lane_polygon_before.contains_points([(agent_state.center_x, agent_state.center_y)]):
                        found_before = True
                        break
                for lane_polygon_after in lane_polygons_dict_after.values():
                    if lane_polygon_after.contains_points([(agent_state.center_x, agent_state.center_y)]):
                        found_after = True
                        break

        crossed_the_tl = found_before and found_after
        if crossed_the_tl:
            agents_who_crossed.append(agent_states)
            agents_ids_who_crossed.append(agent_id)

    # Plotting
    if plot and agents_who_crossed:
        fig = plt.figure(figsize=(FIGURE_SIZE, FIGURE_SIZE), dpi=FIGURE_DPI)
        ax = fig.add_subplot(111)

        for lane_polygon_before in lane_polygons_dict_before.values(): 
            patch = patches.PathPatch(lane_polygon_before, facecolor="orange", lw=2, alpha=0.5)
            ax.add_patch(patch)
        for lane_polygon_after in lane_polygons_dict_after.values():
            patch = patches.PathPatch(lane_polygon_after, facecolor="blue", lw=2, alpha=0.5)
            ax.add_patch(patch)

        plot_lanes(lanes_before_stop_point, "red")
        plot_lanes(lanes_after_stop_point, "red")

        plt.scatter(traffic_light_stop_point.x, traffic_light_stop_point.y)

        for agent_who_crossed in agents_who_crossed:
            agent_xs = [map_point.center_x for map_point in agent_who_crossed if map_point.center_x != 0]
            agent_ys = [map_point.center_y for map_point in agent_who_crossed if map_point.center_y != 0]
            plt.scatter(agent_xs, agent_ys, color="k")

        plt.axis("equal")
        plt.show()
        plt.close()
    
    return agents_ids_who_crossed, agents_who_crossed
    


In [None]:
def find_agents_going_though_stop_sign(stop_sign, agents_states_dicts, road_edges_dict, road_lines_dict, plot=False):
    """
    Find all agents that are going through a stop, i.e. that are found in the polygon before as well as in the polygon after
    Args:
        stop_sign (StopSign): StopSign object to process.
        agents_states_dicts (List[Dict{int: ObjectState}]): List of dictionnaries of ObjectState for each agent.
        road_edges_dict (Dict{int: RoadEdge}): RoadEdges used for polygon computation
        road_lines_dict (Dict{int: RoadLines}): RoadLines used for polygon computation
        plot (bool): Whether to generate a plot or not.
    """  
    
    # we iterate over the multiple lanes a stop sign can be associated with
    for associated_lane_id in stop_sign.stop_sign.lane:
        associated_lane = lanes_dict[associated_lane_id]
        # Step 1: find the point where the stop sign happens
        polyline = associated_lane.lane.polyline
        stop_point, _, stop_point_index = find_closest_point_in_polyline(stop_sign.stop_sign.position, polyline)
    
        associated_lane_polygon = generate_polygon_for_lane(associated_lane, road_edges_dict, road_lines_dict)
        
        # Step 2: generate the before and after polygons 
        if stop_point_index == 0:  # if the stop_point is at the beginning of the lane
            # the successor_polygon of the stop point is just the lane associated with the stop sign
            successor_polygon = associated_lane_polygon
            # we need to generate a fake predecessor_polygon
            predecessor_polygon = generate_fake_predecessor_polygon(associated_lane, False)
        elif stop_point_index == len(polyline) - 1:  # if the stop_point is at the end of the lane
            # the predecessor_polygon of the stop point is just the lane associated with the stop sign
            predecessor_polygon = associated_lane_polygon
            # we need to generate a fake successor_polygon
            successor_polygon = generate_fake_successor_polygon(associated_lane, False)
        else:
            # if the stop point is in the middle of the lane, we generate a polygon
            predecessor_polygon = generate_polygon_for_lane(associated_lane, road_edges_dict, road_lines_dict, 0, stop_point_index)
            successor_polygon = generate_polygon_for_lane(associated_lane, road_edges_dict, road_lines_dict, stop_point_index, len(associated_lane.lane.polyline))
        
        # Step 3: find all vehicles that are found in the before and after polygons
        agents_who_crossed = []
        agents_ids_who_crossed = []
        for agents_states_dict in agents_states_dicts:
            agent_id = list(agents_states_dict.keys())[0]
            agent_states = agents_states_dict[agent_id]
            
            found_before = False
            found_after = False
            for agent_state in agent_states:

                if agent_state.center_x != 0:

                    # we compute the distance between the agent and the stop point
                    distance = compute_2d_distance(stop_point, [agent_state.center_x, agent_state.center_y])
                    # we do not consider points where the agent is too far from the stop sign
                    if distance > 10:
                        continue
                    if predecessor_polygon.contains_points([(agent_state.center_x, agent_state.center_y)]):
                        found_before = True
                    if successor_polygon.contains_points([(agent_state.center_x, agent_state.center_y)]):
                        found_after = True
                    if found_before and found_after:
                        break
        
            crossed_the_stop_point = found_before and found_after
            if crossed_the_stop_point:
                agents_who_crossed.append(agent_states)
                agents_ids_who_crossed.append(agent_id)
        
        # Step 4: compute the speed at the point
        if crossed_the_stop_point:
            for agent_who_crossed in agents_who_crossed:
                # Step 4.1: find the point in the agent trajectory closest to the stop point
                agent_positions = [[map_point.center_x, map_point.center_y] for map_point in agent_who_crossed if map_point.center_x != 0]
                _, _, closest_point_index = find_closest_point_in_polyline(stop_point, agent_positions)
                # Step 4.2: compute the 2d speed at each point
                agent_speeds = [np.linalg.norm([map_point.velocity_x, map_point.velocity_y]) for map_point in agent_who_crossed if map_point.center_x != 0]
                
                # plotting
                fig = plt.figure(figsize=(FIGURE_SIZE/2, FIGURE_SIZE/2), dpi=FIGURE_DPI)
                for i, agent_speed in enumerate(agent_speeds):
                    plt.scatter(i, agent_speed, color="black", marker="+")
                plt.vlines([closest_point_index], 0, max(agent_speeds))
                plt.xlabel("time step")
                plt.ylabel("speed (m/s)")
                plt.show()
                plt.close()
        
        if plot and crossed_the_stop_point:
            fig = plt.figure(figsize=(FIGURE_SIZE, FIGURE_SIZE), dpi=FIGURE_DPI)
            ax = fig.add_subplot(111)
        
            plt.scatter(stop_point.x, stop_point.y)
            plt.scatter(stop_sign.stop_sign.position.x, stop_sign.stop_sign.position.y, marker="H", color="red", s=150)
            plot_lanes([associated_lane], "red")
            patch_predecessor = patches.PathPatch(predecessor_polygon, facecolor="orange", lw=2, alpha=0.5)
            patch_successor = patches.PathPatch(successor_polygon, facecolor="blue", lw=2, alpha=0.5)
            ax.add_patch(patch_predecessor)
            ax.add_patch(patch_successor)
            
            for agent_who_crossed in agents_who_crossed:
                agent_xs = [map_point.center_x for map_point in agent_who_crossed if map_point.center_x != 0]
                agent_ys = [map_point.center_y for map_point in agent_who_crossed if map_point.center_y != 0]
                plt.scatter(agent_xs, agent_ys, color="k")
            
            plt.axis("equal")
            plt.show()
            plt.close()
    
    

# Main loop

In [None]:
dataset = tf.data.TFRecordDataset(FILENAME, compression_type='')

In [None]:
for scenario_idx, data in enumerate(dataset):
    
    print(f"Processing scenario: {scenario_idx}")
    scenario = scenario_pb2.Scenario()
    scenario.ParseFromString(bytearray(data.numpy()))

    # setting the agents tracks
    dict_agents_states = {AgentType.vehicle: [], AgentType.pedestrian: [], AgentType.cyclist: [], AgentType.other: []}
    for agent_tracks in scenario.tracks:
        dict_agents_states[AgentType(agent_tracks.object_type)].append([state for state in agent_tracks.states])

    # generating a dict of dicts for the agent tracks
    dict_dict_agents_states = {AgentType.vehicle: [], AgentType.pedestrian: [], AgentType.cyclist: [], AgentType.other: []}
    
    for agent_tracks in scenario.tracks:
        dict_dict_agents_states[AgentType(agent_tracks.object_type)].append({agent_tracks.id: [state for state in agent_tracks.states]})

    # We create an array of evenly spaced random colors of the length the number of agents in the most populated class
    max_agents_one_class = max([len(agents) for agents in list(dict_agents_states.values())])
    cmap = cm.get_cmap("jet", max_agents_one_class)
    COLORS_ARRAY_AGENTS = cmap(range(max_agents_one_class))
    np.random.shuffle(COLORS_ARRAY_AGENTS)   
        
    # setting the static map elements
    crosswalks =  [map_feature for map_feature in scenario.map_features if map_feature.crosswalk.polygon]
    lanes =       [map_feature for map_feature in scenario.map_features if map_feature.lane.type]
    road_edges =  [map_feature for map_feature in scenario.map_features if map_feature.road_edge.type]
    road_lines =  [map_feature for map_feature in scenario.map_features if map_feature.road_line.type]
    speed_bumps = [map_feature for map_feature in scenario.map_features if map_feature.speed_bump.polygon]
    stop_signs =  [map_feature for map_feature in scenario.map_features if map_feature.stop_sign.lane]
        
    # constructing a dictionary for some static map elements
    lanes_dict =      {lane.id: lane for lane in lanes}
    road_edges_dict = {road_edge.id: road_edge for road_edge in road_edges}
    road_lines_dict = {road_line.id: road_line for road_line in road_lines}
            
    # setting the traffic lights
    traffic_lights_stop_points_dict = {}
    for step in scenario.dynamic_map_states:
        for lane_state in step.lane_states:
            if not lane_state.lane in traffic_lights_stop_points_dict.keys():
                traffic_lights_stop_points_dict[lane_state.lane] = lane_state.stop_point
    traffic_light_stop_points = traffic_lights_stop_points_dict.values()
    
    # Plotting and save an image representing the scenario
    plot_scenario_image(dict_agents_states, crosswalks, lanes, road_edges, road_lines, speed_bumps, stop_signs, scenario_idx, save=True)

    # Generating a video representing the scenario
    generate_video(dict_agents_states, scenario.timestamps_seconds, scenario.dynamic_map_states, crosswalks, lanes, road_edges, road_lines, speed_bumps, stop_signs, scenario_idx)

    # Generating the lane polygons for the first few lanes
    for i, lane in enumerate(lanes[0:5]):
        print(f"lane {i}")
        generate_polygon_for_lane(lane, road_edges_dict, road_lines_dict, plot=True)
        

    # Finding traffic lights that change from a "go" state to a "caution" state
    find_traffic_lights_transitions(TrafficLightFamily.go, TrafficLightFamily.caution, traffic_lights_stop_points_dict, scenario.dynamic_map_states)

    # Counting traffic lights transitions
    transition_matrix = count_traffic_lights_transitions(traffic_lights_stop_points_dict, scenario.dynamic_map_states)
    
    # Finding agents that cross traffic lights
    for tl_lane_id, traffic_light_stop_point in traffic_lights_stop_points_dict.items():
        print(f"TL: {tl_lane_id}")
        who_crossed_ids, _ = find_agents_going_though_traffic_light(tl_lane_id, traffic_light_stop_point, lanes_dict, dict_dict_agents_states[AgentType.vehicle], road_edges_dict, road_lines_dict, True)        
        print(f"Who crossed: {who_crossed_ids}")
    
    # Finding agents that cross stop signs
    for stop_sign in stop_signs:
        print(f"Stop sign: {stop_sign.id}")    
        find_agents_going_though_stop_sign(stop_sign, dict_dict_agents_states[AgentType.vehicle], road_edges_dict, road_lines_dict, True)
    
    break

    