In [10]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from google.protobuf.descriptor import FieldDescriptor
import matplotlib
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.use('TkAgg', force=True)
import matplotlib.pyplot as plt

from waymo_open_dataset.protos import end_to_end_driving_data_pb2 as wod_e2ed_pb2
from waymo_open_dataset import dataset_pb2 as open_dataset

# -----------------------
# CONFIG (改这里)
# -----------------------
TFRECORD  = "/mnt/d/Datasets/WOD_E2E_Camera_v1/val/val_202504211843.tfrecord-00026-of-00093"
RECORD_IDX = 373
SEG_ID = "fb0ed944efebd34d756103188d59da85"  # 仅用于标题展示（不是必须）
KEY_IDX = 147                                # 仅用于标题展示（不是必须）
CAM_ID = 7                                   # 你说 CAM= "1" 才行

OUT_PNG = None
# OUT_PNG = f"/mnt/d/Datasets/WOD_E2E_Camera_v1/_debug_overlay/{SEG_ID}_key{KEY_IDX}_cam{CAM_ID}.png"


# -----------------------
# 1) Parse one record
# -----------------------
def parse_e2ed(raw_bytes: bytes):
    msg = wod_e2ed_pb2.E2EDFrame()
    msg.ParseFromString(raw_bytes)
    return msg

ds = tf.data.TFRecordDataset(TFRECORD, compression_type="")
raw = None
for i, r in enumerate(ds):
    if i == RECORD_IDX:
        raw = bytes(r.numpy()); break
assert raw is not None, "没读到指定 RECORD_IDX，请确认 TFRECORD/RECORD_IDX"

e2e = parse_e2ed(raw)
frame = e2e.frame
print("context.name:", getattr(frame.context, "name", None))
print("intent:", int(getattr(e2e, "intent", -1)))


# -----------------------
# 2) Get image + calib
# -----------------------
def get_image_by_cam_id(frame, cam_id: int):
    cam_id = int(cam_id)
    for im in frame.images:
        if int(im.name) == cam_id:
            rgb = tf.io.decode_jpeg(im.image).numpy()
            return rgb
    return None

def get_camera_calib_by_cam_id(frame, cam_id: int):
    cam_id = int(cam_id)
    for c in frame.context.camera_calibrations:
        if int(c.name) == cam_id:
            return c
    return None

rgb = get_image_by_cam_id(frame, CAM_ID)
calib = get_camera_calib_by_cam_id(frame, CAM_ID)
assert rgb is not None and calib is not None, f"cam{CAM_ID} missing rgb or calib"

H, W = rgb.shape[:2]
print("rgb shape:", rgb.shape)


# -----------------------
# 3) Manual projection (你成功跑过的版本)
# -----------------------
def calib_to_mats(calib):
    """返回 fx,fy,cx,cy, T_vehicle_from_cam(4x4). 其中 extrinsic 是 camera->vehicle。"""
    intr = list(calib.intrinsic)
    fx, fy, cx, cy = intr[0], intr[1], intr[2], intr[3]
    T_v_from_c = np.array(list(calib.extrinsic.transform), dtype=np.float32).reshape(4, 4)  # camera->vehicle
    return fx, fy, cx, cy, T_v_from_c

def vehicle_xyz_to_cam_xyz(xyz_vehicle, T_vehicle_from_cam):
    """
    xyz_vehicle: (N,3) in vehicle frame (FLU)
    T_vehicle_from_cam: camera->vehicle
    return xyz_cam: (N,3) in camera frame
    """
    xyz_vehicle = np.asarray(xyz_vehicle, dtype=np.float32)
    N = xyz_vehicle.shape[0]
    ones = np.ones((N, 1), dtype=np.float32)
    P_v = np.concatenate([xyz_vehicle, ones], axis=1)  # (N,4)

    T_c_from_v = np.linalg.inv(T_vehicle_from_cam)     # vehicle->camera
    P_c = (T_c_from_v @ P_v.T).T                       # (N,4)
    return P_c[:, :3]

