# 1 原始WOMD文件预处理

waymo地图信息定义可以参考[proto文件](https://github.com/waymo-research/waymo-open-dataset/blob/master/src/waymo_open_dataset/protos/map.proto)

## 1.1 从proto中提取原始数据

**save_infos**
+ `scenario_id`: *str*, 场景编号
+ `timestamps_seconds`: *list[91]*, 时间戳列表
+ `current_time_index`: *int*, 当前时间戳
+ `sdc_track_index`: *int*, 主车编号
+ `objects_of_interest`: *list*, 感兴趣车辆列表（可能为空）
+ `tracks_to_predict`: *dict*, 需要被预测的轨迹
    - `track_index`: *list*, 代理轨迹编号
    - `difficulty`: *list*, 预测难度 **（不确定划分依据）** 
    - `object_type`: *list*, 代理类型
+ `track_infos`: *dict*, 代理轨迹信息
    - `object_id`: *list*, 所有代理id
    - `object_type`: *list[str]*, 代理类型
        > 'TYPE_UNSET', 'TYPE_VEHICLE', 'TYPE_PEDESTRIAN', 'TYPE_CYCLIST', 'TYPE_OTHER'
    - `trajs`: *array[NA, NT, 10]*, 代理状态信息
        > center_x, center_y, center_z, length, width, height, heading, velocity_x, velocity_y, valid
+ `map_infos`: *dict*, 静态地图信息
    + `lane`: *list(dict)*, 车道线信息
        - `id`: *int*, 地图元素id
        - `speed_limit_mph`: *float*, 速度限制（单位mph）
        - `type`: *int*, 地图元素类型
            > 'TYPE_FREEWAY': 1, 'TYPE_SURFACE_STREET': 2, 'TYPE_BIKE_LANE': 3,
        - `left_neighbors`: *list(int)*, 车道左侧车道
        - `right_neighbors`: *list(int)*, 车道右侧车道
        - `interpolating`: *bool*, 是否差值
        - `entry_lanes`: *list(int)*, 上游车道
        - `exit_lanes`: *list(int)*, 下游车道线
        - `left_boundary_type`: *list(int)*, 左侧边界线类型
        - `right_boundary_type`: *list(int)*, 右侧边界线类型
        - `left_boundary`: *list(int)*, 左侧边界线id
        - `right_boundary`: *list(int)*, 右侧边界线id
        - `left_boundary_start_index`: *list(int)*
        - `left_boundary_end_index`: *list(int)*
        - `right_boundary_start_index`: *list(int)*
        - `right_boundary_end_index`: *list(int)*
        - `polyline_index`: 该多段线对应点在all_polylines中的起始和终止下标
    + `lane_dict`: *dict*, 根据地图元素id获取车道信息
    + `road_line`: *list(dict)*, 车道分隔线信息
        - `id`: *int*, 地图元素id
        - `type`: *int*, 地图元素类型
            > 'TYPE_BROKEN_SINGLE_WHITE': 6, 'TYPE_SOLID_SINGLE_WHITE': 7, 'TYPE_SOLID_DOUBLE_WHITE': 8, 'TYPE_BROKEN_SINGLE_YELLOW': 9, 'TYPE_BROKEN_DOUBLE_YELLOW': 10, 'TYPE_SOLID_SINGLE_YELLOW': 11, 'TYPE_SOLID_DOUBLE_YELLOW': 12, 'TYPE_PASSING_DOUBLE_YELLOW': 13
        - `polyline_index`: 该多段线对应点在all_polylines中的起始和终止下标
    + `road_edge`: *list(dict)*, 车道边界线信息 
        - `id`: *int*, 地图元素id
        - `type`: *int*, 地图元素类型
            > 'TYPE_ROAD_EDGE_BOUNDARY': 15, 'TYPE_ROAD_EDGE_MEDIAN': 16
        - `polyline_index`: 该多段线对应点在all_polylines中的起始和终止下标
    + `stop_sign`: *list(dict)*, 停车线信息
        - `id`: *int*, 地图元素id
        - `lane_ids`: *list*, 停车线控制的lane_id
        - `position`: *array[3,]*, 停车线坐标
        - `polyline_index`: 该多段线对应点在all_polylines中的起始和终止下标
    + `crosswalk`: *list(dict)*, 人行横道线
        - `id`: *int*, 地图元素id
        - `polyline_index`: 该多段线对应点在all_polylines中的起始和终止下标
    + `speed_bump`: *list(dict)*, 减速带信息 
        - `id`: *int*, 地图元素id
        - `polyline_index`: 该多段线对应点在all_polylines中的起始和终止下标
    + `all_polylines`: *array(n, 5)*, 所有多段线上点的信息（x, y, z, 所属道路元素类型, 所属道路元素id）
    + `lane2other_dict`: *dict*, 获取车道与其他地图元素（左右边界线、停车线）的关联列表
+ `dynamic_map_infos`: *dict*, 动态地图信息
    - `lane_id`: *list(91)*, 每一帧所有信号灯控制的lane_id
    - `state`: *list(91)*, 每一帧所有信号灯的状态
        > 'LANE_STATE_UNKNOWN', 'LANE_STATE_ARROW_STOP', 'LANE_STATE_ARROW_CAUTION', 'LANE_STATE_ARROW_GO', 'LANE_STATE_STOP', 'LANE_STATE_CAUTION', 'LANE_STATE_GO', 'LANE_STATE_FLASHING_STOP', 'LANE_STATE_FLASHING_CAUTION'
    - `stop_point`: *list(91)*, 每一帧所有信号灯的控制的停车点信息（x, y, z）


In [1]:
# womd原始文件路径
WOMD_DIR = "/mnt/i/womd_scenario_v_1_2_0/training"

In [2]:
import os
import tensorflow as tf
from waymo_open_dataset.protos import scenario_pb2

for tfrecord in os.listdir(WOMD_DIR):
    file_path = os.path.join(WOMD_DIR, tfrecord)
    dataset = tf.data.TFRecordDataset(file_path, compression_type='', num_parallel_reads=3)
    for cnt, data in enumerate(dataset):
        scenario = scenario_pb2.Scenario()
        scenario.ParseFromString(bytearray(data.numpy()))
        print(f"scene: {scenario.scenario_id}")
        break
    break

2024-11-17 21:23:45.260398: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-17 21:23:45.327437: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-11-17 21:23:46.634294: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-11-17 21:23:46.681533: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bu

scene: 4b60f9400a30ceaf


In [3]:
import numpy as np
from collections import defaultdict

def decode_tracks_from_proto(tracks):
    object_type = {
        0: 'TYPE_UNSET',
        1: 'TYPE_VEHICLE',
        2: 'TYPE_PEDESTRIAN',
        3: 'TYPE_CYCLIST',
        4: 'TYPE_OTHER'
    }

    track_infos = {
        'object_id': [],  # {0: unset, 1: vehicle, 2: pedestrian, 3: cyclist, 4: others}
        'object_type': [],
        'trajs': []
    }

    for cur_data in tracks:  # number of objects
        cur_traj = [np.array([x.center_x, x.center_y, x.center_z, x.length, x.width, x.height, x.heading,
                              x.velocity_x, x.velocity_y, x.valid], dtype=np.float32) for x in cur_data.states]
        cur_traj = np.stack(cur_traj, axis=0)  # (num_timestamp, 10)

        track_infos['object_id'].append(cur_data.id)
        track_infos['object_type'].append(object_type[cur_data.object_type])
        track_infos['trajs'].append(cur_traj)

    track_infos['trajs'] = np.stack(track_infos['trajs'], axis=0)  # (num_objects, num_timestamp, 9)
    return track_infos

def decode_map_features_from_proto(map_features):
    polyline_type = {
        # for lane
        'TYPE_UNDEFINED': -1,
        'TYPE_FREEWAY': 1,
        'TYPE_SURFACE_STREET': 2,
        'TYPE_BIKE_LANE': 3,

        # for roadline
        'TYPE_UNKNOWN': -1,
        'TYPE_BROKEN_SINGLE_WHITE': 6,
        'TYPE_SOLID_SINGLE_WHITE': 7,
        'TYPE_SOLID_DOUBLE_WHITE': 8,
        'TYPE_BROKEN_SINGLE_YELLOW': 9,
        'TYPE_BROKEN_DOUBLE_YELLOW': 10,
        'TYPE_SOLID_SINGLE_YELLOW': 11,
        'TYPE_SOLID_DOUBLE_YELLOW': 12,
        'TYPE_PASSING_DOUBLE_YELLOW': 13,

        # for roadedge
        'TYPE_ROAD_EDGE_BOUNDARY': 15,
        'TYPE_ROAD_EDGE_MEDIAN': 16,

        # for stopsign
        'TYPE_STOP_SIGN': 17,

        # for crosswalk
        'TYPE_CROSSWALK': 18,

        # for speed bump
        'TYPE_SPEED_BUMP': 19
    }

    map_infos = {
        'lane': [],
        'road_line': [],
        'road_edge': [],
        'stop_sign': [],
        'crosswalk': [],
        'speed_bump': [],
        'lane_dict': {},
        'lane2other_dict': {}
    }
    polylines = []

    point_cnt = 0
    lane2other_dict = defaultdict(list)

    for cur_data in map_features:
        cur_info = {'id': cur_data.id}

        if cur_data.lane.ByteSize() > 0:
            cur_info['speed_limit_mph'] = cur_data.lane.speed_limit_mph
            cur_info['type'] = cur_data.lane.type + 1  # 0: undefined, 1: freeway, 2: surface_street, 3: bike_lane
            cur_info['left_neighbors'] = [lane.feature_id for lane in cur_data.lane.left_neighbors]

            cur_info['right_neighbors'] = [lane.feature_id for lane in cur_data.lane.right_neighbors]

            cur_info['interpolating'] = cur_data.lane.interpolating
            cur_info['entry_lanes'] = list(cur_data.lane.entry_lanes)
            cur_info['exit_lanes'] = list(cur_data.lane.exit_lanes)

            cur_info['left_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.left_boundaries]
            cur_info['right_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.right_boundaries]

            cur_info['left_boundary'] = [x.boundary_feature_id for x in cur_data.lane.left_boundaries]
            cur_info['right_boundary'] = [x.boundary_feature_id for x in cur_data.lane.right_boundaries]
            cur_info['left_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.left_boundaries]
            cur_info['left_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.left_boundaries]
            cur_info['right_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.right_boundaries]
            cur_info['right_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.right_boundaries]

            lane2other_dict[cur_data.id].extend(cur_info['left_boundary'])
            lane2other_dict[cur_data.id].extend(cur_info['right_boundary'])

            global_type = cur_info['type']
            cur_polyline = np.stack(
                [np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.lane.polyline],
                axis=0)
            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['lane'].append(cur_info)
            map_infos['lane_dict'][cur_data.id] = cur_info

        elif cur_data.road_line.ByteSize() > 0:
            cur_info['type'] = cur_data.road_line.type + 5

            global_type = cur_info['type']
            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
                                     cur_data.road_line.polyline], axis=0)
            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['road_line'].append(cur_info)

        elif cur_data.road_edge.ByteSize() > 0:
            cur_info['type'] = cur_data.road_edge.type + 14

            global_type = cur_info['type']
            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
                                     cur_data.road_edge.polyline], axis=0)
            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['road_edge'].append(cur_info)

        elif cur_data.stop_sign.ByteSize() > 0:
            cur_info['lane_ids'] = list(cur_data.stop_sign.lane)
            for i in cur_info['lane_ids']:
                lane2other_dict[i].append(cur_data.id)
            point = cur_data.stop_sign.position
            cur_info['position'] = np.array([point.x, point.y, point.z])

            global_type = polyline_type['TYPE_STOP_SIGN']
            cur_polyline = np.array([point.x, point.y, point.z, global_type, cur_data.id]).reshape(1, 5)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['stop_sign'].append(cur_info)
        elif cur_data.crosswalk.ByteSize() > 0:
            global_type = polyline_type['TYPE_CROSSWALK']
            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
                                     cur_data.crosswalk.polygon], axis=0)
            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['crosswalk'].append(cur_info)

        elif cur_data.speed_bump.ByteSize() > 0:
            global_type = polyline_type['TYPE_SPEED_BUMP']
            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
                                     cur_data.speed_bump.polygon], axis=0)
            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['speed_bump'].append(cur_info)

        else:
            # print(cur_data)
            continue
        polylines.append(cur_polyline)
        cur_info['polyline_index'] = (point_cnt, point_cnt + len(cur_polyline))
        point_cnt += len(cur_polyline)

    # try:
    polylines = np.concatenate(polylines, axis=0).astype(np.float32)
    # except:
    #     polylines = np.zeros((0, 8), dtype=np.float32)
    #     print('Empty polylines: ')
    map_infos['all_polylines'] = polylines
    map_infos['lane2other_dict'] = lane2other_dict
    return map_infos

