In [None]:
from typing import Tuple
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses
import jax
import jax.numpy as jnp
from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax import visualization
import cv2

max_num_objects = 32
config = dataclasses.replace(
    _config.WOD_1_1_0_TRAINING,
    max_num_objects=max_num_objects,
    max_num_rg_points=30000,
    path="./data/training_tfexample.tfrecord@5",
)
data_iter = dataloader.simulator_state_generator(config=config)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)

def construct_SDC_route(
    state,
):
    """Construct a SDC route from the logged trajectory. This is neccessary for the progression metric as WOMD doesn't release their routes.
    Args:
        state: The simulator state.
    Returns:
        The updated simulator state with the SDC route.
    """
    # Calculate arc lengths (cumulative distances along the trajectory)
    # Select sdc trajectory
    sdc_trajectory: datatypes.Trajectory = datatypes.select_by_onehot(
        state.log_trajectory,
        state.object_metadata.is_sdc,
        keepdims=True,
    )
    x = sdc_trajectory.x
    y = sdc_trajectory.y
    z = sdc_trajectory.z

    # Downsample trajectory coordinates by keeping every 3rd point
    stride = 10

    # Get downsampled coordinates
    x_downsampled = x[..., ::stride]
    y_downsampled = y[..., ::stride]
    z_downsampled = z[..., ::stride]

    # Check if last point needs to be added
    num_points = x.shape[-1]
    last_included = (num_points - 1) % stride == 0

    x = jnp.concatenate([x_downsampled, x[..., -1:]], axis=-1)
    y = jnp.concatenate([y_downsampled, y[..., -1:]], axis=-1)
    z = jnp.concatenate([z_downsampled, z[..., -1:]], axis=-1)

    # Calculate differences between consecutive points
    dx = jnp.diff(x, axis=-1)
    dy = jnp.diff(y, axis=-1)

    # Calculate Euclidean distance for each step
    step_distances = jnp.sqrt(dx**2 + dy**2)

    # Calculate cumulative distances
    arc_lengths = jnp.zeros_like(x)
    arc_lengths = arc_lengths.at[..., 1:].set(jnp.cumsum(step_distances, axis=-1))

    logged_route = datatypes.Paths(
        x=x,
        y=y,
        z=z,
        valid=jnp.array([[True] * len(x[0])]),
        arc_length=arc_lengths,
        on_route=jnp.array([[True]]),
        ids=jnp.array([[0] * len(x)]),  # Dummy ID
    )
    return dataclasses.replace(
        state,
        sdc_paths=logged_route,
    )

scenario = construct_SDC_route(scenario)

img = visualization.plot_simulator_state(scenario, use_log_traj=True)
imgs = []
ray_visualization = []
state = scenario


def ray_segment_intersection(
    ray_angle: jax.Array, start_points: jax.Array, segment_dirs: jax.Array
) -> jax.Array:
    """
    Calculate the intersection distances between a ray and line segments.

    Args:
        ray_angle: The angle of the ray (in radians).
        start_points: Array of shape (N, 2) for segment start points (x,y).
        segment_dirs: Array of shape (N, 2) for segment directions (dx, dy).

    Returns:
        Array of distances from origin to intersections. Returns 100.0 if no intersection.
    """
    # Calculate ray direction
    ray_dir_x = jnp.sin(ray_angle)
    ray_dir_y = jnp.cos(ray_angle)

    # Calculate segment direction
    segment_dir_x = segment_dirs[:, 0]
    segment_dir_y = segment_dirs[:, 1]

    # Calculate determinant for intersection test
    det = segment_dir_x * ray_dir_y - segment_dir_y * ray_dir_x

    # Avoid division by zero for parallel lines
    is_parallel = jnp.abs(det) < 1e-8
    det = jnp.where(is_parallel, 1.0, det)  # Avoid division by zero

    # Calculate t1 and t2 parameters
    t1 = (start_points[:, 0] * ray_dir_y - start_points[:, 1] * ray_dir_x) / det
    t2 = (start_points[:, 0] * segment_dir_y - start_points[:, 1] * segment_dir_x) / det

    # Check if intersection is within segment (0 <= t1 <= 1) and ray (t2 >= 0)
    valid_t1 = (t1 >= 0.0) & (t1 <= 1.0)
    valid_t2 = t2 >= 0.0

    # Combine intersection validity checks
    valid_intersection = valid_t1 & valid_t2 & ~is_parallel

    # Calculate intersection points and distances
    ix = start_points[:, 0] + t1 * segment_dir_x
    iy = start_points[:, 1] + t1 * segment_dir_y

    # Distance from origin to intersection point
    distances = jnp.sqrt(ix**2 + iy**2)

    # Make sure distances over 100 are not valid
    valid_intersection = valid_intersection & (distances < 100.0)
    
    # Return distance if valid intersection, otherwise 100.0
    return jnp.where(valid_intersection, distances, 100.0)


