In [2]:
from typing import Tuple
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses
import jax
import jax.numpy as jnp
from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax import visualization
import cv2

max_num_objects = 32
config = dataclasses.replace(
    _config.WOD_1_1_0_TRAINING,
    max_num_objects=max_num_objects,
    max_num_rg_points=30000,
    path="./data/training_tfexample.tfrecord@5",
)
data_iter = dataloader.simulator_state_generator(config=config)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)

def construct_SDC_route(
    state,
):
    """Construct a SDC route from the logged trajectory. This is neccessary for the progression metric as WOMD doesn't release their routes.
    Args:
        state: The simulator state.
    Returns:
        The updated simulator state with the SDC route.
    """
    # Calculate arc lengths (cumulative distances along the trajectory)
    # Select sdc trajectory
    sdc_trajectory: datatypes.Trajectory = datatypes.select_by_onehot(
        state.log_trajectory,
        state.object_metadata.is_sdc,
        keepdims=True,
    )
    x = sdc_trajectory.x
    y = sdc_trajectory.y
    z = sdc_trajectory.z

    # Downsample trajectory coordinates by keeping every 3rd point
    stride = 10

    # Get downsampled coordinates
    x_downsampled = x[..., ::stride]
    y_downsampled = y[..., ::stride]
    z_downsampled = z[..., ::stride]

    # Check if last point needs to be added
    num_points = x.shape[-1]
    last_included = (num_points - 1) % stride == 0

    x = jnp.concatenate([x_downsampled, x[..., -1:]], axis=-1)
    y = jnp.concatenate([y_downsampled, y[..., -1:]], axis=-1)
    z = jnp.concatenate([z_downsampled, z[..., -1:]], axis=-1)

    # Calculate differences between consecutive points
    dx = jnp.diff(x, axis=-1)
    dy = jnp.diff(y, axis=-1)

    # Calculate Euclidean distance for each step
    step_distances = jnp.sqrt(dx**2 + dy**2)

    # Calculate cumulative distances
    arc_lengths = jnp.zeros_like(x)
    arc_lengths = arc_lengths.at[..., 1:].set(jnp.cumsum(step_distances, axis=-1))

    logged_route = datatypes.Paths(
        x=x,
        y=y,
        z=z,
        valid=jnp.array([[True] * len(x[0])]),
        arc_length=arc_lengths,
        on_route=jnp.array([[True]]),
        ids=jnp.array([[0] * len(x)]),  # Dummy ID
    )
    return dataclasses.replace(
        state,
        sdc_paths=logged_route,
    )


def ray_segment_intersection(
    ray_angle: jax.Array, start_points: jax.Array, segment_dirs: jax.Array
) -> jax.Array:
    """
    Calculate the intersection distances between a ray and line segments.

    Args:
        ray_angle: The angle of the ray (in radians).
        start_points: Array of shape (N, 2) for segment start points (x,y).
        segment_dirs: Array of shape (N, 2) for segment directions (dx, dy).

    Returns:
        Array of distances from origin to intersections. Returns 100.0 if no intersection.
    """
    # Calculate ray direction
    ray_dir_x = jnp.cos(ray_angle)
    ray_dir_y = jnp.sin(ray_angle)

    # Calculate segment direction
    segment_dir_x = segment_dirs[:, 0]
    segment_dir_y = segment_dirs[:, 1]

    # Calculate determinant for intersection test
    det = segment_dir_x * ray_dir_y - segment_dir_y * ray_dir_x

    # Avoid division by zero for parallel lines
    is_parallel = jnp.abs(det) < 1e-8
    det = jnp.where(is_parallel, 1.0, det)  # Avoid division by zero

    # Calculate t1 and t2 parameters
    t1 = -(start_points[:, 0] * ray_dir_y - start_points[:, 1] * ray_dir_x) / det
    t2 = (start_points[:, 0] * segment_dir_y - start_points[:, 1] * segment_dir_x) / det

    # Check if intersection is within segment (0 <= t1 <= 1) and ray (t2 >= 0)
    valid_t1 = (t1 >= 0.0) & (t1 <= 1.0)
    valid_t2 = t2 >= 0.0

    # Combine intersection validity checks
    valid_intersection = valid_t1 & valid_t2 & ~is_parallel

    # Calculate intersection points and distances
    ix = start_points[:, 0] + t1 * segment_dir_x
    iy = start_points[:, 1] + t1 * segment_dir_y

    # Distance from origin to intersection point
    distances = jnp.sqrt(ix**2 + iy**2)

    # Make sure distances over 100 are not valid
    valid_intersection = valid_intersection & (distances < 100.0)

    # Return distance if valid intersection, otherwise 100.0
    return jnp.where(valid_intersection, distances, 100.0)