def decode_dynamic_map_states_from_proto(dynamic_map_states):
    signal_state = {
        0: 'LANE_STATE_UNKNOWN',

        # // States for traffic signals with arrows.
        1: 'LANE_STATE_ARROW_STOP',
        2: 'LANE_STATE_ARROW_CAUTION',
        3: 'LANE_STATE_ARROW_GO',

        # // Standard round traffic signals.
        4: 'LANE_STATE_STOP',
        5: 'LANE_STATE_CAUTION',
        6: 'LANE_STATE_GO',

        # // Flashing light signals.
        7: 'LANE_STATE_FLASHING_STOP',
        8: 'LANE_STATE_FLASHING_CAUTION'
    }

    dynamic_map_infos = {
        'lane_id': [],
        'state': [],
        'stop_point': []
    }
    for cur_data in dynamic_map_states:  # (num_timestamp)
        lane_id, state, stop_point = [], [], []
        for cur_signal in cur_data.lane_states:  # (num_observed_signals)
            lane_id.append(cur_signal.lane)
            state.append(signal_state[cur_signal.state])
            stop_point.append([cur_signal.stop_point.x, cur_signal.stop_point.y, cur_signal.stop_point.z])

        dynamic_map_infos['lane_id'].append(np.array([lane_id]))
        dynamic_map_infos['state'].append(np.array([state]))
        dynamic_map_infos['stop_point'].append(np.array([stop_point]))

    return dynamic_map_infos

def process_single_data(scenario):
    info = {}
    info['scenario_id'] = scenario.scenario_id
    info['timestamps_seconds'] = list(scenario.timestamps_seconds)  # list of int of shape (91)
    info['current_time_index'] = scenario.current_time_index  # int, 10
    info['sdc_track_index'] = scenario.sdc_track_index  # int
    info['objects_of_interest'] = list(scenario.objects_of_interest)  # list, could be empty list

    info['tracks_to_predict'] = {
        'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict],
        'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict]
    }  # for training: suggestion of objects to train on, for val/test: need to be predicted

    track_infos = decode_tracks_from_proto(scenario.tracks)
    info['tracks_to_predict']['object_type'] = [track_infos['object_type'][cur_idx] for cur_idx in
                                                info['tracks_to_predict']['track_index']]

    # decode map related data
    map_infos = decode_map_features_from_proto(scenario.map_features)
    dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states)

    save_infos = {
        'track_infos': track_infos,
        'dynamic_map_infos': dynamic_map_infos,
        'map_infos': map_infos
    }
    save_infos.update(info)
    return save_infos

save_infos = process_single_data(scenario)

## 1.2 获取信控信息

`tf_lights`: *array(n \* 3)*, 每个信号灯每时刻的状态（lane_id, time_step, state）
> LANE_STATE_STOP, LANE_STATE_GO, LANE_STATE_CAUTION

In [4]:
import torch
from typing import Any, Dict, List, Optional
import easydict
import pandas as pd

def process_dynamic_map(dynamic_map_infos):
    lane_ids = dynamic_map_infos["lane_id"]
    tf_lights = []
    for t in range(len(lane_ids)):
        lane_id = lane_ids[t]
        time = np.ones_like(lane_id) * t
        state = dynamic_map_infos["state"][t]
        tf_light = np.concatenate([lane_id, time, state], axis=0)
        tf_lights.append(tf_light)
    tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0)
    tf_lights = pd.DataFrame(data=tf_lights, columns=["lane_id", "time_step", "state"])
    tf_lights["time_step"] = tf_lights["time_step"].astype("str")
    tf_lights["lane_id"] = tf_lights["lane_id"].astype("str")
    tf_lights["state"] = tf_lights["state"].astype("str")
    tf_lights.loc[tf_lights["state"].str.contains("STOP"), ["state"] ] = 'LANE_STATE_STOP'
    tf_lights.loc[tf_lights["state"].str.contains("GO"), ["state"] ] = 'LANE_STATE_GO'
    tf_lights.loc[tf_lights["state"].str.contains("CAUTION"), ["state"] ] = 'LANE_STATE_CAUTION'
    return tf_lights

dynamic_map_infos = save_infos["dynamic_map_infos"]
tf_lights = process_dynamic_map(dynamic_map_infos)

## 1.3 获取地图特征

In [5]:
def _safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
    try:
        return ls.index(elem)
    except ValueError:
        return None