def find_closest_distance(
    i, initval: Tuple[jax.Array, datatypes.RoadgraphPoints, jax.Array]
):
    circogram, rg_points, ray_angles = initval
    ray_angle = ray_angles[i]

    # Only consider valid points
    candidate_points = rg_points.valid

    # Only consider road edge points
    candidate_points = candidate_points & (
        (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_BOUNDARY)
        | (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_MEDIAN)
        | (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_UNKNOWN)
    )

    # If no valid points, return the default distance
    has_valid = jnp.any(candidate_points)

    # Create line segments from roadgraph points
    starting_points = jnp.stack([rg_points.x, rg_points.y], axis=1)
    dir_xy = jnp.stack([rg_points.dir_x, rg_points.dir_y], axis=1)

    # Calculate intersection distances
    intersection_distances = ray_segment_intersection(
        ray_angle, starting_points, dir_xy
    )
    masked_distances = jnp.where(candidate_points, intersection_distances, 100.0)

    # Find minimum distance among valid points
    min_distance = jnp.min(masked_distances)

    # Update only the i-th element and return the whole array
    circogram = circogram.at[i].set(jnp.where(has_valid, min_distance, circogram[i]))
    return circogram, rg_points, ray_angles


def create_circogram(observation: datatypes.Observation, num_rays: int) -> jax.Array:
    ray_angles = jnp.linspace(0, 2 * jnp.pi, num_rays, endpoint=False)
    rg_points = observation.roadgraph_static_points
    # For each ray angle, find the closest intersection
    circogram = jnp.full(num_rays, 100.0)  # Default max distance
    (
        circogram,
        _,
        _,
    ) = jax.lax.fori_loop(
        0,
        num_rays,
        find_closest_distance,
        (circogram, rg_points, ray_angles),
    )
    return circogram



jit_step = jax.jit(datatypes.update_state_by_log, static_argnums=(1,))
jit_observed = jax.jit(datatypes.sdc_observation_from_state, static_argnums=(1,2))
jit_create_circogram = jax.jit(create_circogram, static_argnums=(1))
for _ in tqdm(range(scenario.remaining_timesteps)):
    state = jit_step(state, num_steps=1)
    #imgs.append(visualization.plot_simulator_state(state, use_log_traj=True))
    imgs.append(
        visualization.plot_observation(
            jit_observed(state, roadgraph_top_k=2000),
            obj_idx=0
        )
    )

    num_rays = 64
    ray_angles = jnp.linspace(-jnp.pi, jnp.pi, num_rays, endpoint=False)
    observation = jit_observed(state, roadgraph_top_k=2000)
    circogram = jit_create_circogram(observation, num_rays)
    

  # Create a visualization of ray distances
    small_size = 500  # Size of the visualization image
    scale_factor = 5.0  # Scale factor for distance visualization
    x = small_size // 2  # Center of the image
    y = small_size // 2  # Center of the image
    ray_img = np.ones((small_size, small_size, 3), dtype=np.uint8) * 255  # White background
    center = (x, y)  # Center of the image

    # Draw rays
    for i in range(num_rays):
        # Calculate endpoint based on ray angle and distance
        angle = ray_angles[i]
        dist = circogram[i]

        # Scale the distance for visualization
        scaled_dist = dist * scale_factor

        # Calculate endpoint
        end_x = int(center[0] + np.sin(angle) * scaled_dist)
        end_y = int(center[1] + np.cos(angle) * scaled_dist)

        # Draw a line from center to endpoint
        cv2.line(ray_img, center, (end_x, end_y), (0, 0, 255), 1)  # Red lines

    # Add a circle for the SDC
    cv2.circle(ray_img, center, 5, (0, 255, 0), -1)  # Green circle for SDC

    # Flip the ray image upside down to match conventional coordinate system
    ray_img = cv2.flip(ray_img, 0)

    # Draw a reference circle at maximum range
    max_range_radius = int(100.0 * scale_factor)  # 100 meters reference
    cv2.circle(ray_img, center, max_range_radius, (200, 200, 200), 1)  # Gray circle
    # Add distance markers
    for dist in [20, 40, 60, 80]:
        radius = int(dist * scale_factor)
        cv2.circle(ray_img, center, radius, (220, 220, 220), 1, cv2.LINE_AA)
        # Add text label
        text_pos = (center[0] + 5, center[1] - radius - 5)
        cv2.putText(ray_img, f"{dist}m", text_pos, cv2.FONT_HERSHEY_SIMPLEX, 
                    0.4, (100, 100, 100), 1, cv2.LINE_AA)

    # Add to ray_visualization array
    ray_visualization.append(ray_img)