def circogram_subroutine(
    i: int,
    initval: Tuple[
        jax.Array, jax.Array, jax.Array, Tuple[jax.Array, jax.Array], jax.Array
    ],
):
    (
        circogram,
        winning_indices,
        ray_angles,
        (starting_points, dir_xy),
        candidate_mask,
    ) = initval
    ray_angle = ray_angles[i]

    # Calculate intersection distances
    intersection_distances = ray_segment_intersection(
        ray_angle, starting_points, dir_xy
    )

    # Only consider specified segments
    masked_distances = jnp.where(candidate_mask, intersection_distances, 100.0)

    # Find minimum distance and index among candidate segments
    min_distance = jnp.min(masked_distances)
    # Use argmin, handle case where min is 100.0 (no valid hit)
    winning_idx = jnp.argmin(masked_distances)
    # If min_distance is 100.0, set index to -1
    winning_idx = jnp.where(min_distance >= 100.0, -1, winning_idx)

    # Update circogram ray and winning index
    circogram = circogram.at[i].set(min_distance)
    winning_indices = winning_indices.at[i].set(winning_idx)

    return (
        circogram,
        winning_indices,
        ray_angles,
        (starting_points, dir_xy),
        candidate_mask,
    )


def create_road_circogram(
    observation: datatypes.Observation, num_rays: int
) -> Tuple[jax.Array, jax.Array, jax.Array]:
    """Calculates the distances to the nearest road edge along rays.

    Args:
        observation: The observation data containing roadgraph information.
        num_rays: The number of rays to cast for the circogram.

    Returns:
        A tuple containing:
            - circogram: Array of distances to the nearest road edge for each ray.
            - ray_radial_speed: Array of radial speeds (always 0 for static road edges).
            - ray_tangential_speed: Array of tangential speeds (always 0 for static road edges).
    """
    ray_angles = jnp.linspace(0, 2 * jnp.pi, num_rays, endpoint=False)
    circogram = jnp.full(num_rays, 100.0)  # Default max distance
    winning_segment_indices = jnp.full(
        num_rays, -1, dtype=jnp.int32
    )  # Initialize winning indices

    rg_points = observation.roadgraph_static_points
    candidate_mask = rg_points.valid
    candidate_mask = candidate_mask & (
        (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_BOUNDARY)
        | (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_MEDIAN)
        | (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_UNKNOWN)
    )
    # Create line segments from roadgraph points
    starting_points = jnp.stack([rg_points.x, rg_points.y], axis=1)
    dir_xy = jnp.stack([rg_points.dir_x, rg_points.dir_y], axis=1)
    line_segments = (starting_points, dir_xy)

    # Run loop with subroutine
    (circogram, _, _, _, _) = jax.lax.fori_loop(
        0,
        num_rays,
        circogram_subroutine,
        (circogram, winning_segment_indices, ray_angles, line_segments, candidate_mask),
    )

    # Road edges are static, so their velocities are always zero.
    # We return zero speeds for consistency with create_object_circogram's output signature.
    ray_radial_speed = jnp.zeros(num_rays)
    ray_tangential_speed = jnp.zeros(num_rays)

    return circogram, ray_radial_speed, ray_tangential_speed


def create_object_circogram(
    observation: datatypes.Observation, num_rays: int
) -> Tuple[jax.Array, jax.Array, jax.Array]:
    ray_angles = jnp.linspace(0, 2 * jnp.pi, num_rays, endpoint=False)
    circogram = jnp.full(num_rays, 100.0)  # Default max distance
    winning_segment_indices = jnp.full(num_rays, -1, dtype=jnp.int32)

    # --- Prepare segments and mask ---
    candidate_mask = observation.trajectory.valid[..., 0, :, 0]
    candidate_mask = candidate_mask & ~observation.is_ego[..., 0, :]
    candidate_mask = jnp.repeat(
        candidate_mask, 4
    )  # (num_objects*4,) Each object has 4 segments

    obj_corners = observation.trajectory.bbox_corners[0, :, 0, :, :]
    start_indices = jnp.array([0, 1, 2, 3])
    end_indices = jnp.array([1, 2, 3, 0])
    start_points = obj_corners[:, start_indices]
    end_points = obj_corners[:, end_indices]
    segment_dirs = end_points - start_points
    start_points = start_points.reshape(-1, 2)
    segment_dirs = segment_dirs.reshape(-1, 2)
    line_segments = (start_points, segment_dirs)
    # --- End segment prep ---

    # --- Run loop with subroutine ---
    (circogram, winning_segment_indices, _, _, _) = jax.lax.fori_loop(
        0,
        num_rays,
        circogram_subroutine,
        (circogram, winning_segment_indices, ray_angles, line_segments, candidate_mask),
    )
    # --- End loop ---

    # --- Map winning segment indices to velocities ---
    object_indices = winning_segment_indices // 4  # Get object index from segment index
    obj_vel_xy = observation.trajectory.vel_xy[0, :, 0, :]  # Shape: (num_objects, 2)

    # Gather velocities based on object_indices, handle -1 for no hit
    valid_object_hit = winning_segment_indices >= 0
    # Use index 0 for invalid hits temporarily, mask results later
    valid_object_indices = jnp.where(valid_object_hit, object_indices, 0)
    hit_velocities = obj_vel_xy[valid_object_indices]  # Shape: (num_rays, 2)

    # --- End velocity mapping ---

    # --- Calculate Polar Velocities ---
    # Get ray direction vectors
    ray_dir_x = jnp.cos(ray_angles)
    ray_dir_y = jnp.sin(ray_angles)

    # Project hit velocities onto ray direction (radial speed)
    radial_speed = hit_velocities[:, 0] * ray_dir_x + hit_velocities[:, 1] * ray_dir_y

    # Project hit velocities onto direction perpendicular to the ray (tangential speed)
    # Perpendicular vector: (-ray_dir_y, ray_dir_x)
    tangential_speed = (
        -hit_velocities[:, 0] * ray_dir_y + hit_velocities[:, 1] * ray_dir_x
    )

    # Apply mask for valid hits, setting speeds to 0.0 for non-hits
    ray_radial_speed = jnp.where(valid_object_hit, radial_speed, 0.0)
    ray_tangential_speed = jnp.where(valid_object_hit, tangential_speed, 0.0)
    # --- End Polar Velocity Calculation ---

    return circogram, ray_radial_speed, ray_tangential_speed