def get_map_features(map_infos, tf_current_light, dim=3):
    _point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
                'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
                'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
                'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
    _polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
    _polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
    _polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']
    
    Lane_type_hash = {
        4: "BIKE",
        3: "VEHICLE",
        2: "VEHICLE",
        1: "BUS"
    }
    boundary_type_hash = {
        5: "UNKNOWN",
        6: "DASHED_WHITE",
        7: "SOLID_WHITE",
        8: "DOUBLE_DASH_WHITE",
        9: "DASHED_YELLOW",
        10: "DOUBLE_DASH_YELLOW",
        11: "SOLID_YELLOW",
        12: "DOUBLE_SOLID_YELLOW",
        13: "DASH_SOLID_YELLOW",
        14: "UNKNOWN",
        15: "EDGE",
        16: "EDGE"
    }

    lane_segments = map_infos['lane']
    all_polylines = map_infos["all_polylines"]
    crosswalks = map_infos['crosswalk']
    road_edges = map_infos['road_edge']
    road_lines = map_infos['road_line']
    lane_segment_ids = [info["id"] for info in lane_segments]
    cross_walk_ids = [info["id"] for info in crosswalks]
    road_edge_ids = [info["id"] for info in road_edges]
    road_line_ids = [info["id"] for info in road_lines]
    polygon_ids = lane_segment_ids + road_edge_ids + road_line_ids + cross_walk_ids
    num_polygons = len(lane_segment_ids) + len(road_edge_ids) + len(road_line_ids) + len(cross_walk_ids)

    # 多段线中各点所属多段线类型（对应关系见 _polygon_types）
    polygon_type = torch.zeros(num_polygons, dtype=torch.uint8)
    # 多段线中各点所属信控信号类型（默认为3未知，对应关系见 _polygon_light_type）
    polygon_light_type = torch.ones(num_polygons, dtype=torch.uint8) * 3

    # 多段线中点的坐标
    point_position: List[Optional[torch.Tensor]] = [None] * num_polygons
    # 多段线中各点对应向量的方向角（以弧度表示）
    point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons
    # 多段线中各点对应向量的大小（即距离）
    point_magnitude: List[Optional[torch.Tensor]] = [None] * num_polygons
    # 多段线中各点与下一点之间的高度差
    point_height: List[Optional[torch.Tensor]] = [None] * num_polygons
    # 多段线中各点类型（对应关系见 _point_types）
    point_type: List[Optional[torch.Tensor]] = [None] * num_polygons

    for lane_segment in lane_segments:
        lane_segment = easydict.EasyDict(lane_segment)
        lane_segment_idx = polygon_ids.index(lane_segment.id)
        polyline_index = lane_segment.polyline_index
        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
        centerline = torch.from_numpy(centerline).float()
        polygon_type[lane_segment_idx] = _polygon_types.index(Lane_type_hash[lane_segment.type])

        # 查找当前lane是否在当前帧的信控控制的序列中
        res = tf_current_light[tf_current_light["lane_id"] == str(lane_segment.id)]
        if len(res) != 0:
            polygon_light_type[lane_segment_idx] = _polygon_light_type.index(res["state"].item())

        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
        center_vectors = centerline[1:] - centerline[:-1]
        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
        center_type = _point_types.index('CENTERLINE')
        point_type[lane_segment_idx] = torch.cat(
            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)

    for lane_segment in road_edges:
        lane_segment = easydict.EasyDict(lane_segment)
        lane_segment_idx = polygon_ids.index(lane_segment.id)
        polyline_index = lane_segment.polyline_index
        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
        centerline = torch.from_numpy(centerline).float()
        polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")

        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
        center_vectors = centerline[1:] - centerline[:-1]
        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
        center_type = _point_types.index('EDGE')
        point_type[lane_segment_idx] = torch.cat(
            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)

    for lane_segment in road_lines:
        lane_segment = easydict.EasyDict(lane_segment)
        lane_segment_idx = polygon_ids.index(lane_segment.id)
        polyline_index = lane_segment.polyline_index
        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
        centerline = torch.from_numpy(centerline).float()

        polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")

        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
        center_vectors = centerline[1:] - centerline[:-1]
        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
        center_type = _point_types.index(boundary_type_hash[lane_segment.type])
        point_type[lane_segment_idx] = torch.cat(
            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)

    for crosswalk in crosswalks:
        crosswalk = easydict.EasyDict(crosswalk)
        lane_segment_idx = polygon_ids.index(crosswalk.id)
        polyline_index = crosswalk.polyline_index
        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
        centerline = torch.from_numpy(centerline).float()

        polygon_type[lane_segment_idx] = _polygon_types.index("PEDESTRIAN")

        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
        center_vectors = centerline[1:] - centerline[:-1]
        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
        center_type = _point_types.index("CROSSWALK")
        point_type[lane_segment_idx] = torch.cat(
            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)

    # 每条多段线对应的点数
    num_points = torch.tensor([point.size(0) for point in point_position], dtype=torch.long)
    # [2, N_points] 点索引与多段线索引之间的关系
    point_to_polygon_edge_index = torch.stack(
        [torch.arange(num_points.sum(), dtype=torch.long),
            torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points)], dim=0)
    # array([2, N_edge])) 提取车道拓扑关系[车道索引1，车道索引2](不根据顺序区分上下游)
    polygon_to_polygon_edge_index = []
    # array([N_edge])) 拓扑关系类型，参考 _polygon_to_polygon_types
    polygon_to_polygon_type = []
    for lane_segment in lane_segments:
        lane_segment = easydict.EasyDict(lane_segment)
        lane_segment_idx = polygon_ids.index(lane_segment.id)
        pred_inds = []
        for pred in lane_segment.entry_lanes:
            pred_idx = _safe_list_index(polygon_ids, pred)
            if pred_idx is not None:
                pred_inds.append(pred_idx)
        if len(pred_inds) != 0:
            polygon_to_polygon_edge_index.append(
                torch.stack([torch.tensor(pred_inds, dtype=torch.long),
                             torch.full((len(pred_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
            polygon_to_polygon_type.append(
                torch.full((len(pred_inds),), _polygon_to_polygon_types.index('PRED'), dtype=torch.uint8))
        succ_inds = []
        for succ in lane_segment.exit_lanes:
            succ_idx = _safe_list_index(polygon_ids, succ)
            if succ_idx is not None:
                succ_inds.append(succ_idx)
        if len(succ_inds) != 0:
            polygon_to_polygon_edge_index.append(
                torch.stack([torch.tensor(succ_inds, dtype=torch.long),
                             torch.full((len(succ_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
            polygon_to_polygon_type.append(
                torch.full((len(succ_inds),), _polygon_to_polygon_types.index('SUCC'), dtype=torch.uint8))
        if len(lane_segment.left_neighbors) != 0:
            left_neighbor_ids = lane_segment.left_neighbors
            for left_neighbor_id in left_neighbor_ids:
                left_idx = _safe_list_index(polygon_ids, left_neighbor_id)
                if left_idx is not None:
                    polygon_to_polygon_edge_index.append(
                        torch.tensor([[left_idx], [lane_segment_idx]], dtype=torch.long))
                    polygon_to_polygon_type.append(
                        torch.tensor([_polygon_to_polygon_types.index('LEFT')], dtype=torch.uint8))
        if len(lane_segment.right_neighbors) != 0:
            right_neighbor_ids = lane_segment.right_neighbors
            for right_neighbor_id in right_neighbor_ids:
                right_idx = _safe_list_index(polygon_ids, right_neighbor_id)
                if right_idx is not None:
                    polygon_to_polygon_edge_index.append(
                        torch.tensor([[right_idx], [lane_segment_idx]], dtype=torch.long))
                    polygon_to_polygon_type.append(
                        torch.tensor([_polygon_to_polygon_types.index('RIGHT')], dtype=torch.uint8))
    if len(polygon_to_polygon_edge_index) != 0:
        polygon_to_polygon_edge_index = torch.cat(polygon_to_polygon_edge_index, dim=1)
        polygon_to_polygon_type = torch.cat(polygon_to_polygon_type, dim=0)
    else:
        polygon_to_polygon_edge_index = torch.tensor([[], []], dtype=torch.long)
        polygon_to_polygon_type = torch.tensor([], dtype=torch.uint8)

    map_data = {
        'map_polygon': {},
        'map_point': {},
        ('map_point', 'to', 'map_polygon'): {},
        ('map_polygon', 'to', 'map_polygon'): {},
    }
    map_data['map_polygon']['num_nodes'] = num_polygons
    map_data['map_polygon']['type'] = polygon_type
    map_data['map_polygon']['light_type'] = polygon_light_type
    if len(num_points) == 0:
        map_data['map_point']['num_nodes'] = 0
        map_data['map_point']['position'] = torch.tensor([], dtype=torch.float)
        map_data['map_point']['orientation'] = torch.tensor([], dtype=torch.float)
        map_data['map_point']['magnitude'] = torch.tensor([], dtype=torch.float)
        if dim == 3:
            map_data['map_point']['height'] = torch.tensor([], dtype=torch.float)
        map_data['map_point']['type'] = torch.tensor([], dtype=torch.uint8)
        map_data['map_point']['side'] = torch.tensor([], dtype=torch.uint8)
    else:
        map_data['map_point']['num_nodes'] = num_points.sum().item()
        map_data['map_point']['position'] = torch.cat(point_position, dim=0)
        map_data['map_point']['orientation'] = torch.cat(point_orientation, dim=0)
        map_data['map_point']['magnitude'] = torch.cat(point_magnitude, dim=0)
        if dim == 3:
            map_data['map_point']['height'] = torch.cat(point_height, dim=0)
        map_data['map_point']['type'] = torch.cat(point_type, dim=0)
    map_data['map_point', 'to', 'map_polygon']['edge_index'] = point_to_polygon_edge_index
    map_data['map_polygon', 'to', 'map_polygon']['edge_index'] = polygon_to_polygon_edge_index
    map_data['map_polygon', 'to', 'map_polygon']['type'] = polygon_to_polygon_type
    # import matplotlib.pyplot as plt
    # plt.axis('equal')
    # plt.scatter(map_data['map_point']['position'][:, 0],
    #             map_data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none')
    # plt.show(dpi=600)
    return map_data

map_info = save_infos["map_infos"]
tf_current_light = tf_lights.loc[tf_lights["time_step"] == "11"]
map_data = get_map_features(map_info, tf_current_light)

## 1.4 获取代理特征

**new_agents_array**: *array(num_track, 20)*, 详细记录轨迹相关信息
+ `observed`	    是否观测到该点
+ `track_id`	    物体的唯一轨迹ID
+ `object_type`	    物体类型（vehicle、pedestrian 等）
+ `object_category`	物体类别
+ `timestep`	    当前时间步
+ `position_x`	    x 方向位置坐标
+ `position_y`	    y 方向位置坐标
+ `position_z`	    z 方向位置坐标
+ `length`	        物体长度
+ `width`	        物体宽度
+ `height`	        物体高度
+ `heading`	        物体朝向
+ `velocity_x`	    x 方向速度
+ `velocity_y`	    y 方向速度
+ `scenario_id`	    场景ID
+ `start_timestamp`	起始时间戳
+ `end_timestamp`	结束时间戳
+ `num_timestamps`	时间步总数
+ `focal_track_id`	聚焦轨迹ID
+ `city`	        城市标识

In [6]:
def process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, start_timestamp, end_timestamp):
    agents_array = track_info["trajs"].transpose(1, 0, 2)   #[NT, NA, 10]
    # 代理ID
    object_id = np.array(track_info["object_id"])
    # 代理类型（str）
    object_type = track_info["object_type"]
    # 映射关系 {代理ID：代理类型}
    id_hash = {object_id[o_idx]: object_type[o_idx] for o_idx in range(len(object_id))}
    def type_hash(x):
        tp = id_hash[x]
        type_re_hash = {
            "TYPE_VEHICLE": "vehicle",
            "TYPE_PEDESTRIAN": "pedestrian",
            "TYPE_CYCLIST": "cyclist",
            "TYPE_OTHER": "background",
            "TYPE_UNSET": "background"
        }
        return type_re_hash[tp]

    columns = ['observed', 'track_id', 'object_type', 'object_category', 'timestep',
               'position_x', 'position_y', 'position_z', 'length', 'width', 'height', 'heading', 'velocity_x', 'velocity_y',
               'scenario_id', 'start_timestamp', 'end_timestamp', 'num_timestamps',
               'focal_track_id', 'city']
    new_columns = np.ones((agents_array.shape[0], agents_array.shape[1], 11))
    new_columns[:11, :, 0] = True
    new_columns[11:, :, 0] = False
    for index in range(new_columns.shape[0]):
        new_columns[index, :, 4] = int(index)
    new_columns[..., 1] = object_id
    new_columns[..., 2] = object_id
    new_columns[:, tracks_to_predict["track_index"], 3] = 3
    new_columns[..., 5] = 11
    new_columns[..., 6] = int(start_timestamp)
    new_columns[..., 7] = int(end_timestamp)
    new_columns[..., 8] = int(91)
    new_columns[..., 9] = object_id
    new_columns[..., 10] = 10086
    new_columns = new_columns
    new_agents_array = np.concatenate([new_columns, agents_array], axis=-1)
    new_agents_array = new_agents_array[new_agents_array[..., -1] == 1.0].reshape(-1, new_agents_array.shape[-1])
    new_agents_array = new_agents_array[..., [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 5, 6, 7, 8, 9, 10]]
    new_agents_array = pd.DataFrame(data=new_agents_array, columns=columns)
    new_agents_array["object_type"] = new_agents_array["object_type"].apply(func=type_hash)
    new_agents_array["start_timestamp"] = new_agents_array["start_timestamp"].astype(int)
    new_agents_array["end_timestamp"] = new_agents_array["end_timestamp"].astype(int)
    new_agents_array["num_timestamps"] = new_agents_array["num_timestamps"].astype(int)
    new_agents_array["scenario_id"] = scenario_id
    return new_agents_array


track_info = save_infos['track_infos']
tracks_to_predict = save_infos['tracks_to_predict']
assert len(tracks_to_predict["track_index"]) >= 1 
sdc_track_index = save_infos['sdc_track_index']
scenario_id = save_infos['scenario_id']

new_agents_array = process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, 0, 91) # mtr2argo

In [7]:
predict_unseen_agents = False
vector_repr = True
split = 'train'

def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]:
    _agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background']
    
    if not predict_unseen_agents:  # filter out agents that are unseen during the historical time steps
        historical_df = df[df['timestep'] == num_historical_steps-1]
        agent_ids = list(historical_df['track_id'].unique())
        df = df[df['track_id'].isin(agent_ids)]
    else:
        agent_ids = list(df['track_id'].unique())

    num_agents = len(agent_ids)

    # 代理是否处在有效状态
    valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
    # 代理在当前帧是否有效
    current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)
    # 是否需要预测
    predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
    # 代理ID
    agent_id: List[Optional[str]] = [None] * num_agents
    # 代理类型：0 'vehicle', 1 'pedestrian', 2 'cyclist', 3 'background'
    agent_type = torch.zeros(num_agents, dtype=torch.uint8)
    # 代理类别：1 其他代理, 3 应该被预测代理
    agent_category = torch.zeros(num_agents, dtype=torch.uint8)
    # 代理位置 [NA, NT, 3] (x, y, z)
    position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
    # 代理航向 [NA, NT]
    heading = torch.zeros(num_agents, num_steps, dtype=torch.float)
    # 代理速度 [NA, NT, 3] (vx, vy, 0)
    velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
    # 代理尺寸 [NA, NT, 3] (length, width, height)
    shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)

    for track_id, track_df in df.groupby('track_id'):
        agent_idx = agent_ids.index(track_id)
        agent_steps = track_df['timestep'].values

        # 与tracks中提供的valid一致
        valid_mask[agent_idx, agent_steps] = True
        # 当前帧是否可用
        current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1]
        # 与tracks中提供的valid一致
        predict_mask[agent_idx, agent_steps] = True
        # 当前时间步 t 的有效性依赖于 t 和 t-1 的有效性。
        if vector_repr:  # a time step t is valid only when both t and t-1 are valid
            valid_mask[agent_idx, 1: num_historical_steps] = (
                valid_mask[agent_idx, :num_historical_steps - 1] &
                valid_mask[agent_idx, 1: num_historical_steps])
            valid_mask[agent_idx, 0] = False
        # 设置历史时间步的预测掩码为无效
        predict_mask[agent_idx, :num_historical_steps] = False
        # 如果当前帧无效，设置后续时间步的预测掩码为无效
        if not current_valid_mask[agent_idx]:
            predict_mask[agent_idx, num_historical_steps:] = False

        agent_id[agent_idx] = track_id
        agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])
        agent_category[agent_idx] = track_df['object_category'].values[0]
        position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values,
                                                                          track_df['position_y'].values,
                                                                          track_df['position_z'].values],
                                                                         axis=-1)).float()
        heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float()
        velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values,
                                                                          track_df['velocity_y'].values],
                                                                         axis=-1)).float()
        shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values,
                                                                       track_df['width'].values,
                                                                       track_df["height"].values],
                                                                      axis=-1)).float()
    av_idx = agent_id.index(av_id)
    if split == 'test':
        predict_mask[current_valid_mask
                     | (agent_category == 2)
                     | (agent_category == 3), num_historical_steps:] = True

    return {
        'num_nodes': num_agents,
        'av_index': av_idx,
        'valid_mask': valid_mask,
        'predict_mask': predict_mask,
        'id': agent_id,
        'type': agent_type,
        'category': agent_category,
        'position': position,
        'heading': heading,
        'velocity': velocity,
        'shape': shape
    }