mediapy.show_video(imgs, fps=10)
mediapy.show_video(ray_visualization, fps=10)



In [3]:
import dataclasses
from collections import defaultdict
from typing import Any, Tuple, override

import cv2
import jax
import jax.numpy as jnp
import mediapy
import numpy as np
import waymax.utils.geometry as utils
import do_mpc
import casadi
from dm_env import specs
from matplotlib import pyplot as plt
from scipy.spatial import KDTree
from tqdm import tqdm
from waymax import agents
from waymax import config as _config
from waymax import dataloader, datatypes, dynamics
from waymax import env as _env
from waymax import metrics, visualization
from waymax.env import typedefs as types
from waymax.metrics.roadgraph import is_offroad


def construct_SDC_route(
    state: _env.PlanningAgentSimulatorState,
) -> _env.PlanningAgentSimulatorState:
    """Construct a SDC route from the logged trajectory. This is neccessary for the progression metric as WOMD doesn't release their routes.
    Args:
        state: The simulator state.
    Returns:
        The updated simulator state with the SDC route.
    """
    # Calculate arc lengths (cumulative distances along the trajectory)
    # Select sdc trajectory
    sdc_trajectory: datatypes.Trajectory = datatypes.select_by_onehot(
        state.log_trajectory,
        state.object_metadata.is_sdc,
        keepdims=True,
    )
    x = sdc_trajectory.x
    y = sdc_trajectory.y
    z = sdc_trajectory.z

    # Downsample trajectory coordinates
    stride = 5  # Downsample every 5th point

    # Get downsampled coordinates
    x_downsampled = x[..., ::stride]
    y_downsampled = y[..., ::stride]
    z_downsampled = z[..., ::stride]

    # Check if last point needs to be added
    num_points = x.shape[-1]
    last_included = (num_points - 1) % stride == 0

    x = jnp.concatenate([x_downsampled, x[..., -1:]], axis=-1)
    y = jnp.concatenate([y_downsampled, y[..., -1:]], axis=-1)
    z = jnp.concatenate([z_downsampled, z[..., -1:]], axis=-1)

    # Calculate differences between consecutive points
    dx = jnp.diff(x, axis=-1)
    dy = jnp.diff(y, axis=-1)

    # Calculate Euclidean distance for each step
    step_distances = jnp.sqrt(dx**2 + dy**2)

    # Calculate cumulative distances
    arc_lengths = jnp.zeros_like(x)
    arc_lengths = arc_lengths.at[..., 1:].set(jnp.cumsum(step_distances, axis=-1))

    logged_route = datatypes.Paths(
        x=x,
        y=y,
        z=z,
        valid=jnp.array([[True] * len(x[0])]),
        arc_length=arc_lengths,
        on_route=jnp.array([[True]]),
        ids=jnp.array([[0] * len(x)]),  # Dummy ID
    )
    return dataclasses.replace(
        state,
        sdc_paths=logged_route,
    )


