In [None]:
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses
import jax
import jax.numpy as jnp
from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax import visualization
import cv2

max_num_objects = 32
config = dataclasses.replace(
    _config.WOD_1_1_0_TRAINING,
    max_num_objects=max_num_objects,
    max_num_rg_points=30000,
    path="./data/training_tfexample.tfrecord@5",
)
data_iter = dataloader.simulator_state_generator(config=config)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)


img = visualization.plot_simulator_state(scenario, use_log_traj=True)
imgs = []
ray_visualization = []
state = scenario


def find_closest_distance(i, initval):
    distances, rg_points, rg_angles, num_rays, ray_angles = initval
    ray_angle = ray_angles[i]

    # Calculate angle difference and normalize to [-pi, pi]
    angle_diff = rg_angles - ray_angle
    angle_diff = (angle_diff + jnp.pi) % (2 * jnp.pi) - jnp.pi

    # Only consider points roughly in the ray direction (within tolerance)
    angle_tolerance = jnp.pi / num_rays
    candidate_points = jnp.abs(angle_diff) < angle_tolerance

    # Only consider valid points
    candidate_points = candidate_points & rg_points.valid

    # Only consider boundary points
    candidate_points = candidate_points & (
        (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_BOUNDARY)
        | (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_MEDIAN)
        | (rg_points.types == datatypes.MapElementIds.ROAD_EDGE_UNKNOWN)
    )

    # If no valid points, return the default distance
    has_valid = jnp.any(candidate_points)

    # Calculate distances to all valid points
    point_distances = jnp.sqrt(rg_points.x**2 + rg_points.y**2)
    masked_distances = jnp.where(candidate_points, point_distances, 100.0)

    # Find minimum distance among valid points
    min_distance = jnp.min(masked_distances)

    # Update only the i-th element and return the whole array
    new_distances = distances.at[i].set(
        jnp.where(has_valid, min_distance, distances[i])
    )
    return new_distances, rg_points, rg_angles, num_rays, ray_angles

def create_circogram(state,num_rays,ray_angles):
    observation = datatypes.sdc_observation_from_state(state,roadgraph_top_k=2000)
    rg_points = observation.roadgraph_static_points
    rg_angles = jnp.arctan2(rg_points.x, rg_points.y)
    # For each ray angle, find the closest roadgraph point
    closest_distances = jnp.full(num_rays, 100.0)  # Default large distance
    closest_distances, _,_,_,_ = jax.lax.fori_loop(
        0,
        num_rays,
        find_closest_distance,
        (closest_distances, rg_points, rg_angles, num_rays, ray_angles),
    )
    return closest_distances

jit_step = jax.jit(datatypes.update_state_by_log, static_argnums=(1,))
jit_observed = jax.jit(datatypes.sdc_observation_from_state, static_argnums=(1,2))
jit_create_circogram = jax.jit(create_circogram, static_argnums=(1))
for _ in tqdm(range(scenario.remaining_timesteps)):
    state = jit_step(state, num_steps=1)
    #imgs.append(visualization.plot_simulator_state(state, use_log_traj=True))
    imgs.append(
        visualization.plot_observation(
            jit_observed(state, roadgraph_top_k=2000),
            obj_idx=0
        )
    )

    num_rays = 24
    ray_angles = jnp.linspace(-jnp.pi, jnp.pi, num_rays, endpoint=False)
    circogram = jit_create_circogram(state, num_rays, ray_angles)
    

  # Create a visualization of ray distances
    small_size = 500  # Size of the visualization image
    scale_factor = 5.0  # Scale factor for distance visualization
    x = small_size // 2  # Center of the image
    y = small_size // 2  # Center of the image
    ray_img = np.ones((small_size, small_size, 3), dtype=np.uint8) * 255  # White background
    center = (x, y)  # Center of the image

    # Draw rays
    for i in range(num_rays):
        # Calculate endpoint based on ray angle and distance
        angle = ray_angles[i]
        dist = circogram[i]

        # Scale the distance for visualization
        scaled_dist = dist * scale_factor

        # Calculate endpoint
        end_x = int(center[0] + np.sin(angle) * scaled_dist)
        end_y = int(center[1] + np.cos(angle) * scaled_dist)

        # Draw a line from center to endpoint
        cv2.line(ray_img, center, (end_x, end_y), (0, 0, 255), 1)  # Red lines

    # Add a circle for the SDC
    cv2.circle(ray_img, center, 5, (0, 255, 0), -1)  # Green circle for SDC

    # Flip the ray image upside down to match conventional coordinate system
    ray_img = cv2.flip(ray_img, 0)

    # Draw a reference circle at maximum range
    max_range_radius = int(100.0 * scale_factor)  # 100 meters reference
    cv2.circle(ray_img, center, max_range_radius, (200, 200, 200), 1)  # Gray circle
    # Add distance markers
    for dist in [20, 40, 60, 80]:
        radius = int(dist * scale_factor)
        cv2.circle(ray_img, center, radius, (220, 220, 220), 1, cv2.LINE_AA)
        # Add text label
        text_pos = (center[0] + 5, center[1] - radius - 5)
        cv2.putText(ray_img, f"{dist}m", text_pos, cv2.FONT_HERSHEY_SIMPLEX, 
                    0.4, (100, 100, 100), 1, cv2.LINE_AA)

    # Add to ray_visualization array
    ray_visualization.append(ray_img)
mediapy.show_video(imgs, fps=10)
mediapy.show_video(ray_visualization, fps=10)



100%|██████████| 90/90 [00:11<00:00,  7.56it/s]


0
This browser does not support the video tag.


0
This browser does not support the video tag.