av_id = track_info["object_id"][sdc_track_index]
agent_features = get_agent_features(new_agents_array, av_id, num_historical_steps=11)

## 1.5 组织预处理的data信息

**data**
+ `scenario_id`: *str*, 场景id
+ `city`: *float*, 场景所属城市编号（默认10086）
+ `agent`: *dict*, 场景中的代理信息
    - `num_nodes`: *int*, 代理数量
    - `av_index`: *int*, AV对应的代理编号
    - `valid_mask`: *array([NA, NT])*, 各代理各时间步是否有效
    - `predict_mask`: *array([NA, NT])*, 各代理需要被预测的时间步
    - `id`: *list([NA])*, 代理ID 
    - `type`: *array([NA])*, 代理类型
        > 0 'vehicle', 1 'pedestrian', 2 'cyclist', 3 'background'
    - `category`: *array([NA])*, 代理类别
        > 1 其他代理, 2 感兴趣代理, 3 被预测代理
    - `position`: *array([NA, NT, 3])*, 代理位置 (x, y, z)
    - `heading`: *array([NA, NT])*, 代理航向
    - `velocity`: *array([NA, NT, 3])*, 代理速度 (vx, vy, 0)
    - `shape`: *array([NA, NT, 3])*, 代理尺寸 (length, width, height)
+ `map_polygon`: *dict*, 地图中多段线信息
    - `num_nodes`: *int*, 多段线数量
    - `type`: *array([num_nodes])*, 每条多段线对应的类型
        > ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
    - `light_type`: *array([num_nodes])*, 每条多段线对应的交通信号状态
        > ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
+ `map_point`: *dict*, 地图中各点信息
    - `num_nodes`: *int*, 点数量
    - `position`: *array([num_nodes, 3])*, 点的坐标(x, y, z)
    - `orientation`: *array([num_nodes])*, 各点对应向量的方向角（以弧度表示）
    - `magnitude`: *array([num_nodes])*, 各点对应向量的大小（即距离）
    - `height`: *array([num_nodes])*, 各点与下一点之间的高度差 *(dim=3时才存在)*
    - `type`: *array([num_nodes])*, 各点所属类型
        > [ 'DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
            'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
            'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
            'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
+ `'map_point', 'to', 'map_polygon'`: *dict*, 点与多段线之间的索引映射关系
    - `edge_index`: *array([2, num_points])*, [point_idx, polygon_idx]
+ `'map_polygon', 'to', 'map_polygon'`: *dict*, 车道与车道之间的拓扑关系
    - `edge_index`: *array([2, num_edges])*, [polygon1_idx, polygon2_idx]
    - `type`: *array([num_edges])*, 拓扑关系类型-polygon1是polygon2的前驱/后继/左邻居/右邻居
        > ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']

In [8]:
data = dict()
data['scenario_id'] = new_agents_array['scenario_id'].values[0]
data['city'] = new_agents_array['city'].values[0]
data['agent'] = agent_features
data.update(map_data)

# 2 数据加载时token化预处理

**agent_token_data**

聚类得到的token信息，即以[0,0]为第一帧中心坐标，五帧为时间间隔到达的终点位置

+ `token`: *dict*, 每个token对应的终点位置矩形的四个角点坐标(x, y)，**角点顺序为-左前、右前、右后、左后**
    - `veh`: *array([2048, 4, 2])*
    - `ped`: *array([2048, 4, 2])*
    - `cyc`: *array([2048, 4, 2])*
+ `traj`: *dict*, 每个token对应的6帧完整轨迹中心坐标 (x, y ,z)
    - `veh`: *array([2048, 6, 3])*
    - `ped`: *array([2048, 6, 3])*
    - `cyc`: *array([2048, 6, 3])*
+ `token_all`: *dict*, 每个token对应的6帧完整轨迹各处矩形四个角点坐标
    - `veh`: *array([2048, 6, 4, 2])*
    - `ped`: *array([2048, 6, 4, 2])*
    - `cyc`: *array([2048, 6, 4, 2])*

In [9]:
import pickle

current_step = 10
shift = 5
noise = True
training = False

agent_token_path = "/home/yangyh408/codes/SMART/smart/tokens/cluster_frame_5_2048.pkl"
agent_token_data = pickle.load(open(agent_token_path, 'rb'))
trajectory_token = agent_token_data['token']
trajectory_token_all = agent_token_data['token_all']
# 对所有token依据倒数第二帧的状态为基准状态对最后一帧进行归一化
token_last_all = {}

for k, v in trajectory_token_all.items():
    # 计算每个 agent 的最终 token 朝向
    token_last = torch.from_numpy(v[:, -2:]).to(torch.float)    # [2048, 2, 4, 2]
    diff_xy = token_last[:, 0, 0] - token_last[:, 0, 3]         # 倒数第二帧 左前-左后
    theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])         # 倒数第二帧的航向角
    cos, sin = theta.cos(), theta.sin()
    # 生成旋转矩阵
    rot_mat = theta.new_zeros(token_last.shape[0], 2, 2)
    rot_mat[:, 0, 0] = cos
    rot_mat[:, 0, 1] = -sin
    rot_mat[:, 1, 0] = sin
    rot_mat[:, 1, 1] = cos
    # 应用旋转矩阵并归一化 token 数据
    agent_token = torch.bmm(token_last[:, 1], rot_mat)
    agent_token -= token_last[:, 0].mean(1)[:, None, :]
    token_last_all[k] = agent_token.numpy()

In [10]:
def clean_heading(data):
    """
        这个函数 clean_heading 的主要功能是对“heading” (朝向角度) 进行清理，以修复明显异常或突然变化的朝向角度
        （例如，当相邻帧之间的朝向差异超过一定阈值时），从而平滑朝向数据。
        具体而言，代码通过对相邻帧的朝向差异进行检测和修正，使得朝向变化更连贯。
    """
    heading = data['agent']['heading']
    valid = data['agent']['valid_mask']
    pi = torch.tensor(torch.pi)
    n_vehicles, n_frames = heading.shape

    heading_diff_raw = heading[:, :-1] - heading[:, 1:]
    heading_diff = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi
    heading_diff[heading_diff > pi] -= 2 * pi
    heading_diff[heading_diff < -pi] += 2 * pi

    valid_pairs = valid[:, :-1] & valid[:, 1:]

    for i in range(n_frames - 1):
        change_needed = (torch.abs(heading_diff[:, i:i + 1]) > 1.0) & valid_pairs[:, i:i + 1]

        heading[:, i + 1][change_needed.squeeze()] = heading[:, i][change_needed.squeeze()]

        if i < n_frames - 2:
            heading_diff_raw = heading[:, i + 1] - heading[:, i + 2]
            heading_diff[:, i + 1] = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi
            heading_diff[heading_diff[:, i + 1] > pi] -= 2 * pi
            heading_diff[heading_diff[:, i + 1] < -pi] += 2 * pi