def ray_segment_intersection(
    ray_angle: jax.Array, start_points: jax.Array, segment_dirs: jax.Array
) -> jax.Array:
    """
    Calculate the intersection distances between a ray and line segments.

    Args:
        ray_angle: The angle of the ray (in radians).
        start_points: Array of shape (N, 2) for segment start points (x,y).
        segment_dirs: Array of shape (N, 2) for segment directions (dx, dy).

    Returns:
        Array of distances from origin to intersections. Returns 100.0 if no intersection.
    """
    # Calculate ray direction
    ray_dir_x = jnp.sin(ray_angle)
    ray_dir_y = jnp.cos(ray_angle)

    # Calculate segment direction
    segment_dir_x = segment_dirs[:, 0]
    segment_dir_y = segment_dirs[:, 1]

    # Calculate determinant for intersection test
    det = segment_dir_x * ray_dir_y - segment_dir_y * ray_dir_x

    # Avoid division by zero for parallel lines
    is_parallel = jnp.abs(det) < 1e-8
    det = jnp.where(is_parallel, 1.0, det)  # Avoid division by zero

    # Calculate t1 and t2 parameters
    t1 = (start_points[:, 0] * ray_dir_y - start_points[:, 1] * ray_dir_x) / det
    t2 = (start_points[:, 0] * segment_dir_y - start_points[:, 1] * segment_dir_x) / det

    # Check if intersection is within segment (0 <= t1 <= 1) and ray (t2 >= 0)
    valid_t1 = (t1 >= 0.0) & (t1 <= 1.0)
    valid_t2 = t2 >= 0.0

    # Combine intersection validity checks
    valid_intersection = valid_t1 & valid_t2 & ~is_parallel

    # Calculate intersection points and distances
    ix = start_points[:, 0] + t1 * segment_dir_x
    iy = start_points[:, 1] + t1 * segment_dir_y

    # Distance from origin to intersection point
    distances = jnp.sqrt(ix**2 + iy**2)

    # Make sure distances over 100 are not valid
    valid_intersection = valid_intersection & (distances < 100.0)

    # Return distance if valid intersection, otherwise 100.0
    return jnp.where(valid_intersection, distances, 100.0)


def find_closest_distance(
    i, initval: Tuple[jax.Array, datatypes.RoadgraphPoints, jax.Array]
):
    circogram, rg_points, ray_angles = initval
    ray_angle = ray_angles[i]

    # Only consider valid points
    candidate_points = rg_points.valid

    # Only consider road edge points
    candidate_points = candidate_points & (
        (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_BOUNDARY)
        | (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_MEDIAN)
        | (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_UNKNOWN)
    )

    # If no valid points, return the default distance
    has_valid = jnp.any(candidate_points)

    # Create line segments from roadgraph points
    starting_points = jnp.stack([rg_points.x, rg_points.y], axis=1)
    dir_xy = jnp.stack([rg_points.dir_x, rg_points.dir_y], axis=1)

    # Calculate intersection distances
    intersection_distances = ray_segment_intersection(
        ray_angle, starting_points, dir_xy
    )
    masked_distances = jnp.where(candidate_points, intersection_distances, 100.0)

    # Find minimum distance among valid points
    min_distance = jnp.min(masked_distances)

    # Update only the i-th element and return the whole array
    circogram = circogram.at[i].set(jnp.where(has_valid, min_distance, circogram[i]))
    return circogram, rg_points, ray_angles


def create_circogram(observation: datatypes.Observation, num_rays: int) -> jax.Array:
    ray_angles = jnp.linspace(0, 2 * jnp.pi, num_rays, endpoint=False)
    rg_points = observation.roadgraph_static_points
    # For each ray angle, find the closest intersection
    circogram = jnp.full(num_rays, 100.0)  # Default max distance
    (
        circogram,
        _,
        _,
    ) = jax.lax.fori_loop(
        0,
        num_rays,
        find_closest_distance,
        (circogram, rg_points, ray_angles),
    )
    return circogram


