In [5]:
import os, glob
import numpy as np
import tensorflow as tf
from waymo_open_dataset import dataset_pb2 as open_dataset
from waymo_open_dataset.protos import end_to_end_driving_data_pb2 as wod_e2ed_pb2


def parse_e2ed(raw_bytes: bytes):
    m = wod_e2ed_pb2.E2EDFrame()
    m.ParseFromString(raw_bytes)
    return m

def split_context_name(ctx_name: str):
    """'<segment>-<idx>' -> (segment_id, idx:int or None)"""
    if not ctx_name or "-" not in ctx_name:
        return ctx_name, None
    base, suf = ctx_name.rsplit("-", 1)
    try:
        return base, int(suf)
    except ValueError:
        return ctx_name, None

def get_segment_and_keyidx(e2e):
    ctx = e2e.frame.context.name if (hasattr(e2e, "frame") and hasattr(e2e.frame, "context")) else ""
    seg_id, key_idx = split_context_name(ctx)
    return seg_id, key_idx, ctx

def extract_ego_states(states_obj, keys=(
    "pos_x","pos_y","pos_z",
    "vel_x","vel_y","vel_z",
    "accel_x","accel_y","accel_z",
    "heading"
)):
    """
    适配你当前看到的 EgoTrajectoryStates 结构：很多字段是 repeated scalar
    返回: dict[k]=np.array(float32), 以及长度 L
    """
    out = {}
    if states_obj is None:
        return out, 0

    for k in keys:
        if hasattr(states_obj, k):
            out[k] = np.array(list(getattr(states_obj, k)), dtype=np.float32)

    # 取一个存在字段的长度
    L = 0
    for v in out.values():
        L = len(v)
        break
    return out, L

def extract_ego_from_record(e2e):
    """
    输出一条 record 的 ego 信息（keyframe级别）
    """
    seg_id, key_idx, key_id = get_segment_and_keyidx(e2e)
    intent = int(getattr(e2e, "intent", -1))

    past_dict, past_L = extract_ego_states(getattr(e2e, "past_states", None))
    fut_dict,  fut_L  = extract_ego_states(getattr(e2e, "future_states", None))

    return {
        "segment_id": seg_id,
        "key_idx": key_idx,
        "key_id": key_id,
        "intent": intent,
        "past": past_dict,
        "past_L": past_L,
        "future": fut_dict,
        "future_L": fut_L,
    }


In [6]:
def list_tfrecords(folder):
    files = sorted(glob.glob(os.path.join(folder, "*.tfrecord*")))
    if not files:
        raise FileNotFoundError(f"No tfrecord files under: {folder}")
    return files

def collect_segment_ego(val_or_train_folder, target_segment_id, max_files=None):
    """
    返回: records(list)，每个元素是 extract_ego_from_record 的输出
    """
    files = list_tfrecords(val_or_train_folder)
    if max_files is not None:
        files = files[:max_files]

    records = []
    for fp in files:
        ds = tf.data.TFRecordDataset(fp, compression_type=("GZIP" if fp.endswith(".gz") else ""))
        for r in ds:
            e2e = parse_e2ed(r.numpy())
            seg_id, key_idx, _ = get_segment_and_keyidx(e2e)
            if seg_id != target_segment_id:
                continue
            rec = extract_ego_from_record(e2e)
            if rec["key_idx"] is not None:
                records.append(rec)

    records.sort(key=lambda x: x["key_idx"])
    return records


In [7]:
def fuse_segment_ego_timeline(records, fps=10, past_includes_keyframe=False):
    """
    把多个 keyframe 的 past/future 投到统一 abs_idx，并融合（同一 idx 多个值取平均）
    返回: timeline dict，包含 idx + 各字段数组（缺失为 NaN）
    """
    if not records:
        return None

    # abs_idx -> field -> list[values]
    agg = {}

    def _push(abs_i, field, val):
        agg.setdefault(abs_i, {}).setdefault(field, []).append(float(val))

    for rec in records:
        key = rec["key_idx"]
        past = rec["past"]
        fut  = rec["future"]

        # past offsets
        if past:
            L = rec["past_L"]
            if past_includes_keyframe:
                offsets = np.arange(-L+1, 1, dtype=int)   # [-L+1 .. 0]
            else:
                offsets = np.arange(-L, 0, dtype=int)     # [-L .. -1]

            for j, off in enumerate(offsets):
                abs_i = key + int(off)
                for field, arr in past.items():
                    if j < len(arr):
                        _push(abs_i, field, arr[j])

        # future offsets: 默认不包含 keyframe（[1..L]）
        if fut:
            L = rec["future_L"]
            offsets = np.arange(1, L+1, dtype=int)
            for j, off in enumerate(offsets):
                abs_i = key + int(off)
                for field, arr in fut.items():
                    if j < len(arr):
                        _push(abs_i, field, arr[j])

    idxs = sorted(agg.keys())
    # 收集全字段集合
    fields = sorted({f for i in idxs for f in agg[i].keys()})

    out = {"segment_id": records[0]["segment_id"], "fps": fps, "idx": np.array(idxs, dtype=int)}
    for f in fields:
        out[f] = np.array(
            [np.mean(agg[i].get(f, [np.nan])) if f in agg[i] else np.nan for i in idxs],
            dtype=np.float32
        )

    return out


In [None]:
VAL_DIR = r"/mnt/d/Datasets/WOD_E2E_Camera_v1/val"  # 你自己的路径
SEG_ID  = "fb0ed944efebd34d756103188d59da85"