def cal_polygon_contour(x, y, theta, width, length):
    """
        函数功能：计算一个矩形多边形的四个顶点坐标（轮廓）
        返回值：返回一个形状为 [n, 4, 2] 的数组 polygon_contour，表示每个矩形的四个顶点的坐标，方便后续用作绘制或碰撞检测等应用。
    """
    left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
    left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
    left_front = np.column_stack((left_front_x, left_front_y))

    right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
    right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
    right_front = np.column_stack((right_front_x, right_front_y))

    right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
    right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
    right_back = np.column_stack((right_back_x, right_back_y))

    left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
    left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
    left_back = np.column_stack((left_back_x, left_back_y))

    polygon_contour = np.concatenate(
        (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)

    return polygon_contour

def match_token(pos, valid_mask, heading, category, agent_category, extra_mask):
    """
        将轨迹位置和朝向数据与预定义的 token 数据进行匹配，以便在场景中的每个时间步中都能追踪到正确的 token。
    """
    agent_token_src = trajectory_token[category]
    token_last = token_last_all[category]
    if shift <= 2:
        if category == 'veh':
            width = 1.0
            length = 2.4
        elif category == 'cyc':
            width = 0.5
            length = 1.5
        else:
            width = 0.5
            length = 0.5
    else:
        if category == 'veh':
            width = 2.0
            length = 4.8
        elif category == 'cyc':
            width = 1.0
            length = 2.0
        else:
            width = 1.0
            length = 1.0

    prev_heading = heading[:, 0]
    prev_pos = pos[:, 0]
    agent_num, num_step, feat_dim = pos.shape   # [NA, 91, 2]
    token_num, token_contour_dim, feat_dim = agent_token_src.shape  # [2048, 4, 2]
    agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0)
    token_last = token_last.reshape(1, token_num * token_contour_dim, feat_dim).repeat(extra_mask.sum(), 0)
    token_index_list = []
    token_contour_list = []
    prev_token_idx = None

    for i in range(shift, pos.shape[1], shift):
        # 上一token所在位置航向角（5帧前）
        theta = prev_heading
        # 当前航向角和位置
        cur_heading = heading[:, i]
        cur_pos = pos[:, i]
        # 将归一化的原始token信息以上一时刻位置和航向状态为基准调整到全局坐标系
        cos, sin = theta.cos(), theta.sin()
        rot_mat = theta.new_zeros(agent_num, 2, 2)
        rot_mat[:, 0, 0] = cos
        rot_mat[:, 0, 1] = sin
        rot_mat[:, 1, 0] = -sin
        rot_mat[:, 1, 1] = cos
        agent_token_world = torch.bmm(torch.from_numpy(agent_token_src).to(torch.float), rot_mat).reshape(agent_num,
                                                                                                            token_num,
                                                                                                            token_contour_dim,
                                                                                                            feat_dim)
        agent_token_world += prev_pos[:, None, None, :]

        # 获取当前所在位置的矩形四角信息
        cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
        # 找出与当前距离最近的token作为匹配对象，记录该tokenid
        agent_token_index = torch.from_numpy(np.argmin(
            np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
            axis=-1))
        if prev_token_idx is not None and noise:
            same_idx = prev_token_idx == agent_token_index
            same_idx[:] = True
            topk_indices = np.argsort(
                np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)),
                        axis=2), axis=-1)[:, :5]
            sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0])
            agent_token_index[same_idx] = \
                torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx]
        # 将匹配的tokenid转换为矩形四角坐标
        token_contour_select = agent_token_world[torch.arange(agent_num), agent_token_index]

        # 将当前帧信息更新为上一帧信息
        diff_xy = token_contour_select[:, 0, :] - token_contour_select[:, 3, :]
        # 数据集中原航向角
        prev_heading = heading[:, i].clone()
        # 如果是这一帧被预测的对象，则用当前token所在状态更新航向和位置信息
        prev_heading[valid_mask[:, i - shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[
            valid_mask[:, i - shift]]

        prev_pos = pos[:, i].clone()
        prev_pos[valid_mask[:, i - shift]] = token_contour_select.mean(dim=1)[valid_mask[:, i - shift]]
        prev_token_idx = agent_token_index
        token_index_list.append(agent_token_index[:, None])
        token_contour_list.append(token_contour_select[:, None, ...])

    token_index = torch.cat(token_index_list, dim=1)
    token_contour = torch.cat(token_contour_list, dim=1)

    # extra matching（如果在第十一帧存在但第六帧不存在的代理，则根据第十帧的状态来匹配token信息）
    if not training:
        theta = heading[extra_mask, current_step - 1]
        prev_pos = pos[extra_mask, current_step - 1]
        cur_pos = pos[extra_mask, current_step]
        cur_heading = heading[extra_mask, current_step]
        cos, sin = theta.cos(), theta.sin()
        rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2)
        rot_mat[:, 0, 0] = cos
        rot_mat[:, 0, 1] = sin
        rot_mat[:, 1, 0] = -sin
        rot_mat[:, 1, 1] = cos
        agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape(
            extra_mask.sum(), token_num, token_contour_dim, feat_dim)
        agent_token_world += prev_pos[:, None, None, :]

        cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
        agent_token_index = torch.from_numpy(np.argmin(
            np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
            axis=-1))
        token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index]

        token_index[extra_mask, 1] = agent_token_index
        token_contour[extra_mask, 1] = token_contour_select

    return token_index, token_contour

def tokenize_agent(data):
    if data['agent']["velocity"].shape[1] == 90:
        print(data['scenario_id'], data['agent']["velocity"].shape)
    
    # 创建插值掩码 interplote_mask，用于标记那些当前时间步为无效但坐标非零的位置，以确定需要插值的数据点
    interplote_mask = (data['agent']['valid_mask'][:, current_step] == False) * (
            data['agent']['position'][:, current_step, 0] != 0)
    # 通过检查当前时间步中无效但位置非零的轨迹点，将其前一个时间步的位置、速度、航向等信息进行估算和填充，确保轨迹数据连续性
    if data['agent']["velocity"].shape[-1] == 2:
        data['agent']["velocity"] = torch.cat([data['agent']["velocity"],
                                                torch.zeros(data['agent']["velocity"].shape[0],
                                                            data['agent']["velocity"].shape[1], 1)], dim=-1)
    vel = data['agent']["velocity"][interplote_mask, current_step]
    # 插值前一个时间步的位置、航向、速度
    data['agent']['position'][interplote_mask, current_step - 1, :3] = data['agent']['position'][
                                                                            interplote_mask, current_step,
                                                                            :3] - vel * 0.1
    data['agent']['heading'][interplote_mask, current_step - 1] = data['agent']['heading'][
        interplote_mask, current_step]
    data['agent']["velocity"][interplote_mask, current_step - 1] = data['agent']["velocity"][
        interplote_mask, current_step]
    data['agent']['valid_mask'][interplote_mask, current_step - 1:current_step + 1] = True

    data['agent']['type'] = data['agent']['type'].to(torch.uint8)

    clean_heading(data)
    matching_extra_mask = (data['agent']['valid_mask'][:, current_step] == True) * (
            data['agent']['valid_mask'][:, current_step - 5] == False)

    interplote_mask_first = (data['agent']['valid_mask'][:, 0] == False) * (data['agent']['position'][:, 0, 0] != 0)
    data['agent']['valid_mask'][interplote_mask_first, 0] = True

    agent_pos = data['agent']['position'][:, :, :2]
    valid_mask = data['agent']['valid_mask']
    # 以下标1为起点，长度为6，间隔为5创建滑动窗口
    valid_mask_shift = valid_mask.unfold(1, shift + 1, shift)         # [NA, 18, 6]
    # 每个滑动窗口的起止都为true时窗口才有效
    token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1]   # [NA, 18]
    agent_type = data['agent']['type']
    agent_category = data['agent']['category']
    agent_heading = data['agent']['heading']
    vehicle_mask = agent_type == 0
    cyclist_mask = agent_type == 2
    ped_mask = agent_type == 1

    veh_pos = agent_pos[vehicle_mask, :, :]
    veh_valid_mask = valid_mask[vehicle_mask, :]
    cyc_pos = agent_pos[cyclist_mask, :, :]
    cyc_valid_mask = valid_mask[cyclist_mask, :]
    ped_pos = agent_pos[ped_mask, :, :]
    ped_valid_mask = valid_mask[ped_mask, :]

    veh_token_index, veh_token_contour = match_token(veh_pos, veh_valid_mask, agent_heading[vehicle_mask],
                                                            'veh', agent_category[vehicle_mask],
                                                            matching_extra_mask[vehicle_mask])
    ped_token_index, ped_token_contour = match_token(ped_pos, ped_valid_mask, agent_heading[ped_mask], 'ped',
                                                            agent_category[ped_mask], matching_extra_mask[ped_mask])
    cyc_token_index, cyc_token_contour = match_token(cyc_pos, cyc_valid_mask, agent_heading[cyclist_mask],
                                                            'cyc', agent_category[cyclist_mask],
                                                            matching_extra_mask[cyclist_mask])

    # token_index: [NA, 18(90/5)] 每个代理在90帧中匹配到的18个token索引
    token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64)
    token_index[vehicle_mask] = veh_token_index
    token_index[ped_mask] = ped_token_index
    token_index[cyclist_mask] = cyc_token_index

    # token_contour: [NA, 18, 4, 2] 每个代理在90帧中匹配到的18个token对应的矩形信息
    token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1],
                                    veh_token_contour.shape[2], veh_token_contour.shape[3]))
    token_contour[vehicle_mask] = veh_token_contour
    token_contour[ped_mask] = ped_token_contour
    token_contour[cyclist_mask] = cyc_token_contour

    # trajectory_token_veh = torch.from_numpy(trajectory_token['veh']).clone().to(torch.float)
    # trajectory_token_ped = torch.from_numpy(trajectory_token['ped']).clone().to(torch.float)
    # trajectory_token_cyc = torch.from_numpy(trajectory_token['cyc']).clone().to(torch.float)

    # agent_token_traj = torch.zeros((agent_pos.shape[0], trajectory_token_veh.shape[0], 4, 2))
    # agent_token_traj[vehicle_mask] = trajectory_token_veh
    # agent_token_traj[ped_mask] = trajectory_token_ped
    # agent_token_traj[cyclist_mask] = trajectory_token_cyc

    if not training:
        token_valid_mask[matching_extra_mask, 1] = True

    data['agent']['token_idx'] = token_index            # [NA, 18]
    data['agent']['token_contour'] = token_contour      # [NA, 18, 4, 2]
    token_pos = token_contour.mean(dim=2)               
    data['agent']['token_pos'] = token_pos              # [NA, 18, 2]
    diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :]
    data['agent']['token_heading'] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])  # [NA, 18]
    data['agent']['agent_valid_mask'] = token_valid_mask                                # [NA, 18]

    vel = torch.cat([token_pos.new_zeros(data['agent']['num_nodes'], 1, 2),
                        ((token_pos[:, 1:] - token_pos[:, :-1]) / (0.1 * shift))], dim=1)
    vel_valid_mask = torch.cat([torch.zeros(token_valid_mask.shape[0], 1, dtype=torch.bool),
                                (token_valid_mask * token_valid_mask.roll(shifts=1, dims=1))[:, 1:]], dim=1)
    vel[~vel_valid_mask] = 0
    vel[data['agent']['valid_mask'][:, current_step], 1] = data['agent']['velocity'][
                                                                data['agent']['valid_mask'][:, current_step],
                                                                current_step, :2]

    data['agent']['token_velocity'] = vel

    return data