class WaymaxEnv(_env.PlanningAgentEnvironment):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @override
    def observe(self, state: _env.PlanningAgentSimulatorState) -> Any:
        """Computes the observation for the given simulation state.

        Here we assume that the default observation is just the simulator state. We
        leave this for the user to override in order to provide a user-specific
        observation function. A user can use this to move some of their model
        specific post-processing into the environment rollout in the actor nodes. If
        they want this post-processing on the accelerator, they can keep this the
        same and implement it on the learner side. We provide some helper functions
        at datatypes.observation.py to help write your own observation functions.

        Args:
          state: Current state of the simulator of shape (...).

        Returns:
          Simulator state as an observation without modifications of shape (...).
        """
        # Get base observation from SDC perspective first
        observation = datatypes.sdc_observation_from_state(state, roadgraph_top_k=1000)

        sdc_trajectory = datatypes.select_by_onehot(
            observation.trajectory,
            observation.is_ego,
            keepdims=True,
        )
        sdc_velocity_xy = sdc_trajectory.vel_xy

        # Create the goal position from the last point in the logged trajectory
        sdc_xy_goal = datatypes.select_by_onehot(
            state.log_trajectory.xy[..., -1, :],
            state.object_metadata.is_sdc,
            keepdims=True,
        )
        sdc_xy_goal = utils.transform_points(observation.pose2d.matrix, sdc_xy_goal)[0]
        # Convert the goal position from Cartesian to polar coordinates
        sdc_goal_distance = jnp.sqrt(
            sdc_xy_goal[..., 0] ** 2 + sdc_xy_goal[..., 1] ** 2
        )
        sdc_goal_angle = jnp.arctan2(sdc_xy_goal[..., 1], sdc_xy_goal[..., 0])
        sdc_yaw_goal = datatypes.select_by_onehot(
            state.log_trajectory.yaw[..., -1],
            state.object_metadata.is_sdc,
            keepdims=True,
        )

        sdc_offroad = is_offroad(sdc_trajectory, observation.roadgraph_static_points)
        sdc_offroad = sdc_offroad.astype(jnp.float32)  # Convert boolean to float32

        _, sdc_idx = jax.lax.top_k(observation.is_ego, k=1)
        non_sdc_xy = jnp.delete(
            observation.trajectory.xy, sdc_idx, axis=1, assume_unique_indices=True
        ).reshape(31, 2)
        non_sdc_vel_xy = jnp.delete(
            observation.trajectory.vel_xy, sdc_idx, axis=1, assume_unique_indices=True
        ).reshape(31, 2)
        non_sdc_valid = jnp.delete(
            observation.trajectory.valid, sdc_idx, axis=1, assume_unique_indices=True
        ).reshape(31, 1)
        # Set positions of invalid objects to 10000
        non_sdc_xy = non_sdc_xy * non_sdc_valid + (1 - non_sdc_valid) * 10000
        # Set velocities of invalid objects to 0
        non_sdc_vel_xy = non_sdc_vel_xy * non_sdc_valid

        num_rays = 64
        circogram = create_circogram(observation, num_rays)

        obs = jnp.concatenate(
            [
                sdc_goal_distance.flatten(),
                sdc_goal_angle.flatten(),
                sdc_velocity_xy.flatten(),
                sdc_offroad.flatten(),
                circogram.flatten(),
            ],
            axis=-1,
        )
        return obs

    @override
    def observation_spec(self) -> types.Observation:
        """Returns the observation spec of the environment.
        Returns:
            Observation spec of the environment.
        """
        # Define dimensions for each observation component
        sdc_goal_dim = 2
        sdc_vel_dim = 2
        sdc_offroad_dim = 1
        circogram_dim = 64

        # Total shape is the sum of all component dimensions
        total_dim = sdc_goal_dim + sdc_vel_dim + sdc_offroad_dim + circogram_dim

        # Define min/max bounds for each component
        sdc_goal_min = [-1000] * sdc_goal_dim
        sdc_goal_max = [1000] * sdc_goal_dim

        sdc_vel_min = [-30] * sdc_vel_dim
        sdc_vel_max = [30] * sdc_vel_dim

        sdc_offroad_min = [0] * sdc_offroad_dim
        sdc_offroad_max = [1] * sdc_offroad_dim

        circogram_min = [0] * circogram_dim
        circogram_max = [100] * circogram_dim

        # Combine all bounds
        min_bounds = jnp.array(
            sdc_goal_min + sdc_vel_min + sdc_offroad_min + circogram_min
        )
        max_bounds = jnp.array(
            sdc_goal_max + sdc_vel_max + sdc_offroad_max + circogram_max
        )

        return specs.BoundedArray(
            shape=(total_dim,),
            minimum=min_bounds,
            maximum=max_bounds,
            dtype=jnp.float32,
        )