scenario = construct_SDC_route(scenario)
img = visualization.plot_simulator_state(scenario, use_log_traj=True)
imgs = []
ray_visualization = []
state = scenario

jit_step = jax.jit(datatypes.update_state_by_log, static_argnums=(1,))
jit_observed = jax.jit(datatypes.sdc_observation_from_state, static_argnums=(1,2,3))
jit_create_object_circogram = jax.jit(create_object_circogram, static_argnums=(1))
jit_create_road_circogram = jax.jit(create_road_circogram, static_argnums=(1))
for _ in tqdm(range(scenario.remaining_timesteps)):
    state = jit_step(state, num_steps=1)
    #imgs.append(visualization.plot_simulator_state(state, use_log_traj=True))
    imgs.append(
        visualization.plot_observation(
            jit_observed(state, roadgraph_top_k=2500),
            obj_idx=0
        )
    )

    num_rays = 64
    ray_angles = jnp.linspace(-jnp.pi, jnp.pi, num_rays, endpoint=False)
    observation = jit_observed(state, roadgraph_top_k=2500)
    object_circogram,_,_ = jit_create_object_circogram(observation, num_rays)
    road_circogram,_,_ = jit_create_road_circogram(observation, num_rays)
    circogram = jnp.minimum(object_circogram, road_circogram)

    # Create line segments from object bounding box corners
    obj_corners = observation.trajectory.bbox_corners[0,:, 0,:,:]
    
    # Create indices for corners
    start_indices = jnp.array([0, 1, 2, 3])
    end_indices = jnp.array([1, 2, 3, 0])
    
    start_points = obj_corners[:, start_indices]  # (num_objects, 4, 2)
    end_points = obj_corners[:, end_indices]  # (num_objects, 4, 2)
    
    #flatten the start and end points
    start_points = start_points.reshape(-1, 2)  # (num_objects*4, 2)
    end_points = end_points.reshape(-1, 2)  # (num_objects*4, 2)

    test_points = jnp.array([[20, -10], [-20, 10]])
    test_dirs = jnp.array([[0, 20], [0, -20]])


  # Create a visualization of ray distances
    small_size = 500  # Size of the visualization image
    scale_factor = 5.0  # Scale factor for distance visualization
    x = small_size // 2  # Center of the image
    y = small_size // 2  # Center of the image
    ray_img = np.ones((small_size, small_size, 3), dtype=np.uint8) * 255  # White background
    center = (x, y)  # Center of the image

    # Draw rays
    for i in range(num_rays):
        # Calculate endpoint based on ray angle and distance
        angle = ray_angles[i]
        dist = circogram[i]

        # Scale the distance for visualization
        scaled_dist = dist * scale_factor

        # Calculate endpoint
        end_y = int(center[0] + np.sin(angle) * scaled_dist)
        end_x = int(center[1] + np.cos(angle) * scaled_dist)

        # Draw a line from center to endpoint
        cv2.line(ray_img, center, (end_x, end_y), (0, 0, 255), 1)  # Red lines

    #Draw object bounding boxes
    for i in range(len(start_points)):
        # Get the start and end points of the bounding box
        start_point = start_points[i]
        end_point = end_points[i]

        # Draw the bounding box
        cv2.line(
            ray_img,
            (int(center[0] + start_point[0] * scale_factor), int(center[1] + start_point[1] * scale_factor)),
            (int(center[0] + end_point[0] * scale_factor), int(center[1] + end_point[1] * scale_factor)),
            (255, 0, 0), 1
        )  # Blue lines for bounding box

    # Add a circle for the SDC
    cv2.circle(ray_img, center, 5, (0, 255, 0), -1)  # Green circle for SDC

    # Flip the ray image upside down to match conventional coordinate system
    ray_img = cv2.flip(ray_img, 0)

    # Draw a reference circle at maximum range
    max_range_radius = int(100.0 * scale_factor)  # 100 meters reference
    cv2.circle(ray_img, center, max_range_radius, (200, 200, 200), 1)  # Gray circle
    # Add distance markers
    for dist in [20, 40, 60, 80]:
        radius = int(dist * scale_factor)
        cv2.circle(ray_img, center, radius, (220, 220, 220), 1, cv2.LINE_AA)
        # Add text label
        text_pos = (center[0] + 5, center[1] - radius - 5)
        cv2.putText(ray_img, f"{dist}m", text_pos, cv2.FONT_HERSHEY_SIMPLEX, 
                    0.4, (100, 100, 100), 1, cv2.LINE_AA)

    # Add to ray_visualization array
    ray_visualization.append(ray_img)