# 1) 收集 keyframe 级 ego status
records = collect_segment_ego(VAL_DIR, SEG_ID)
print("keyframes:", len(records), "key_idx range:", (records[0]["key_idx"], records[-1]["key_idx"]) if records else None)

# 2) 融合成 segment 时间轴轨迹
timeline = fuse_segment_ego_timeline(records, fps=10, past_includes_keyframe=False)
print("timeline idx range:", (int(timeline["idx"].min()), int(timeline["idx"].max())) if timeline else None)

# 3) 保存（npz 最方便）
if timeline:
    np.savez(f"{SEG_ID}_ego_cam_free.npz", **timeline)
    print("saved:", f"{SEG_ID}_ego_cam_free.npz")


In [1]:
import tensorflow as tf
import numpy as np

def parse_e2ed(raw_bytes: bytes):
    m = wod_e2ed_pb2.E2EDFrame()
    m.ParseFromString(raw_bytes)
    return m

def split_context_name(ctx_name: str):
    if not ctx_name or "-" not in ctx_name:
        return ctx_name, None
    base, suf = ctx_name.rsplit("-", 1)
    try:
        return base, int(suf)
    except ValueError:
        return ctx_name, None

def find_record_by_seg_and_idx(tfrecord_file, target_seg_id, target_key_idx, max_scan=None):
    """
    在一个 tfrecord shard 里找某个 segment 的某个 key_idx
    找到则返回 e2e，否则返回 None
    """
    ds = tf.data.TFRecordDataset(
        tfrecord_file,
        compression_type=("GZIP" if tfrecord_file.endswith(".gz") else "")
    )
    for i, r in enumerate(ds):
        if max_scan is not None and i >= max_scan:
            break
        e2e = parse_e2ed(r.numpy())
        ctx = e2e.frame.context.name
        seg_id, key_idx = split_context_name(ctx)
        if seg_id == target_seg_id and key_idx == target_key_idx:
            return e2e
    return None

def _get_repeated(states_obj, field):
    if states_obj is None or (not hasattr(states_obj, field)):
        return None
    return np.array(list(getattr(states_obj, field)), dtype=np.float32)

def show_ego_status_one_record(e2e, head=5):
    """
    轻量展示：只打印你最关心的 ego status（past/future）
    """
    ctx = e2e.frame.context.name
    seg_id, key_idx = split_context_name(ctx)
    print("=== Keyframe ===")
    print("context.name:", ctx)
    print("segment_id:", seg_id)
    print("key_idx:", key_idx)
    print("intent:", int(getattr(e2e, "intent", -1)))

    past = getattr(e2e, "past_states", None)
    fut  = getattr(e2e, "future_states", None)

    # 你目前版本里常见字段（不保证全都有，所以做了容错）
    past_fields = ["pos_x","pos_y","vel_x","vel_y","accel_x","accel_y"]
    fut_fields  = ["pos_x","pos_y","pos_z"]  # 你之前看到 future 有 pos_z

    print("\n=== past_states ===")
    for f in past_fields:
        arr = _get_repeated(past, f)
        if arr is None:
            continue
        print(f"{f}: len={len(arr)} head={arr[:head]} tail={arr[-head:]}")

    print("\n=== future_states ===")
    for f in fut_fields:
        arr = _get_repeated(fut, f)
        if arr is None:
            continue
        print(f"{f}: len={len(arr)} head={arr[:head]} tail={arr[-head:]}")


2026-01-06 17:28:38.314782: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-06 17:28:38.315907: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2026-01-06 17:28:38.336914: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2026-01-06 17:28:38.337555: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
from waymo_open_dataset import dataset_pb2 as open_dataset
from waymo_open_dataset.protos import end_to_end_driving_data_pb2 as wod_e2ed_pb2

FILENAME = r"/mnt/d/Datasets/WOD_E2E_Camera_v1/val/val_202504211843.tfrecord-00008-of-00093"
SEG_ID   = "fb0ed944efebd34d756103188d59da85"
KEY_IDX  = 223   # 例子：你之前出现过 20~223

e2e = find_record_by_seg_and_idx(FILENAME, SEG_ID, KEY_IDX, max_scan=None)
print("found:", e2e is not None)

if e2e is not None:
    show_ego_status_one_record(e2e, head=6)


found: True
=== Keyframe ===
context.name: fb0ed944efebd34d756103188d59da85-223
segment_id: fb0ed944efebd34d756103188d59da85
key_idx: 223
intent: 1

=== past_states ===
pos_x: len=16 head=[-21.756836 -20.174805 -18.549805 -16.89746  -15.229492 -13.55957 ] tail=[-5.7597656 -4.435547  -3.2021484 -2.0546875 -0.9863281  0.       ]
pos_y: len=16 head=[-0.02978516 -0.00634766  0.01025391  0.02392578  0.02734375  0.02880859] tail=[ 0.01220703  0.0078125   0.00292969  0.         -0.00195312  0.        ]
vel_x: len=16 head=[6.426236  6.566687  6.650089  6.6838446 6.670882  6.5977283] tail=[5.0934463 4.766901  4.4306293 4.1217246 3.7761056 3.7761056]
vel_y: len=16 head=[ 0.10726788  0.09395864  0.06050611  0.02924885  0.01881703 -0.00044798] tail=[-0.01857909 -0.01457623 -0.02092961 -0.0144933  -0.00751033 -0.0075103 ]
accel_x: len=16 head=[ 0.20577478  0.14045095  0.08340168  0.03375578 -0.01296234 -0.07315397] tail=[-0.41816998 -0.32654524 -0.33627176 -0.30890465 -0.34561896 -0.34561896]
accel