token_data = tokenize_agent(data)

In [11]:
import math
from scipy.interpolate import interp1d
from scipy.spatial.distance import euclidean

def wrap_angle(
        angle: torch.Tensor,
        min_val: float = -math.pi,
        max_val: float = math.pi) -> torch.Tensor:
    return min_val + (angle + max_val) % (max_val - min_val)

def interplating_polyline(polylines, heading, distance=0.5, split_distace=5):
    # 多段线切分长度为5米，多段线内部点之间距离为2.5米，即每条多段线由3个点构成
    # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter
    dist_along_path_list = [[0]]
    polylines_list = [[polylines[0]]]
    for i in range(1, polylines.shape[0]):
        euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2])
        heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),
                           abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi))
        if heading_diff > math.pi / 4 and euclidean_dist > 3:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        elif heading_diff > math.pi / 8 and euclidean_dist > 3:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        elif heading_diff > 0.1 and euclidean_dist > 3:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        elif euclidean_dist > 10:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        else:
            dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)
            polylines_list[-1].append(polylines[i])
    # plt.plot(polylines[:, 0], polylines[:, 1])
    # plt.savefig('tmp.jpg')
    new_x_list = []
    new_y_list = []
    multi_polylines_list = []
    for idx in range(len(dist_along_path_list)):
        if len(dist_along_path_list[idx]) < 2:
            continue
        dist_along_path = np.array(dist_along_path_list[idx])
        polylines_cur = np.array(polylines_list[idx])
        # Create interpolation functions for x and y coordinates
        fx = interp1d(dist_along_path, polylines_cur[:, 0])
        fy = interp1d(dist_along_path, polylines_cur[:, 1])
        # fyaw = interp1d(dist_along_path, heading)

        # Create an array of distances at which to interpolate
        new_dist_along_path = np.arange(0, dist_along_path[-1], distance)
        new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]])
        # Use the interpolation functions to generate new x and y coordinates
        new_x = fx(new_dist_along_path)
        new_y = fy(new_dist_along_path)
        # new_yaw = fyaw(new_dist_along_path)
        new_x_list.append(new_x)
        new_y_list.append(new_y)

        # Combine the new x and y coordinates into a single array
        new_polylines = np.vstack((new_x, new_y)).T
        polyline_size = int(split_distace / distance)
        if new_polylines.shape[0] >= (polyline_size + 1):
            padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size
            final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1
        else:
            padding_size = new_polylines.shape[0]
            final_index = 0
        multi_polylines = None
        new_polylines = torch.from_numpy(new_polylines)
        new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],
                                  new_polylines[1:, 0] - new_polylines[:-1, 0])
        new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]
        new_polylines = torch.cat([new_polylines, new_heading], -1)
        if new_polylines.shape[0] >= (polyline_size + 1):
            multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)
            multi_polylines = multi_polylines.transpose(1, 2)
            multi_polylines = multi_polylines[:, ::5, :]
        if padding_size >= 3:
            last_polyline = new_polylines[final_index * polyline_size:]
            last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]
            if multi_polylines is not None:
                multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)
            else:
                multi_polylines = last_polyline.unsqueeze(0)
        if multi_polylines is None:
            continue
        multi_polylines_list.append(multi_polylines)
    if len(multi_polylines_list) > 0:
        multi_polylines_list = torch.cat(multi_polylines_list, dim=0)
    else:
        multi_polylines_list = None
    return multi_polylines_list

def tokenize_map(data):
    data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8)
    data['map_point']['type'] = data['map_point']['type'].to(torch.uint8)
    pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index']
    pt_type = data['map_point']['type'].to(torch.uint8)
    pt_side = torch.zeros_like(pt_type)
    pt_pos = data['map_point']['position'][:, :2]
    data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation'])
    pt_heading = data['map_point']['orientation']
    split_polyline_type = []
    split_polyline_pos = []
    split_polyline_theta = []
    split_polyline_side = []
    pl_idx_list = []
    split_polygon_type = []
    data['map_point']['type'].unique()

    # 对多段线进行便利
    for i in sorted(np.unique(pt2pl[1])):
        # 每一条多段线对应的点
        index = pt2pl[0, pt2pl[1] == i]
        polygon_type = data['map_polygon']["type"][i]
        cur_side = pt_side[index]
        cur_type = pt_type[index]
        cur_pos = pt_pos[index]
        cur_heading = pt_heading[index]

        for side_val in np.unique(cur_side):
            for type_val in np.unique(cur_type):
                if type_val == 13:
                    continue
                indices = np.where((cur_side == side_val) & (cur_type == type_val))[0]
                if len(indices) <= 2:
                    continue
                split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy())
                if split_polyline is None:
                    continue
                new_cur_type = cur_type[indices][0]
                new_cur_side = cur_side[indices][0]
                map_polygon_type = polygon_type.repeat(split_polyline.shape[0])
                new_cur_type = new_cur_type.repeat(split_polyline.shape[0])
                new_cur_side = new_cur_side.repeat(split_polyline.shape[0])
                cur_pl_idx = torch.Tensor([i])
                new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0])
                split_polyline_pos.append(split_polyline[..., :2])
                split_polyline_theta.append(split_polyline[..., 2])
                split_polyline_type.append(new_cur_type)
                split_polyline_side.append(new_cur_side)
                pl_idx_list.append(new_cur_pl_idx)
                split_polygon_type.append(map_polygon_type)

    split_polyline_pos = torch.cat(split_polyline_pos, dim=0)
    split_polyline_theta = torch.cat(split_polyline_theta, dim=0)
    split_polyline_type = torch.cat(split_polyline_type, dim=0)
    split_polyline_side = torch.cat(split_polyline_side, dim=0)
    split_polygon_type = torch.cat(split_polygon_type, dim=0)
    pl_idx_list = torch.cat(pl_idx_list, dim=0)
    vec = split_polyline_pos[:, 1, :] - split_polyline_pos[:, 0, :]
    data['map_save'] = {}
    data['pt_token'] = {}
    data['map_save']['traj_pos'] = split_polyline_pos
    data['map_save']['traj_theta'] = split_polyline_theta[:, 0]  # torch.arctan2(vec[:, 1], vec[:, 0])
    data['map_save']['pl_idx_list'] = pl_idx_list
    data['pt_token']['type'] = split_polyline_type
    data['pt_token']['side'] = split_polyline_side
    data['pt_token']['pl_type'] = split_polygon_type
    data['pt_token']['num_nodes'] = split_polyline_pos.shape[0]
    return data

token_data = tokenize_map(token_data)

In [12]:
del token_data['city']
if 'polygon_is_intersection' in token_data['map_polygon']:
    print("delete polygon_is_intersection")
    del token_data['map_polygon']['polygon_is_intersection']
if 'route_type' in data['map_polygon']:
    print("delete route_type")
    del token_data['map_polygon']['route_type']

# 3 将字典转换成HeteroData类型数据

通过dataloader加载数据时，会在批处理时自动添加batch和ptr字段

+ `batch` 字段：表示每个节点或边属于哪个样本。对于每个节点（或边），batch 中的值表示该节点所属样本的索引。

+ `ptr` 字段：用于记录每个样本的起始索引。这在生成批次时对边的连接（例如跨图连接）很有帮助。

In [13]:
from torch_geometric.data import Dataset, HeteroData
from torch_geometric.loader import DataLoader

class CustomHeteroDataset(Dataset):
    def __init__(self, data_list):
        super(CustomHeteroDataset, self).__init__()
        self.data_list = data_list
    
    def len(self):
        return len(self.data_list)
    
    def get(self, idx):
        batch_data = HeteroData()

        for node_type, node_data in self.data_list[idx].items():
            if isinstance(node_type, str):  # 处理节点数据
                if isinstance(node_data, dict):
                    for attr, value in node_data.items():
                        batch_data[node_type][attr] = value
                else:
                    batch_data[node_type] = [node_data]

        for edge_type, edge_data in self.data_list[idx].items():
            if isinstance(edge_type, tuple) and len(edge_type) == 3:  # 处理边数据
                if isinstance(edge_data, dict):
                    for attr, value in edge_data.items():
                        batch_data[edge_type][attr] = value
                else:
                    batch_data[edge_type] = edge_data
        return batch_data

dataset = CustomHeteroDataset([token_data])
loader = DataLoader(dataset, batch_size=1)
batch = next(iter(loader))


# 4 模型内部匹配地图token

**data**

+ `scenario_id`: *str*, 场景id

+ `agent`: *dict*, 场景中的代理信息

  - `num_nodes`: *int*, 代理数量

  - `av_index`: *int*, AV对应的代理编号

  - `valid_mask`: *array([NA, NT])*, 各代理各时间步是否有效

  - `predict_mask`: *array([NA, NT])*, 各代理需要被预测的时间步

  - `id`: *list([NA])*, 代理ID 

  - `type`: *array([NA])*, 代理类型

    > 0 'vehicle', 1 'pedestrian', 2 'cyclist', 3 'background'

  - `category`: *array([NA])*, 代理类别

    > 1 其他代理, 2 感兴趣代理, 3 被预测代理

  - `position`: *array([NA, NT, 3])*, 代理位置 (x, y, z)

  - `heading`: *array([NA, NT])*, 代理航向

  - `velocity`: *array([NA, NT, 3])*, 代理速度 (vx, vy, 0)

  - `shape`: *array([NA, NT, 3])*, 代理尺寸 (length, width, height)

  - `token_idx`: *array([NA, 18])*, 每个代理在90帧中匹配到的18个token索引(5帧为间隔,18=90/5)

  - `token_contour`: *array([NA, 18, 4, 2])*, 每个代理在90帧中匹配到的18个token对应的矩形信息

  - `token_pos`: *array([NA, 18, 2])*, 每个代理在90帧中匹配到的18个token对应的矩形中心点全局坐标

  - `token_heading`: *array([NA, 18])*, 每个代理在90帧中匹配到的18个token处对应的全局航向角

  - `token_velocity`: *array([NA, 18, 2])*, 每个代理在90帧中匹配到的18个token处的速度(vx, vy)

  - `agent_valid_mask`: *array([NA, 18])*, 每个代理在90帧中匹配到的18个token是否有效