max_num_objects = 32
config = dataclasses.replace(
    _config.WOD_1_1_0_TRAINING,
    max_num_objects=max_num_objects,
    max_num_rg_points=30000,
    path="./data/training_tfexample.tfrecord@5",
)
data_iter = dataloader.simulator_state_generator(config=config)
sim_agent_config = _config.SimAgentConfig(
    agent_type=_config.SimAgentType.IDM,
    controlled_objects=_config.ObjectType.NON_SDC,
)
metrics_config = _config.MetricsConfig(
    metrics_to_run=("sdc_progression", "offroad")
)
reward_config = _config.LinearCombinationRewardConfig(
    rewards={"sdc_progression": 1.0, "offroad": -1.0},
)
env_config = dataclasses.replace(
    _config.EnvironmentConfig(),
    metrics=metrics_config,
    rewards=reward_config,
    max_num_objects=max_num_objects,
    sim_agents=[sim_agent_config],
)
dynamics_model = dynamics.InvertibleBicycleModel(normalize_actions=True)
env = WaymaxEnv(
    dynamics_model=dynamics_model,
    config=env_config,
    sim_agent_actors=[agents.create_sim_agents_from_config(sim_agent_config)],
    sim_agent_params=[{}],
)

obj_idx = jnp.arange(max_num_objects)
actor = agents.create_constant_speed_actor(
    speed=5.0,
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: obj_idx !=-1,
)
expert_actor = agents.create_expert_actor(
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: obj_idx !=-1,
)
jit_step =jax.jit(env.step)
jit_select = jax.jit(datatypes.select_by_onehot, static_argnums=(2))
jit_select_action = jax.jit(expert_actor.select_action)
jit_observe = jax.jit(env.observe)
jit_observe_from_state = jax.jit(datatypes.sdc_observation_from_state)
jit_reward = jax.jit(env.reward)
jit_construct_SDC_route = jax.jit(construct_SDC_route)
jit_reset = jax.jit(env.reset)
jit_transform_points = jax.jit(utils.transform_points)