mediapy.show_video(imgs, fps=10)
mediapy.show_video(ray_visualization, fps=10)



100%|██████████| 90/90 [01:19<00:00,  1.14it/s]


0
This browser does not support the video tag.


0
This browser does not support the video tag.


In [None]:
import dataclasses
from typing import Any, Tuple, override

import casadi
import jax
import jax.numpy as jnp
import mediapy
import numpy as np
import waymax.utils.geometry as utils
from dm_env import specs
from tqdm import tqdm
from waymax import agents
from waymax import config as _config
from waymax import dataloader, datatypes, dynamics
from waymax import env as _env
from waymax import  visualization
from waymax.env import typedefs as types
from waymax.metrics.roadgraph import is_offroad
import os
from casadi import external


def construct_SDC_route(
    state,
):
    """Construct a SDC route from the logged trajectory. This is neccessary for the progression metric as WOMD doesn't release their routes.
    Args:
        state: The simulator state.
    Returns:
        The updated simulator state with the SDC route.
    """
    # Calculate arc lengths (cumulative distances along the trajectory)
    # Select sdc trajectory
    sdc_trajectory: datatypes.Trajectory = datatypes.select_by_onehot(
        state.log_trajectory,
        state.object_metadata.is_sdc,
        keepdims=True,
    )
    x = sdc_trajectory.x
    y = sdc_trajectory.y
    z = sdc_trajectory.z

    # Downsample trajectory coordinates by keeping every 3rd point
    stride = 10

    # Get downsampled coordinates
    x_downsampled = x[..., ::stride]
    y_downsampled = y[..., ::stride]
    z_downsampled = z[..., ::stride]

    # Check if last point needs to be added
    num_points = x.shape[-1]
    last_included = (num_points - 1) % stride == 0

    x = jnp.concatenate([x_downsampled, x[..., -1:]], axis=-1)
    y = jnp.concatenate([y_downsampled, y[..., -1:]], axis=-1)
    z = jnp.concatenate([z_downsampled, z[..., -1:]], axis=-1)

    # Calculate differences between consecutive points
    dx = jnp.diff(x, axis=-1)
    dy = jnp.diff(y, axis=-1)

    # Calculate Euclidean distance for each step
    step_distances = jnp.sqrt(dx**2 + dy**2)

    # Calculate cumulative distances
    arc_lengths = jnp.zeros_like(x)
    arc_lengths = arc_lengths.at[..., 1:].set(jnp.cumsum(step_distances, axis=-1))

    logged_route = datatypes.Paths(
        x=x,
        y=y,
        z=z,
        valid=jnp.array([[True] * len(x[0])]),
        arc_length=arc_lengths,
        on_route=jnp.array([[True]]),
        ids=jnp.array([[0] * len(x)]),  # Dummy ID
    )
    return dataclasses.replace(
        state,
        sdc_paths=logged_route,
    )


def ray_segment_intersection(
    ray_angle: jax.Array, start_points: jax.Array, segment_dirs: jax.Array
) -> jax.Array:
    """
    Calculate the intersection distances between a ray and line segments.

    Args:
        ray_angle: The angle of the ray (in radians).
        start_points: Array of shape (N, 2) for segment start points (x,y).
        segment_dirs: Array of shape (N, 2) for segment directions (dx, dy).

    Returns:
        Array of distances from origin to intersections. Returns 100.0 if no intersection.
    """
    # Calculate ray direction
    ray_dir_x = jnp.cos(ray_angle)
    ray_dir_y = jnp.sin(ray_angle)

    # Calculate segment direction
    segment_dir_x = segment_dirs[:, 0]
    segment_dir_y = segment_dirs[:, 1]

    # Calculate determinant for intersection test
    det = segment_dir_x * ray_dir_y - segment_dir_y * ray_dir_x

    # Avoid division by zero for parallel lines
    is_parallel = jnp.abs(det) < 1e-8
    det = jnp.where(is_parallel, 1.0, det)  # Avoid division by zero

    # Calculate t1 and t2 parameters
    t1 = -(start_points[:, 0] * ray_dir_y - start_points[:, 1] * ray_dir_x) / det
    t2 = (start_points[:, 0] * segment_dir_y - start_points[:, 1] * segment_dir_x) / det

    # Check if intersection is within segment (0 <= t1 <= 1) and ray (t2 >= 0)
    valid_t1 = (t1 >= 0.0) & (t1 <= 1.0)
    valid_t2 = t2 >= 0.0

    # Combine intersection validity checks
    valid_intersection = valid_t1 & valid_t2 & ~is_parallel

    # Calculate intersection points and distances
    ix = start_points[:, 0] + t1 * segment_dir_x
    iy = start_points[:, 1] + t1 * segment_dir_y

    # Distance from origin to intersection point
    distances = jnp.sqrt(ix**2 + iy**2)

    # Make sure distances over 100 are not valid
    valid_intersection = valid_intersection & (distances < 100.0)
    
    # Return distance if valid intersection, otherwise 100.0
    return jnp.where(valid_intersection, distances, 100.0)