+ `map_save`: *dict*, 将多段线按5米进行拆分，多段线内点间距为2.5米重新存图

  - `traj_pos`: *array([n_polyline, 3, 2])*, [多段线数量, 每个多段线中有3个点，点坐标xy]

  - `traj_theta`: *array([n_polyline])*, 各多段线的朝向（起始段朝向作为代表）
  
  - `pl_idx_list`: *array([n_polyline])*, 划分后的多段线对应在划分前多段线的索引

+ `pt_token`: *dict*, 

  - `type`: *array([n_polyline])*, 多段线中点的类型

    > [ 'DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
    > 'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
    > 'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
    > 'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']

  - `side`: *array([n_polyline])*, 多段线在道路哪一侧

    > 0: left_side, 1: right_side, 2: center_side

  - `pl_type`: *array([n_polyline])*, 多段线类型

    > ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']

  - `num_nodes`: *int*, 划分后多段线数量

  - `position`: *array([n_polyline, 3])*, 各地图多段线的起始点坐标 x,y,z(z=0)

  - `orientation`: *array([n_polyline])*, 各地图多段线的起始位置朝向（相对于正 x 轴的弧度值，范围为 [-π, π]）

  - `height`: *array([n_polyline])*, 各地图多段线的起始点高度(0)

  - `token_idx`: *array([n_polyline])*, 根据map_save['traj_pos']匹配到的地图token索引

  - `traj_mask`: *array([n_map_poly, 3, max_token_num])*, 记录每个原始地图多段线在左侧、右侧、中心的轨迹掩码（针对每条多段线，从 0到该条多段线的token数 的下标区间为true，其余为false）

  - `pt_valid_mask`: *array([n_polyline])*, 基于traj_mask，从每个原始多段线中随机选取1/3的traj值作为被预测对象掩码掉（置为false）

  - `pt_pred_mask`: *array([n_polyline])*, 记录需要预测下一个地图token的索引（哪些点是需要预测的轨迹点）

  - `pt_target_mask`: *array([n_polyline])*, 记录真实的下一个地图token的索引

+ `map_polygon`: *dict*, 地图中多段线信息

  - `num_nodes`: *int*, 多段线数量

  - `type`: *array([num_nodes])*, 每条多段线对应的类型

    > ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']

  - `light_type`: *array([num_nodes])*, 每条多段线对应的交通信号状态

    > ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']

+ `map_point`: *dict*, 地图中各点信息

  - `num_nodes`: *int*, 点数量

  - `position`: *array([num_nodes, 3])*, 点的坐标(x, y, z)

  - `orientation`: *array([num_nodes])*, 各点对应向量的方向角（以弧度表示）

  - `magnitude`: *array([num_nodes])*, 各点对应向量的大小（即距离）

  - `height`: *array([num_nodes])*, 各点与下一点之间的高度差 *(dim=3时才存在)*

  - `type`: *array([num_nodes])*, 各点所属类型

    > [ 'DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
    > 'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
    > 'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
    > 'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']

+ `'map_point', 'to', 'map_polygon'`: *dict*, 点与多段线之间的索引映射关系

  - `edge_index`: *array([2, num_points])*, [point_idx, polygon_idx]

+ `'map_polygon', 'to', 'map_polygon'`: *dict*, 车道与车道之间的拓扑关系

  - `edge_index`: *array([2, num_edges])*, [polygon1_idx, polygon2_idx]

  - `type`: *array([num_edges])*, 拓扑关系类型-polygon1是polygon2的前驱/后继/左邻居/右邻居

    > ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']

+ `'pt_token', 'to', 'map_polygon'`: *dict*, pt_token['token_idx']与原多段线之间的索引映射关系

  - `edge_index`: *array([2, n_polyline])*, [token_idx, polygon_idx]

## 4.1 加载轨迹token

**agent_token_data**

聚类得到的token信息，即以[0,0]为第一帧中心坐标，五帧为时间间隔到达的终点位置

+ `token`: *dict*, 每个token对应的终点位置矩形的四个角点坐标(x, y)，**角点顺序为-左前、右前、右后、左后**
    - `veh`: *array([2048, 4, 2])*
    - `ped`: *array([2048, 4, 2])*
    - `cyc`: *array([2048, 4, 2])*
+ `traj`: *dict*, 每个token对应的6帧完整轨迹中心坐标 (x, y ,z)
    - `veh`: *array([2048, 6, 3])*
    - `ped`: *array([2048, 6, 3])*
    - `cyc`: *array([2048, 6, 3])*
+ `token_all`: *dict*, 每个token对应的6帧完整轨迹各处矩形四个角点坐标
    - `veh`: *array([2048, 6, 4, 2])*
    - `ped`: *array([2048, 6, 4, 2])*
    - `cyc`: *array([2048, 6, 4, 2])*

In [14]:
agent_token_path = "/home/yangyh408/codes/SMART/smart/tokens/cluster_frame_5_2048.pkl"
agent_token_data = pickle.load(open(agent_token_path, 'rb'))
trajectory_token = agent_token_data['token']
trajectory_token_traj = agent_token_data['traj']
trajectory_token_all = agent_token_data['token_all']

## 4.2 加载地图token

**map_token_data**

聚类得到的地图token信息，即以[0,0]为初始位置，连续11个点的坐标

+ `traj_src`: *array([1024, 11, 2])*, 每个地图token对应的多段线信息，即11个点的xy坐标
+ `sample_pt`: *array([1024, 3, 2])*, 对地图token的多段线信息进行采样，仅保留索引为[0, 5, 10]的三个点信息
+ `traj_end_theta`: *array([1024])*, 根据traj_src计算各地图token在最后一个位置处的朝向

In [15]:
argmin_sample_len = 3

map_token_traj_path = "/home/yangyh408/codes/SMART/smart/tokens/map_traj_token5.pkl"
map_token_traj = pickle.load(open(map_token_traj_path, 'rb'))

map_token = {'traj_src': map_token_traj['traj_src'], }
traj_end_theta = np.arctan2(map_token['traj_src'][:, -1, 1]-map_token['traj_src'][:, -2, 1],
                            map_token['traj_src'][:, -1, 0]-map_token['traj_src'][:, -2, 0])
# 生成从 start 到 end 的 steps 个等间隔值。
indices = torch.linspace(0, map_token['traj_src'].shape[1]-1, steps=argmin_sample_len).long()
map_token['sample_pt'] = torch.from_numpy(map_token['traj_src'][:, indices]).to(torch.float)
map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
map_token['traj_src'] = torch.from_numpy(map_token['traj_src']).to(torch.float)

## 4.3 匹配地图token

In [16]:
def match_token_map(data):
    traj_pos = data['map_save']['traj_pos'].to(torch.float)
    traj_theta = data['map_save']['traj_theta'].to(torch.float)
    pl_idx_list = data['map_save']['pl_idx_list']
    token_sample_pt = map_token['sample_pt'].to(traj_pos.device)
    token_src = map_token['traj_src'].to(traj_pos.device)
    max_traj_len = map_token['traj_src'].shape[1]
    pl_num = traj_pos.shape[0]

    # 各地图多段线的起始点坐标xy
    pt_token_pos = traj_pos[:, 0, :].clone()
    # 各地图多段线的起始位置朝向
    pt_token_orientation = traj_theta.clone()
    # 将地图多段线由全局坐标系转换为局部坐标系
    cos, sin = traj_theta.cos(), traj_theta.sin()
    rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
    rot_mat[..., 0, 0] = cos
    rot_mat[..., 0, 1] = -sin
    rot_mat[..., 1, 0] = sin
    rot_mat[..., 1, 1] = cos
    traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
    # 将坐标转换后的多段线与地图map_token进行匹配
    distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1))
    pt_token_id = torch.argmin(distance, dim=1)

    if noise:
        topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8]
        sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
        pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)

    cos, sin = traj_theta.cos(), traj_theta.sin()
    rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
    rot_mat[..., 0, 0] = cos
    rot_mat[..., 0, 1] = sin
    rot_mat[..., 1, 0] = -sin
    rot_mat[..., 1, 1] = cos
    token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
                                rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
    token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)

    pl_idx_full = pl_idx_list.clone()
    token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
    count_nums = []
    for pl in pl_idx_full.unique():
        pt = token2pl[0, token2pl[1, :] == pl]
        left_side = (data['pt_token']['side'][pt] == 0).sum()
        right_side = (data['pt_token']['side'][pt] == 1).sum()
        center_side = (data['pt_token']['side'][pt] == 2).sum()
        count_nums.append(torch.Tensor([left_side, right_side, center_side]))
    # count_nums: [N_polyline, 3]分别记录每个原始多段线对应的左侧、右侧、中心token有多少
    count_nums = torch.stack(count_nums, dim=0)
    # 获取每个原始多段线对应的最多token数量
    max_token_num = int(count_nums.max().item())
    # 构建多段线的轨迹掩码 [N_polyline, 3, max_token_num]
    traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, max_token_num), dtype=bool)
    idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
    idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1)    #[N_polyline, 3, max_token_num]
    counts_num_expanded = count_nums.unsqueeze(-1)                              #[N_polyline, 3, 1]
    traj_mask[idx_matrix < counts_num_expanded] = True

    data['pt_token']['traj_mask'] = traj_mask
    data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
                                                                        device=traj_pos.device, dtype=torch.float)], dim=-1)
    data['pt_token']['orientation'] = pt_token_orientation
    data['pt_token']['height'] = data['pt_token']['position'][:, -1]
    data[('pt_token', 'to', 'map_polygon')] = {}
    data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl
    data['pt_token']['token_idx'] = pt_token_id
    return data

batch = match_token_map(batch)

## 4.4 随机生成地图token预测掩码信息