def project_cam_xyz_to_uv(xyz_cam, fx, fy, cx, cy, min_depth=0.5):
    """
    常用 Waymo camera frame: x 为前向深度，y 左，z 上（你之前版本就是这么用的）
    u = cx - fx*(y/x), v = cy - fy*(z/x)
    """
    xyz_cam = np.asarray(xyz_cam, dtype=np.float32)
    x = xyz_cam[:, 0]  # depth
    y = xyz_cam[:, 1]
    z = xyz_cam[:, 2]
    mask = x > float(min_depth)

    u = cx - fx * (y / np.clip(x, 1e-6, None))
    v = cy - fy * (z / np.clip(x, 1e-6, None))
    uv = np.stack([u, v], axis=1).astype(np.float32)
    return uv, mask

def project_vehicle_xyz_to_uv(xyz_vehicle, calib, min_depth=0.5):
    fx, fy, cx, cy, T_v_from_c = calib_to_mats(calib)
    xyz_cam = vehicle_xyz_to_cam_xyz(xyz_vehicle, T_v_from_c)
    uv, mask = project_cam_xyz_to_uv(xyz_cam, fx, fy, cx, cy, min_depth=min_depth)
    return uv, mask

def in_bounds_mask(uv, H, W):
    return (uv[:,0] >= 0) & (uv[:,0] < W) & (uv[:,1] >= 0) & (uv[:,1] < H)


# -----------------------
# 4) Ego past / future xyz (稳健对齐长度)
# -----------------------
def ego_states_xyz(states_msg, fallback_z=0.0):
    if states_msg is None:
        return None
    if not (hasattr(states_msg, "pos_x") and hasattr(states_msg, "pos_y")):
        return None
    x = np.array(list(states_msg.pos_x), dtype=np.float32)
    y = np.array(list(states_msg.pos_y), dtype=np.float32)
    z = np.array(list(getattr(states_msg, "pos_z", [])), dtype=np.float32)

    if len(x) == 0 or len(y) == 0:
        return None

    if len(z) == 0:
        z = np.full((min(len(x), len(y)),), float(fallback_z), dtype=np.float32)

    n = min(len(x), len(y), len(z))
    return np.stack([x[:n], y[:n], z[:n]], axis=1)

ego_past = ego_states_xyz(e2e.past_states, fallback_z=0.0)
ego_fut  = ego_states_xyz(e2e.future_states, fallback_z=0.0)

print("ego_past:", None if ego_past is None else ego_past.shape)
print("ego_future:", None if ego_fut  is None else ego_fut.shape)


# -----------------------
# 5) Preference trajectories xyz (递归找 pos_x/pos_y/pos_z 或 x/y/z)
# -----------------------
def _is_message(x):
    return hasattr(x, "ListFields") and callable(getattr(x, "ListFields"))

def _try_extract_pos_arrays_from_message(msg):
    if not _is_message(msg):
        return None
    field_map = {fd.name: val for fd, val in msg.ListFields()}

    # pos_x/pos_y/pos_z
    if "pos_x" in field_map and "pos_y" in field_map:
        x = np.array(list(field_map["pos_x"]), dtype=np.float32)
        y = np.array(list(field_map["pos_y"]), dtype=np.float32)
        z = np.array(list(field_map.get("pos_z", [])), dtype=np.float32)
        if len(z) == 0:
            z = np.zeros_like(x)
        n = min(len(x), len(y), len(z))
        return np.stack([x[:n], y[:n], z[:n]], axis=1)

    # x/y/z
    if "x" in field_map and "y" in field_map:
        x = np.array(list(field_map["x"]), dtype=np.float32)
        y = np.array(list(field_map["y"]), dtype=np.float32)
        z = np.array(list(field_map.get("z", [])), dtype=np.float32)
        if len(z) == 0:
            z = np.zeros_like(x)
        n = min(len(x), len(y), len(z))
        return np.stack([x[:n], y[:n], z[:n]], axis=1)

    return None