def circogram_subroutine(
    i:int, initval: Tuple[jax.Array, jax.Array,Tuple[jax.Array,jax.Array], jax.Array]
):
    circogram, ray_angles,(starting_points,dir_xy), candidate_mask = initval
    ray_angle = ray_angles[i]

    # Calculate intersection distances
    intersection_distances = ray_segment_intersection(
        ray_angle, starting_points, dir_xy
    )

    # Only consider specified segments
    masked_distances = jnp.where(candidate_mask, intersection_distances, 100.0)

    # Find minimum distance among candidate segments
    min_distance = jnp.min(masked_distances)

    # Update only the i-th circogram ray and return the whole array
    circogram = circogram.at[i].set(min_distance)
    return circogram,ray_angles, (starting_points,dir_xy), candidate_mask

def create_road_circogram(observation: datatypes.Observation, num_rays: int) -> jax.Array:
    ray_angles = jnp.linspace(0, 2 * jnp.pi, num_rays, endpoint=False)
    circogram = jnp.full(num_rays, 100.0)  # Default max distance
    rg_points = observation.roadgraph_static_points
    candidate_mask = rg_points.valid
    candidate_mask = candidate_mask & (
        (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_BOUNDARY)
        | (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_MEDIAN)
        | (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_UNKNOWN)
    )
    # Create line segments from roadgraph points
    starting_points = jnp.stack([rg_points.x, rg_points.y], axis=1)
    dir_xy = jnp.stack([rg_points.dir_x, rg_points.dir_y], axis=1)
    line_segments = (starting_points, dir_xy)

    (circogram,_,_,_) = jax.lax.fori_loop(
        0,
        num_rays,
        circogram_subroutine,
        (circogram, ray_angles, line_segments, candidate_mask),
    )
    return circogram

def create_object_circogram(
    observation: datatypes.Observation, num_rays: int)-> jax.Array:
    ray_angles = jnp.linspace(0, 2 * jnp.pi, num_rays, endpoint=False)
    circogram = jnp.full(num_rays, 100.0)  # Default max distance
    
    candidate_mask = observation.trajectory.valid[...,0,:, 0]
    candidate_mask = candidate_mask & ~observation.is_ego[...,0,:]
    candidate_mask = jnp.repeat(candidate_mask, 4)  # (num_objects*4,) Each object has 4 segments

    # Create line segments from object bounding box corners
    obj_corners = observation.trajectory.bbox_corners[0,:, 0,:,:]    
    start_indices = jnp.array([0, 1, 2, 3])
    end_indices = jnp.array([1, 2, 3, 0])
    start_points = obj_corners[:, start_indices]  # (num_objects, 4, 2)
    end_points = obj_corners[:, end_indices]  # (num_objects, 4, 2)
    segment_dirs = end_points - start_points  # (num_objects, 4, 2)
    start_points = start_points.reshape(-1, 2)  # (num_objects*4, 2)
    segment_dirs = segment_dirs.reshape(-1, 2)  # (num_objects*4, 2)
    line_segments = (start_points, segment_dirs)
    
    (circogram,_,_,_) = jax.lax.fori_loop(
        0,
        num_rays,
        circogram_subroutine,
        (circogram, ray_angles, line_segments, candidate_mask),
    )
    return circogram