In [17]:
def sample_pt_pred(data):
    # traj_mask: [n_map_poly, 3, max_token_num]
    traj_mask = data['pt_token']['traj_mask']
    # 从每个原始多段线中随机选取1/3的traj值被掩码掉
    raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1)
    masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)]
    masked_pt_index = torch.sort(masked_pt_index, -1)[0]
    # 有效掩码
    pt_valid_mask = traj_mask.clone()
    pt_valid_mask.scatter_(2, masked_pt_index, False)
    # 预测掩码
    pt_pred_mask = traj_mask.clone()
    pt_pred_mask.scatter_(2, masked_pt_index, False)
    tmp_mask = pt_pred_mask.clone()
    tmp_mask[:, :, :] = True
    tmp_mask.scatter_(2, masked_pt_index-1, False)
    pt_pred_mask.masked_fill_(tmp_mask, False)
    pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2)
    # 目标掩码
    pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2)
    # 通过traj_mask将生成的掩码向量从[n_map_poly, 3, max_token_num]转换为[n_polyline]的形式，使其与token信息对应
    data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask]
    data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask]
    data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask]

    return data

batch = sample_pt_pred(batch)
batch['agent']['av_index'] += batch['agent']['ptr'][:-1]

# 5 使用batch数据进行模型推理

In [19]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
torch.manual_seed(12)

from smart.model import SMART
from smart.utils.config import load_config_act
from smart.utils.log import Logging

config = load_config_act("../configs/validation/validation_scalable.yaml")
pretrain_ckpt = "../ckpt/20241021_1037/epoch=07-step=30440-val_loss=2.52.ckpt"
Predictor = SMART
logger = Logging().log(level='DEBUG')
model = Predictor(config.Model)
model.load_params_from_file(filename=pretrain_ckpt, logger=logger)
model.eval()

with torch.no_grad():
    pred = model.inference(batch)
    # pred = model(batch)

pred

2024-11-17 23:44:08,730-INFO-smart.py-Line:222-Message:==> Loading parameters from checkpoint ../ckpt/20241021_1037/epoch=07-step=30440-val_loss=2.52.ckpt to GPU
2024-11-17 23:44:09,189-INFO-smart.py-Line:231-Message:The number of disk ckpt keys: 818
2024-11-17 23:44:09,296-INFO-smart.py-Line:247-Message:Missing keys: []
2024-11-17 23:44:09,297-INFO-smart.py-Line:248-Message:The number of missing keys: 0
2024-11-17 23:44:09,298-INFO-smart.py-Line:249-Message:The number of unexpected keys: 0
2024-11-17 23:44:09,298-INFO-smart.py-Line:250-Message:==> Done (total keys 818)


{'x_pt': tensor([[-2.8628,  0.7293, -7.0975,  ...,  1.2902,  2.5426,  1.9984],
         [-2.6509,  0.5438, -7.2383,  ...,  1.6891,  2.8929,  1.8436],
         [-2.1915,  0.7175, -8.1312,  ...,  2.0025,  2.2056,  1.7226],
         ...,
         [ 4.0198, -2.7484, -5.1329,  ...,  4.4264,  3.8843,  2.2978],
         [ 3.1914, -2.2112, -6.0598,  ...,  4.0475,  4.3945,  3.8649],
         [ 5.0166, -2.0447, -5.4248,  ...,  3.6456,  4.0969,  3.7391]]),
 'map_next_token_idx': tensor([[218, 166, 444,  ..., 908, 630, 487],
         [218, 166, 444,  ..., 128, 362,  52],
         [933, 218, 603,  ..., 630, 555,  35],
         ...,
         [613, 871, 873,  ..., 311, 659,  25],
         [873, 954, 533,  ..., 659, 603, 123],
         [267, 633, 494,  ..., 867, 873, 365]]),
 'map_next_token_prob': tensor([[-8.8503e-02,  2.6148e-02,  2.8189e-01,  ...,  1.3059e-01,
          -7.5653e-02, -2.6913e-01],
         [-3.1502e-02,  3.2731e-02,  3.2199e-01,  ...,  7.6186e-02,
          -1.1064e-01, -2.7430e-01

In [22]:
from smart.datasets.scalable_dataset import MultiDataset
from smart.model import SMART
from smart.transforms import WaymoTargetBuilder
from smart.utils.config import load_config_act
from smart.utils.log import Logging
from smart.metrics.real_metrics.custom_metrics import RealMetrics
from smart.metrics.real_metrics.real_features import compute_real_metric_features

import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

import os
import torch
import pickle
import numpy as np
from pathlib import Path
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from waymo_open_dataset.utils.sim_agents import visualizations
from waymo_open_dataset.protos import scenario_pb2

config = load_config_act("../configs/testing/testing_scalable.yaml")

data_config = config.Dataset
test_dataset = {
    "scalable": MultiDataset,
}[data_config.dataset](root=data_config.root, split='val',
                        raw_dir=data_config.val_raw_dir,
                        processed_dir=data_config.val_processed_dir,
                        transform=WaymoTargetBuilder(config.Model.num_historical_steps, config.Model.decoder.num_future_steps))
dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=False, persistent_workers=True)
    

2024-11-17 21:11:59,903-DEBUG-scalable_dataset.py-Line:38-Message:Starting loading dataset with MultiDataset
2024-11-17 21:12:18,186-DEBUG-scalable_dataset.py-Line:62-Message:The number of val dataset is 44911


In [24]:
data_iter = iter(dataloader)
batch = next(data_iter)
batch

HeteroDataBatch(
  scenario_id=[1],
  agent={
    num_nodes=14,
    av_index=[1],
    valid_mask=[14, 91],
    predict_mask=[14, 91],
    id=[1],
    type=[14],
    category=[14],
    position=[14, 91, 3],
    heading=[14, 91],
    velocity=[14, 91, 3],
    shape=[14, 91, 3],
    token_idx=[14, 18],
    token_contour=[14, 18, 4, 2],
    token_pos=[14, 18, 2],
    token_heading=[14, 18],
    agent_valid_mask=[14, 18],
    token_velocity=[14, 18, 2],
    batch=[14],
    ptr=[2],
  },
  map_polygon={
    num_nodes=121,
    type=[121],
    light_type=[121],
    batch=[121],
    ptr=[2],
  },
  map_point={
    num_nodes=16751,
    position=[16751, 3],
    orientation=[16751],
    magnitude=[16751],
    height=[16751],
    type=[16751],
    batch=[16751],
    ptr=[2],
  },
  map_save={
    traj_pos=[1679, 3, 2],
    traj_theta=[1679],
    pl_idx_list=[1679],
  },
  pt_token={
    type=[1679],
    side=[1679],
    pl_type=[1679],
    num_nodes=1679,
    batch=[1679],
    ptr=[2],
  },
  (map_

In [36]:
batch['agent']['token_idx']


tensor([[ 180,  255, 1401, 1401,   31,   31,   31,   31,   31,   31,   31,   31,
           31,   31,   31,   31,   31,   31],
        [  31,  443, 1401, 1401,   31,   31,   31,   31,   31,   31,   31,   31,
           31,   31,   31,   31,   31,   31],
        [  31,  443, 1401, 1401,   31,   31,   31,   31,   31,   31,   31,   31,
           31,   31,   31,   31,   31,   31],
        [  31,  443,  510,  510,   31,   31,   31,   31,   31,   31,   31,   31,
           31,   31,   31,   31,   31,   31],
        [  31,  443,  245,  245,   31,   31,   31,   31,   31,   31,   31,   31,
           31,   31,   31,   31,   31,   31],
        [  31,  443,  510, 1753,   31,   31,   31,   31,   31,   31,   31,   31,
           31,   31,   31,   31,   31,   31],
        [  31,  443,  245,  245,   31,   31,   31,   31,   31,   31,   31,   31,
           31,   31,   31,   31,   31,   31],
        [  31,  443,  510, 1833,   31,   31,   31,   31,   31,   31,   31,   31,
           31,   31,   31,   3

In [32]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
torch.manual_seed(12)

from smart.model import SMART
from smart.utils.config import load_config_act
from smart.utils.log import Logging

config = load_config_act("../configs/validation/validation_scalable.yaml")
pretrain_ckpt = "../ckpt/20241021_1037/epoch=07-step=30440-val_loss=2.52.ckpt"
Predictor = SMART
logger = Logging().log(level='DEBUG')
model = Predictor(config.Model)
model.load_params_from_file(filename=pretrain_ckpt, logger=logger)
model.eval()

with torch.no_grad():
    data = model.match_token_map(batch)
    data = model.sample_pt_pred(data)
    data['agent']['av_index'] += data['agent']['ptr'][:-1]
    pred = model.inference(data)
    # pred = model(batch)

pred

2024-11-17 21:17:40,474-INFO-smart.py-Line:222-Message:==> Loading parameters from checkpoint ../ckpt/20241021_1037/epoch=07-step=30440-val_loss=2.52.ckpt to GPU
2024-11-17 21:17:40,932-INFO-smart.py-Line:231-Message:The number of disk ckpt keys: 818
2024-11-17 21:17:41,061-INFO-smart.py-Line:247-Message:Missing keys: []
2024-11-17 21:17:41,062-INFO-smart.py-Line:248-Message:The number of missing keys: 0
2024-11-17 21:17:41,062-INFO-smart.py-Line:249-Message:The number of unexpected keys: 0
2024-11-17 21:17:41,063-INFO-smart.py-Line:250-Message:==> Done (total keys 818)


{'x_pt': tensor([[ 4.0965,  1.5222, -3.1555,  ..., -1.6200,  4.8519, -5.5597],
         [ 1.8451, -1.2355, -2.8023,  ..., -4.7137,  5.5168, -4.8254],
         [ 0.8611, -2.9032, -1.4840,  ..., -4.2531,  5.7611, -0.6347],
         ...,
         [ 1.7210, -6.5852, -3.3734,  ..., -1.1363,  2.4527,  0.9055],
         [ 0.8461, -5.1582, -3.0413,  ..., -2.0396,  3.1396,  0.3555],
         [ 0.2282, -8.2308, -3.9181,  ..., -2.7247,  3.4893, -2.0633]]),
 'map_next_token_idx': tensor([[ 733,  648,  167,  ...,  885,  448,  166],
         [ 476,  167,  448,  ..., 1009,  961,  722],
         [ 476,  448,  933,  ...,  514,  166,  173],
         ...,
         [ 885,  613,   25,  ...,    6,  593,  883],
         [ 885,  864,  287,  ...,  873,  122,   30],
         [ 885,  613,  593,  ...,  196,  528,  194]]),
 'map_next_token_prob': tensor([[-0.1768, -0.1692, -0.1485,  ..., -0.1117, -0.3023,  0.2430],
         [-0.1941, -0.1418,  0.0420,  ..., -0.0900, -0.3504,  0.3512],
         [-0.1816, -0.0418,  