def extract_xyz_recursive(msg, max_depth=5):
    if max_depth < 0 or (not _is_message(msg)):
        return None

    direct = _try_extract_pos_arrays_from_message(msg)
    if direct is not None and direct.shape[0] > 0:
        return direct

    for fd, val in msg.ListFields():
        if fd.type != FieldDescriptor.TYPE_MESSAGE:
            continue

        if fd.label == FieldDescriptor.LABEL_REPEATED:
            if len(val) == 0:
                continue
            cand = extract_xyz_recursive(val[0], max_depth=max_depth-1)
            if cand is not None:
                return cand
        else:
            cand = extract_xyz_recursive(val, max_depth=max_depth-1)
            if cand is not None:
                return cand

    return None

pref_xyz_list = []
for i, pref in enumerate(e2e.preference_trajectories):
    xyz = extract_xyz_recursive(pref, max_depth=6)
    score = float(getattr(pref, "preference_score", np.nan))
    print(f"[pref {i}] score={score}, xyz={'None' if xyz is None else xyz.shape}")
    pref_xyz_list.append(xyz)


# -----------------------
# 6) Overlay draw (统计可见点 + 画图)
# -----------------------
def count_visible(xyz):
    if xyz is None or len(xyz) < 2:
        return 0
    uv, m = project_vehicle_xyz_to_uv(xyz, calib, min_depth=0.05)
    m = m & in_bounds_mask(uv, H, W)
    return int(m.sum())

v_p = count_visible(ego_past)
v_f = count_visible(ego_fut)
v_pref = [count_visible(x) for x in pref_xyz_list]
print(f"cam{CAM_ID}: visible past={v_p}, future={v_f}, prefs={v_pref}")

plt.figure(figsize=(14,8))
ax = plt.gca()
ax.imshow(rgb)
ax.set_title(f"{SEG_ID} key_idx={KEY_IDX} | cam{CAM_ID} overlay | past={v_p}, future={v_f}, prefs={v_pref}")
ax.axis("off")

# keyframe origin
origin = np.array([[0.0, 0.0, 0.0]], dtype=np.float32)
uv0, m0 = project_vehicle_xyz_to_uv(origin, calib, min_depth=0.05)
if m0[0] and in_bounds_mask(uv0, H, W)[0]:
    ax.scatter([uv0[0,0]], [uv0[0,1]], s=90, marker="*", label="keyframe (0,0,0)")

# ego past/future
if ego_past is not None and len(ego_past) >= 2:
    uv, m = project_vehicle_xyz_to_uv(ego_past, calib, min_depth=0.05)
    m = m & in_bounds_mask(uv, H, W)
    if m.any():
        ax.plot(uv[m,0], uv[m,1], marker=".", linewidth=2, label=f"ego_past({m.sum()}/{len(m)})")

if ego_fut is not None and len(ego_fut) >= 2:
    uv, m = project_vehicle_xyz_to_uv(ego_fut, calib, min_depth=0.05)
    m = m & in_bounds_mask(uv, H, W)
    if m.any():
        ax.plot(uv[m,0], uv[m,1], marker="x", linewidth=3, label=f"ego_future({m.sum()}/{len(m)})")

# preference trajectories
for i, xyz in enumerate(pref_xyz_list):
    if xyz is None or len(xyz) < 2:
        continue
    uv, m = project_vehicle_xyz_to_uv(xyz, calib, min_depth=0.05)
    m = m & in_bounds_mask(uv, H, W)
    if m.any():
        score = float(getattr(e2e.preference_trajectories[i], "preference_score", np.nan))
        ax.plot(uv[m,0], uv[m,1], marker="o", linewidth=2, label=f"pref{i} score={score:.0f} ({m.sum()})")

ax.legend()
plt.tight_layout()

if OUT_PNG:
    os.makedirs(os.path.dirname(OUT_PNG), exist_ok=True)
    plt.savefig(OUT_PNG, dpi=200)
    plt.close()
    print("saved:", OUT_PNG)
else:
    plt.show()


context.name: fb0ed944efebd34d756103188d59da85-147
intent: 1
rgb shape: (551, 972, 3)
ego_past: (16, 3)
ego_future: (20, 3)
[pref 0] score=9.0, xyz=(21, 3)
[pref 1] score=6.0, xyz=(21, 3)
[pref 2] score=10.0, xyz=(21, 3)
cam7: visible past=13, future=0, prefs=[0, 0, 0]