class WaymaxEnv(_env.PlanningAgentEnvironment):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @override
    def observe(self, state: _env.PlanningAgentSimulatorState) -> Any:
        """Computes the observation for the given simulation state.

        Here we assume that the default observation is just the simulator state. We
        leave this for the user to override in order to provide a user-specific
        observation function. A user can use this to move some of their model
        specific post-processing into the environment rollout in the actor nodes. If
        they want this post-processing on the accelerator, they can keep this the
        same and implement it on the learner side. We provide some helper functions
        at datatypes.observation.py to help write your own observation functions.

        Args:
          state: Current state of the simulator of shape (...).

        Returns:
          Simulator state as an observation without modifications of shape (...).
        """
        # Get base observation from SDC perspective first
        observation = datatypes.sdc_observation_from_state(state, roadgraph_top_k=1000)

        sdc_trajectory = datatypes.select_by_onehot(
            observation.trajectory,
            observation.is_ego,
            keepdims=True,
        )
        sdc_velocity_xy = sdc_trajectory.vel_xy

        # Create the goal position from the last point in the logged trajectory
        sdc_xy_goal = datatypes.select_by_onehot(
            state.log_trajectory.xy[..., -1, :],
            state.object_metadata.is_sdc,
            keepdims=True,
        )
        sdc_xy_goal = utils.transform_points(observation.pose2d.matrix, sdc_xy_goal)[0]
        # Convert the goal position from Cartesian to polar coordinates
        sdc_goal_distance = jnp.sqrt(
            sdc_xy_goal[..., 0] ** 2 + sdc_xy_goal[..., 1] ** 2
        )
        sdc_goal_angle = jnp.arctan2(sdc_xy_goal[..., 1], sdc_xy_goal[..., 0])
        sdc_yaw_goal = datatypes.select_by_onehot(
            state.log_trajectory.yaw[..., -1],
            state.object_metadata.is_sdc,
            keepdims=True,
        )

        sdc_offroad = is_offroad(sdc_trajectory, observation.roadgraph_static_points)
        sdc_offroad = sdc_offroad.astype(jnp.float32)  # Convert boolean to float32

        num_rays = 64
        road_circogram = create_road_circogram(observation, num_rays)
        object_circogram = create_object_circogram(observation, num_rays)

        obs = jnp.concatenate(
            [
                sdc_goal_distance.flatten(),
                sdc_goal_angle.flatten(),
                sdc_velocity_xy.flatten(),
                sdc_offroad.flatten(),
                road_circogram.flatten(),
                object_circogram.flatten(),
            ],
            axis=-1,
        )
        return obs

    @override
    def observation_spec(self) -> types.Observation:
        """Returns the observation spec of the environment.
        Returns:
            Observation spec of the environment.
        """
        # Define dimensions for each observation component
        sdc_goal_angle_dim = 1
        sdc_goal_distance_dim = 1
        sdc_vel_dim = 2
        sdc_offroad_dim = 1
        road_circogram_dim = 64
        object_circogram_dim = 64

        # Total shape is the sum of all component dimensions
        total_dim = (
            sdc_goal_angle_dim
            + sdc_goal_distance_dim
            + sdc_vel_dim
            + sdc_offroad_dim
            + road_circogram_dim
            + object_circogram_dim
        )

        # Define min/max bounds for each component
        sdc_goal_angle_min = [-jnp.pi]
        sdc_goal_angle_max = [jnp.pi]
        sdc_goal_distance_min = [0]
        sdc_goal_distance_max = [250]

        # Radial speed
        sdc_vel_x_min = [-30]
        sdc_vel_x_max = [30]

        # Tangential speed. The limit is calculated based on a maximum absolute speed of 30 m/s
        sdc_vel_y_min = [-9]
        sdc_vel_y_max = [9]

        sdc_offroad_min = [0]
        sdc_offroad_max = [1]

        road_circogram_min = [0] * road_circogram_dim
        road_circogram_max = [100] * road_circogram_dim
        object_circogram_min = [0] * object_circogram_dim
        object_circogram_max = [100] * object_circogram_dim

        # Combine all bounds
        min_bounds = jnp.array(
            sdc_goal_angle_min
            + sdc_goal_distance_min
            + sdc_vel_x_min
            + sdc_vel_y_min
            + sdc_offroad_min
            + road_circogram_min
            + object_circogram_min
        )
        max_bounds = jnp.array(
            sdc_goal_angle_max
            + sdc_goal_distance_max
            + sdc_vel_x_max
            + sdc_vel_y_max
            + sdc_offroad_max
            + road_circogram_max
            + object_circogram_max
        )

        return specs.BoundedArray(
            shape=(total_dim,),
            minimum=min_bounds,
            maximum=max_bounds,
            dtype=jnp.float32,
        )


max_num_objects = 32
config = dataclasses.replace(
    _config.WOD_1_1_0_TRAINING,
    max_num_objects=max_num_objects,
    max_num_rg_points=30000,
    path="./data/training_tfexample.tfrecord@5",
)
data_iter = dataloader.simulator_state_generator(config=config)
sim_agent_config = _config.SimAgentConfig(
    agent_type=_config.SimAgentType.IDM,
    controlled_objects=_config.ObjectType.NON_SDC,
)
metrics_config = _config.MetricsConfig(
    metrics_to_run=("sdc_progression", "offroad")
)
reward_config = _config.LinearCombinationRewardConfig(
    rewards={"sdc_progression": 1.0, "offroad": -1.0},
)
env_config = dataclasses.replace(
    _config.EnvironmentConfig(),
    metrics=metrics_config,
    rewards=reward_config,
    max_num_objects=max_num_objects,
    sim_agents=[sim_agent_config],
)
dynamics_model = dynamics.InvertibleBicycleModel(normalize_actions=True)
env = WaymaxEnv(
    dynamics_model=dynamics_model,
    config=env_config,
    sim_agent_actors=[agents.create_sim_agents_from_config(sim_agent_config)],
    sim_agent_params=[{}],
)