def mpc(state: _env.PlanningAgentSimulatorState) -> datatypes.Action:
    observation = jit_observe_from_state(state)
    sdc_trajectory = jit_select(
        observation.trajectory,
        observation.is_ego,
        keepdims=True,
    )
    sdc_velocity_xy = sdc_trajectory.vel_xy
    sdc_xy_goal = jit_select(
        state.log_trajectory.xy[..., -1, :],
        state.object_metadata.is_sdc,
        keepdims=True,
    )
    sdc_xy_goal = jit_transform_points(observation.pose2d.matrix, sdc_xy_goal)[0]

    start_x = 0.0
    start_y = 0.0
    start_yaw = 0.0
    start_vel_x = float(sdc_velocity_xy.flatten()[0])
    start_vel_y = float(sdc_velocity_xy.flatten()[1])
    current_speed = np.sqrt(start_vel_x**2 + start_vel_y**2)

    target_x = float(sdc_xy_goal[0])
    target_y = float(sdc_xy_goal[1])


    MAX_ACCEL = 6.0
    MAX_STEERING = 0.3
    DT = 0.1
    N = 20  # Shorter horizon for speed

    try:
        # Create an optimization problem
        opti = casadi.Opti()

        # State variables - use scalar speed instead of separate velocities
        x = opti.variable(N+1)  # x position
        y = opti.variable(N+1)  # y position
        yaw = opti.variable(N+1)  # heading angle
        speed = opti.variable(N+1)  # scalar speed

        # Control inputs - use normalized controls
        accel_norm = opti.variable(N)  # normalized acceleration
        steering_norm = opti.variable(N)  # normalized steering

        # Initial conditions
        opti.subject_to(x[0] == start_x)
        opti.subject_to(y[0] == start_y)
        opti.subject_to(yaw[0] == start_yaw)
        opti.subject_to(speed[0] == current_speed)

        # Dynamics constraints (simplified bicycle model)
        for i in range(N):
            # Denormalize controls
            accel = accel_norm[i] * MAX_ACCEL
            steering = steering_norm[i] * MAX_STEERING
            
            # Simplified dynamics - avoid division by small speed
            opti.subject_to(x[i+1] == x[i] + DT * speed[i] * casadi.cos(yaw[i]))
            opti.subject_to(y[i+1] == y[i] + DT * speed[i] * casadi.sin(yaw[i]))
            opti.subject_to(speed[i+1] == speed[i] + DT * accel)
            
            # Yaw dynamics 
            opti.subject_to(yaw[i+1] == yaw[i] + DT * steering * speed[i])

        # Bound controls between -1 and 1 (normalized)
        for i in range(N):
            opti.subject_to(accel_norm[i] >= -1.0)
            opti.subject_to(accel_norm[i] <= 1.0)
            opti.subject_to(steering_norm[i] >= -1.0)
            opti.subject_to(steering_norm[i] <= 1.0)

        # Speed constraints
        for i in range(N+1):
            opti.subject_to(speed[i] >= 0.0)  # Minimum speed 0.0 m/s
            opti.subject_to(speed[i] <= 20.0)  # Maximum speed 20 m/s

        # Simplified objective function
        distance_to_goal = (x[N] - target_x)**2 + (y[N] - target_y)**2  # Terminal cost
        control_effort = casadi.sum1(accel_norm**2 + 5.0 * steering_norm**2)  # Control regularization
        
        # Provide initial guess - critical for convergence!
        opti.set_initial(x, np.linspace(start_x, target_x, N+1))
        opti.set_initial(y, np.linspace(start_y, target_y, N+1))
        opti.set_initial(yaw, np.zeros(N+1))
        opti.set_initial(speed, np.ones(N+1) * max(0.5, current_speed))
        opti.set_initial(accel_norm, np.zeros(N))
        opti.set_initial(steering_norm, np.zeros(N))

        # Set objective
        opti.minimize(10.0 * distance_to_goal + control_effort)

        # Better solver options
        p_opts = {"expand": True}
        s_opts = {
            "max_iter": 100,
            "print_level": 0,
            "acceptable_tol": 1e-2,  # Looser tolerance for better convergence
            "acceptable_obj_change_tol": 1e-2,
            "hessian_approximation": "limited-memory"  # Use L-BFGS for better numerical stability
        }
        opti.solver('ipopt', p_opts, s_opts)
        
        sol = opti.solve()
        optimal_accel = float(sol.value(accel_norm[0]))
        optimal_steering = float(sol.value(steering_norm[0]))

    except Exception as e:
        print(f"MPC failed: {str(e)[:100]}...")  # Print first part of error
        
        # Fallback to simple controller
        angle_to_target = np.arctan2(target_y, target_x)
        angle_diff = angle_to_target - start_yaw
        # Normalize angle difference to [-pi, pi]
        angle_diff = np.arctan2(np.sin(angle_diff), np.cos(angle_diff))
        
        optimal_steering = np.clip(0.5 * angle_diff / MAX_STEERING, -1.0, 1.0)
        optimal_accel = 0.3  # Moderate acceleration for fallback
        
    return datatypes.Action(
        data=jnp.array([optimal_accel, optimal_steering]),
        valid=jnp.array([True]),
    )

states= []
for _ in range(10):
    scenario = next(data_iter)
    state=jit_reset(scenario)
    state=jit_construct_SDC_route(state)
    states.append(state)
    for _ in tqdm(range(80)):
        action=mpc(state)
        state = jit_step(state, action)
        reward = jit_reward(state, action)
        states.append(state)
imgs = []
for state in states:
    img = visualization.plot_simulator_state(state, use_log_traj=False)
    imgs.append(img)

mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.