obj_idx = jnp.arange(max_num_objects)
actor = agents.create_constant_speed_actor(
    speed=5.0,
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: obj_idx !=-1,
)
expert_actor = agents.create_expert_actor(
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: obj_idx !=-1,
)
jit_step =jax.jit(env.step)
jit_select = jax.jit(datatypes.select_by_onehot, static_argnums=(2))
jit_select_action = jax.jit(expert_actor.select_action)
jit_observe = jax.jit(env.observe)
jit_observe_from_state = jax.jit(datatypes.sdc_observation_from_state)
jit_reward = jax.jit(env.reward)
jit_construct_SDC_route = jax.jit(construct_SDC_route)
jit_reset = jax.jit(env.reset)
jit_transform_points = jax.jit(utils.transform_points)
jit_create_road_circogram = jax.jit(create_road_circogram, static_argnums=(1))

def create_mpc_solver() -> casadi.Function:
    """
    Creates a compiled CasADi MPC solver for vehicle control with C code generation for speed.
    
    Returns:
        A compiled CasADi function that takes state and goal parameters and returns optimal control actions.
    """
    
    MAX_ACCEL = 6.0
    MAX_STEERING = 0.3
    MAX_SPEED = 20.0

    # Problem dimensions
    N = 20  # Prediction horizon
    dt = 0.1  # Time step
    
    # Create CasADi optimization variables
    opti = casadi.Opti()
    
    # State variables over the horizon: x, y, yaw, speed
    x = opti.variable(N+1)
    y = opti.variable(N+1)
    yaw = opti.variable(N+1)
    speed = opti.variable(N+1)
    
    # Control variables over the horizon: acceleration, steering
    accel = opti.variable(N)
    steering = opti.variable(N)
    
    # Parameters: initial state and target position
    params = opti.parameter(6)  # [start_x, start_y, start_yaw, start_speed, target_x, target_y]
    
    # Initial state constraints
    opti.subject_to(x[0] == params[0])
    opti.subject_to(y[0] == params[1])
    opti.subject_to(yaw[0] == params[2])
    opti.subject_to(speed[0] == params[3])
    
    # Target position
    target_x = params[4]
    target_y = params[5]
    
    # System dynamics (bicycle model)
    for i in range(N):
        opti.subject_to(x[i+1] == x[i] + dt * speed[i] * casadi.cos(yaw[i]))
        opti.subject_to(y[i+1] == y[i] + dt * speed[i] * casadi.sin(yaw[i]))
        opti.subject_to(yaw[i+1] == yaw[i] + dt * speed[i] * steering[i]*MAX_STEERING)
        opti.subject_to(speed[i+1] == speed[i] + dt * accel[i]*MAX_ACCEL)
        
        # Control constraints
        opti.subject_to(opti.bounded(-1, accel[i], 1))
        opti.subject_to(opti.bounded(-1, steering[i], 1))
        
        # State constraints
        opti.subject_to(opti.bounded(0.0, speed[i+1], MAX_SPEED))
    
    # Add circogram parameter for obstacle distances
    circogram = opti.parameter(64)  # 64 ray measurements
    safety_margin = 1.0

    # Add simple constraints for the four main directions (right, forward, left, backward)
    for i in range(1, N+1):
        # Right constraint (angle = 0, index = 0)
        opti.subject_to(x[i] - x[0] <= circogram[0] - safety_margin)
        
        # Forward constraint (angle = 90°, index = 16)
        opti.subject_to(y[i] - y[0] <= circogram[16] -safety_margin)
        
        # Left constraint (angle = 180°, index = 32)
        opti.subject_to(x[0] - x[i] <= circogram[32] - safety_margin)
        
        # Backward constraint (angle = 270°, index = 48)
        opti.subject_to(y[0] - y[i] <= circogram[48] - safety_margin)

    # Simplified objective function
    distance_to_goal = (x[N] - target_x)**2 + (y[N] - target_y)**2  # Terminal cost
    control_effort = casadi.sum1(accel**2 + 5.0 * steering**2)  # Control regularization

    opti.minimize(10*distance_to_goal + control_effort)


    # Set up solver options
    p_opts = {"expand": True, "print_time": 0}
    s_opts = {
        "max_iter": 100,
        "print_level": 0,
        "warm_start_init_point": "yes",   
        "acceptable_tol": 1e-2,
        "acceptable_obj_change_tol": 1e-2,
        }
    opti.solver('ipopt', p_opts, s_opts)
    
    # Create a CasADi function
    mpc_fn = opti.to_function('mpc_solver', [params, circogram], [accel[0], steering[0]], 
                             ['params', 'circogram'], ['optimal_accel', 'optimal_steering'])
        

    try:
        c_code_name = "mpc_solver"
        c_file_name = os.path.join(c_code_name + ".c")
        so_file_name = os.path.join(c_code_name + ".so")
        
        print(f"Generating C code at: {c_file_name}")
        
        code_gen = casadi.CodeGenerator(c_file_name, {"with_header": True, "with_mem": True})
        code_gen.add(mpc_fn)
        code_gen.generate()
        
        # Check if file was created successfully
        if os.path.exists(c_file_name):
            print(f"C file created successfully at {c_file_name}")
            casadi_include = os.path.join(os.path.dirname(casadi.__file__), 'include')
            compile_command = f"gcc -fPIC -shared -Ofast -march=native -I{casadi_include} {c_file_name} -o {so_file_name} -lipopt -ldl -lm"

            print(f"Running: {compile_command}")
            os.system(compile_command)
            
            # Load the compiled function
            mpc_fn = external('mpc_solver', so_file_name)
            print(f"Successfully compiled MPC solver to C.")

        else:
            print(f"Failed to create C file at {c_file_name}")
    except Exception as e:
            print(f"C compilation failed, using normal CasADi function: {str(e)}...")

    return mpc_fn
    


# Compile the MPC solver once
compiled_mpc_solver = create_mpc_solver()

def get_action(state: datatypes.SimulatorState) -> datatypes.Action:
    observation = jit_observe_from_state(state)
    sdc_trajectory = jit_select(
        observation.trajectory,
        observation.is_ego,
        keepdims=True,
    )
    sdc_velocity_xy = sdc_trajectory.vel_xy
    sdc_xy_goal = jit_select(
        state.log_trajectory.xy[..., -1, :],
        state.object_metadata.is_sdc,
        keepdims=True,
    )
    sdc_xy_goal = jit_transform_points(observation.pose2d.matrix, sdc_xy_goal)[0]

    start_x = 0.0
    start_y = 0.0
    start_yaw = 0.0
    start_vel_x = float(sdc_velocity_xy.flatten()[0])
    start_vel_y = float(sdc_velocity_xy.flatten()[1])
    start_speed = np.sqrt(start_vel_x**2 + start_vel_y**2)
    
    target_x = float(sdc_xy_goal[0])
    target_y = float(sdc_xy_goal[1])

    num_rays = 64
    road_circogram = jit_create_road_circogram(observation, num_rays)
    
    try:
        params = casadi.DM([start_x, start_y, start_yaw, start_speed, target_x, target_y])
        circogram = casadi.DM(road_circogram) 
        
        # Call the function correctly
        result = compiled_mpc_solver(params, circogram)
        
        # Extract results (must convert to scalar values)
        optimal_accel = float(result[0])
        optimal_steering = float(result[1])

        
    except Exception as e:
        print(f"MPC solver failed: {str(e)[:100]}...")
        # Fallback to a simple controller
        optimal_steering = 0.0  # No steering
        optimal_accel = 0.0  # No acceleration
    
    action = jnp.array([optimal_accel, optimal_steering])
    return datatypes.Action(
        data=action,
        valid=jnp.array([True]))

states= []
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
for _ in range(13):
    scenario = next(data_iter)
    state=jit_reset(scenario)
    state=jit_construct_SDC_route(state)
    states.append(state)
    for _ in tqdm(range(80)):
        action=get_action(state)
        state = jit_step(state, action)
        reward = jit_reward(state, action)
        states.append(state)
imgs = []
for state in states:
    img = visualization.plot_simulator_state(state, use_log_traj=False)
    imgs.append(img)

mediapy.show_video(imgs, fps=10)

Generating C code at: mpc_solver.c
C file created successfully at mpc_solver.c
Running: gcc -fPIC -shared -Ofast -march=native -I/home/poggle/Documents/waymax-auto-planner/.venv/lib64/python3.12/site-packages/casadi/include mpc_solver.c -o mpc_solver.so -lipopt -ldl -lm
Successfully compiled MPC solver to C.


100%|██████████| 80/80 [00:06<00:00, 12.79it/s]
100%|██████████| 80/80 [00:02<00:00, 27.22it/s]
100%|██████████| 80/80 [00:03<00:00, 25.32it/s]
100%|██████████| 80/80 [00:03<00:00, 25.59it/s]
100%|██████████| 80/80 [00:03<00:00, 25.43it/s]
100%|██████████| 80/80 [00:05<00:00, 13.74it/s]
100%|██████████| 80/80 [00:02<00:00, 27.10it/s]
100%|██████████| 80/80 [00:03<00:00, 26.05it/s]
100%|██████████| 80/80 [00:03<00:00, 23.17it/s]
100%|██████████| 80/80 [00:03<00:00, 24.84it/s]
100%|██████████| 80/80 [00:03<00:00, 23.29it/s]
100%|██████████| 80/80 [00:02<00:00, 27.00it/s]
100%|██████████| 80/80 [00:03<00:00, 26.63it/s]


0
This browser does not support the video tag.
