# 0 准备工作

```bash
# 复制smart环境
conda create --name smart-limsim --clone smart

# 安装SUMO
sudo add-apt-repository ppa:sumo/stable
sudo apt-get update
sudo apt-get install sumo sumo-tools sumo-doc

# 导入limsim需要的库
cd MARL-LimSim
pip install -r requirements.txt
```

smart是基于python3.8的版本，而limsim支持3.9-3.11，因此在运行时有一些报错需要额外处理：

1. TypeError: 'type' object is not subscriptable错误
    + 解决方法：在每一个报错路径的首行添加：`from __future__ import annotations`

# 1 地图信息提取

## 1.1 Waymo数据集中的地图信息

### 1.1.1 validation数据

In [57]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

import torch
from torch_geometric.loader import DataLoader
from smart.datasets.scalable_dataset import MultiDataset
from smart.transforms import WaymoTargetBuilder
from smart.utils.config import load_config_act

config = load_config_act("../configs/validation/validation_scalable.yaml")
data_config = config.Dataset
val_dataset = {
    "scalable": MultiDataset,
}[data_config.dataset](root=data_config.root, split='val',
                        raw_dir=data_config.val_raw_dir,
                        processed_dir=data_config.val_processed_dir,
                        transform=WaymoTargetBuilder(config.Model.num_historical_steps, config.Model.decoder.num_future_steps))
dataloader = DataLoader(val_dataset, batch_size=data_config.batch_size, shuffle=False, num_workers=data_config.num_workers,
                        pin_memory=data_config.pin_memory, persistent_workers=True if data_config.num_workers > 0 else False)

# 修改scenario_idx以获取不同的场景batch信息
SCENARIO_IDX = 3
data_iter = iter(dataloader)
for i in range(SCENARIO_IDX):
    batch = next(data_iter)

print("-" * 100)
print(f"Scenario Index: {SCENARIO_IDX}")
print(f"Scenario ID: {batch.scenario_id[0]}")
print(f"不同多段线类型对应的多段线条数: {torch.bincount(batch['map_polygon']['type'], minlength=4).tolist()}")
# 0:'VEHICLE', 1:'BIKE', 2:'BUS', 3:'PEDESTRIAN'
print(f"不同信控信号对应的多段线条数: {torch.bincount(batch['map_polygon']['light_type'], minlength=4).tolist()}")
# 0:'LANE_STATE_STOP', 1:'LANE_STATE_GO', 2:'LANE_STATE_CAUTION', 3:'LANE_STATE_UNKNOWN'
print(f"不同的地图点类型对应的个数: {torch.bincount(batch['map_point']['type'], minlength=17).tolist()}")
# 0:'DASH_SOLID_YELLOW', 1:'DASH_SOLID_WHITE', 2:'DASHED_WHITE', 3:'DASHED_YELLOW', 4:'DOUBLE_SOLID_YELLOW', 5:'DOUBLE_SOLID_WHITE', 6:'DOUBLE_DASH_YELLOW', 7:'DOUBLE_DASH_WHITE',
# 8:'SOLID_YELLOW', 9:'SOLID_WHITE', 10:'SOLID_DASH_WHITE', 11:'SOLID_DASH_YELLOW', 12:'EDGE', 13:'NONE', 14:'UNKNOWN', 15:'CROSSWALK', 16:'CENTERLINE'

2024-11-20 12:14:19,080-DEBUG-scalable_dataset.py-Line:38-Message:Starting loading dataset with MultiDataset
2024-11-20 12:14:19,188-DEBUG-scalable_dataset.py-Line:62-Message:The number of val dataset is 44097


----------------------------------------------------------------------------------------------------
Scenario Index: 3
Scenario ID: 1001ebb6d3905d92
不同多段线类型对应的多段线条数: [685, 37, 2, 11]
不同信控信号对应的多段线条数: [3, 5, 0, 727]
不同的地图点类型对应的个数: [751, 0, 3333, 0, 640, 0, 0, 0, 23, 3770, 0, 0, 7319, 0, 0, 33, 24906]


### 1.1.2 testing数据

In [205]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

import torch
from torch_geometric.loader import DataLoader
from smart.datasets.scalable_dataset import MultiDataset
from smart.transforms import WaymoTargetBuilder
from smart.utils.config import load_config_act

config = load_config_act("../configs/testing/testing_scalable.yaml")
data_config = config.Dataset
val_dataset = {
    "scalable": MultiDataset,
}[data_config.dataset](root=data_config.root, split='test',
                        raw_dir=data_config.val_raw_dir,
                        processed_dir=data_config.val_processed_dir,
                        transform=WaymoTargetBuilder(config.Model.num_historical_steps, config.Model.decoder.num_future_steps))
dataloader = DataLoader(val_dataset, batch_size=data_config.batch_size, shuffle=False, num_workers=data_config.num_workers,
                        pin_memory=data_config.pin_memory, persistent_workers=True if data_config.num_workers > 0 else False)

# 修改scenario_idx以获取不同的场景batch信息
SCENARIO_IDX = 3
data_iter = iter(dataloader)
for i in range(SCENARIO_IDX):
    batch = next(data_iter)

2024-11-21 16:14:32,044-DEBUG-scalable_dataset.py-Line:38-Message:Starting loading dataset with MultiDataset
2024-11-21 16:14:32,149-DEBUG-scalable_dataset.py-Line:62-Message:The number of test dataset is 44911


In [213]:
for k, v in batch['agent'].items():
    print(k)

num_nodes
av_index
valid_mask
predict_mask
id
type
category
position
heading
velocity
shape
token_idx
token_contour
token_pos
token_heading
agent_valid_mask
token_velocity
batch
ptr


In [243]:
batch['agent']['velocity'].shape

torch.Size([62, 91, 3])

## 1.2 加载limsim预处理出的地图信息

### 1.2.1 加载预处理后的limsim地图pkl文件

In [244]:
import torch
import os
import pickle

# waymo预处理的地图相关信息格式（仅供参考）
print("Waymo processed map info: ")
print("-" * 50)
with open(os.path.join(f"/mnt/i/smart_waymo_processed/raw/validation/1001ebb6d3905d92.pkl"), 'rb') as handle:
    raw_data = pickle.load(handle)
del raw_data['city']
del raw_data[('map_polygon', 'to', 'map_polygon')]
del raw_data['agent']
for k, v in raw_data.items():
    if isinstance(v, dict):
        print(f"{k}:")
        for kk, vv in v.items():
            if isinstance(vv, torch.Tensor):
                print(f"  {kk}: {vv.shape}")
            else:
                print(f"  {kk}: {vv}")
    else:
        print(f"{k}: {v}")

# 加载limsim预处理的地图信息
print("\nLimSim processed map info: ")
print("-" * 50)
with open(os.path.join(f"/home/yangyh408/codes/SMART/data/limsim/limsim_meta_inter.pkl"), 'rb') as handle:
    raw_data = pickle.load(handle)
for k, v in raw_data.items():
    if isinstance(v, dict):
        print(f"{k}:")
        for kk, vv in v.items():
            if isinstance(vv, torch.Tensor):
                print(f"  {kk}: {vv.shape}")
            else:
                print(f"  {kk}: {vv}")
    else:
        print(f"{k}: {v}")

Waymo processed map info: 
--------------------------------------------------
scenario_id: 1001ebb6d3905d92
map_polygon:
  num_nodes: 735
  type: torch.Size([735])
  light_type: torch.Size([735])
map_point:
  num_nodes: 40775
  position: torch.Size([40775, 3])
  orientation: torch.Size([40775])
  magnitude: torch.Size([40775])
  height: torch.Size([40775])
  type: torch.Size([40775])
('map_point', 'to', 'map_polygon'):
  edge_index: torch.Size([2, 40775])

LimSim processed map info: 
--------------------------------------------------
scenario_id: limsim_meta_inter
map_polygon:
  num_nodes: 75
  type: torch.Size([75])
  light_type: torch.Size([75])
map_point:
  num_nodes: 3675
  position: torch.Size([3675, 3])
  orientation: torch.Size([3675])
  magnitude: torch.Size([3675])
  height: torch.Size([3675])
  type: torch.Size([3675])
('map_point', 'to', 'map_polygon'):
  edge_index: torch.Size([2, 3675])
agent:
  num_nodes: 5
  av_index: None
  valid_mask: torch.Size([5, 91])
  predict_mask

### 1.2.2 将地图多段线按地图token的形式进行切分

每条切分后的多段线的长度为5m，按0.5m的间距划分为11个点，仅采样[0, 5, 10]三个索引处的点作为新多段线，新多段线信息保存在`['map_save']`中用于后续地图token匹配

In [245]:
import math
import numpy as np
from scipy.interpolate import interp1d
from scipy.spatial.distance import euclidean

def wrap_angle(
        angle: torch.Tensor,
        min_val: float = -math.pi,
        max_val: float = math.pi) -> torch.Tensor:
    return min_val + (angle + max_val) % (max_val - min_val)

def interplating_polyline(polylines, heading, distance=0.5, split_distace=5):
    # 多段线切分长度为5米，多段线内部点之间距离为2.5米，即每条多段线由3个点构成
    # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter
    dist_along_path_list = [[0]]
    polylines_list = [[polylines[0]]]
    for i in range(1, polylines.shape[0]):
        euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2])
        heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),
                           abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi))
        if heading_diff > math.pi / 4 and euclidean_dist > 3:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        elif heading_diff > math.pi / 8 and euclidean_dist > 3:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        elif heading_diff > 0.1 and euclidean_dist > 3:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        elif euclidean_dist > 10:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        else:
            dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)
            polylines_list[-1].append(polylines[i])
    # plt.plot(polylines[:, 0], polylines[:, 1])
    # plt.savefig('tmp.jpg')
    new_x_list = []
    new_y_list = []
    multi_polylines_list = []
    for idx in range(len(dist_along_path_list)):
        if len(dist_along_path_list[idx]) < 2:
            continue
        dist_along_path = np.array(dist_along_path_list[idx])
        polylines_cur = np.array(polylines_list[idx])
        # Create interpolation functions for x and y coordinates
        fx = interp1d(dist_along_path, polylines_cur[:, 0])
        fy = interp1d(dist_along_path, polylines_cur[:, 1])
        # fyaw = interp1d(dist_along_path, heading)

        # Create an array of distances at which to interpolate
        new_dist_along_path = np.arange(0, dist_along_path[-1], distance)
        new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]])
        # Use the interpolation functions to generate new x and y coordinates
        new_x = fx(new_dist_along_path)
        new_y = fy(new_dist_along_path)
        # new_yaw = fyaw(new_dist_along_path)
        new_x_list.append(new_x)
        new_y_list.append(new_y)

        # Combine the new x and y coordinates into a single array
        new_polylines = np.vstack((new_x, new_y)).T
        polyline_size = int(split_distace / distance)
        if new_polylines.shape[0] >= (polyline_size + 1):
            padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size
            final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1
        else:
            padding_size = new_polylines.shape[0]
            final_index = 0
        multi_polylines = None
        new_polylines = torch.from_numpy(new_polylines)
        new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],
                                  new_polylines[1:, 0] - new_polylines[:-1, 0])
        new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]
        new_polylines = torch.cat([new_polylines, new_heading], -1)
        if new_polylines.shape[0] >= (polyline_size + 1):
            multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)
            multi_polylines = multi_polylines.transpose(1, 2)
            multi_polylines = multi_polylines[:, ::5, :]
        if padding_size >= 3:
            last_polyline = new_polylines[final_index * polyline_size:]
            last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]
            if multi_polylines is not None:
                multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)
            else:
                multi_polylines = last_polyline.unsqueeze(0)
        if multi_polylines is None:
            continue
        multi_polylines_list.append(multi_polylines)
    if len(multi_polylines_list) > 0:
        multi_polylines_list = torch.cat(multi_polylines_list, dim=0)
    else:
        multi_polylines_list = None
    return multi_polylines_list

def tokenize_map(data):
    data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8)
    data['map_point']['type'] = data['map_point']['type'].to(torch.uint8)
    pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index']
    pt_type = data['map_point']['type'].to(torch.uint8)
    pt_side = torch.zeros_like(pt_type)
    pt_pos = data['map_point']['position'][:, :2]
    data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation'])
    pt_heading = data['map_point']['orientation']
    split_polyline_type = []
    split_polyline_pos = []
    split_polyline_theta = []
    split_polyline_side = []
    pl_idx_list = []
    split_polygon_type = []
    data['map_point']['type'].unique()

    # 对多段线进行便利
    for i in sorted(np.unique(pt2pl[1])):
        # 每一条多段线对应的点
        index = pt2pl[0, pt2pl[1] == i]
        polygon_type = data['map_polygon']["type"][i]
        cur_side = pt_side[index]
        cur_type = pt_type[index]
        cur_pos = pt_pos[index]
        cur_heading = pt_heading[index]

        for side_val in np.unique(cur_side):
            for type_val in np.unique(cur_type):
                if type_val == 13:
                    continue
                indices = np.where((cur_side == side_val) & (cur_type == type_val))[0]
                if len(indices) <= 2:
                    continue
                split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy())
                if split_polyline is None:
                    continue
                new_cur_type = cur_type[indices][0]
                new_cur_side = cur_side[indices][0]
                map_polygon_type = polygon_type.repeat(split_polyline.shape[0])
                new_cur_type = new_cur_type.repeat(split_polyline.shape[0])
                new_cur_side = new_cur_side.repeat(split_polyline.shape[0])
                cur_pl_idx = torch.Tensor([i])
                new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0])
                split_polyline_pos.append(split_polyline[..., :2])
                split_polyline_theta.append(split_polyline[..., 2])
                split_polyline_type.append(new_cur_type)
                split_polyline_side.append(new_cur_side)
                pl_idx_list.append(new_cur_pl_idx)
                split_polygon_type.append(map_polygon_type)

    split_polyline_pos = torch.cat(split_polyline_pos, dim=0)
    split_polyline_theta = torch.cat(split_polyline_theta, dim=0)
    split_polyline_type = torch.cat(split_polyline_type, dim=0)
    split_polyline_side = torch.cat(split_polyline_side, dim=0)
    split_polygon_type = torch.cat(split_polygon_type, dim=0)
    pl_idx_list = torch.cat(pl_idx_list, dim=0)
    vec = split_polyline_pos[:, 1, :] - split_polyline_pos[:, 0, :]
    data['map_save'] = {}
    data['pt_token'] = {}
    data['map_save']['traj_pos'] = split_polyline_pos
    data['map_save']['traj_theta'] = split_polyline_theta[:, 0]  # torch.arctan2(vec[:, 1], vec[:, 0])
    data['map_save']['pl_idx_list'] = pl_idx_list
    data['pt_token']['type'] = split_polyline_type
    data['pt_token']['side'] = split_polyline_side
    data['pt_token']['pl_type'] = split_polygon_type
    data['pt_token']['num_nodes'] = split_polyline_pos.shape[0]
    return data

# 将polygon按最长5米进行
data = tokenize_map(raw_data)

### 1.2.3 将上述字典类型data加载为HeteroData类型

In [246]:
from torch_geometric.data import Dataset, HeteroData
from torch_geometric.loader import DataLoader

class CustomHeteroDataset(Dataset):
    def __init__(self, data_list):
        super(CustomHeteroDataset, self).__init__()
        self.data_list = data_list
    
    def len(self):
        return len(self.data_list)
    
    def get(self, idx):
        batch_data = HeteroData()

        for node_type, node_data in self.data_list[idx].items():
            if isinstance(node_type, str):  # 处理节点数据
                if isinstance(node_data, dict):
                    for attr, value in node_data.items():
                        batch_data[node_type][attr] = value
                else:
                    batch_data[node_type] = [node_data]

        for edge_type, edge_data in self.data_list[idx].items():
            if isinstance(edge_type, tuple) and len(edge_type) == 3:  # 处理边数据
                if isinstance(edge_data, dict):
                    for attr, value in edge_data.items():
                        batch_data[edge_type][attr] = value
                else:
                    batch_data[edge_type] = edge_data
        return batch_data

dataset = CustomHeteroDataset([data])
loader = DataLoader(dataset, batch_size=1)
batch = next(iter(loader))

### 1.2.4 进行地图token匹配

In [247]:
noise = False
argmin_sample_len = 3

map_token_traj_path = "/home/yangyh408/codes/SMART/smart/tokens/map_traj_token5.pkl"
map_token_traj = pickle.load(open(map_token_traj_path, 'rb'))

map_token = {'traj_src': map_token_traj['traj_src'], }
traj_end_theta = np.arctan2(map_token['traj_src'][:, -1, 1]-map_token['traj_src'][:, -2, 1],
                            map_token['traj_src'][:, -1, 0]-map_token['traj_src'][:, -2, 0])
# 生成从 start 到 end 的 steps 个等间隔值。
indices = torch.linspace(0, map_token['traj_src'].shape[1]-1, steps=argmin_sample_len).long()
map_token['sample_pt'] = torch.from_numpy(map_token['traj_src'][:, indices]).to(torch.float)
map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
map_token['traj_src'] = torch.from_numpy(map_token['traj_src']).to(torch.float)

def match_token_map(data):
    traj_pos = data['map_save']['traj_pos'].to(torch.float)
    traj_theta = data['map_save']['traj_theta'].to(torch.float)
    pl_idx_list = data['map_save']['pl_idx_list']
    token_sample_pt = map_token['sample_pt'].to(traj_pos.device)
    token_src = map_token['traj_src'].to(traj_pos.device)
    max_traj_len = map_token['traj_src'].shape[1]
    pl_num = traj_pos.shape[0]

    # 各地图多段线的起始点坐标xy
    pt_token_pos = traj_pos[:, 0, :].clone()
    # 各地图多段线的起始位置朝向
    pt_token_orientation = traj_theta.clone()
    # 将地图多段线由全局坐标系转换为局部坐标系
    cos, sin = traj_theta.cos(), traj_theta.sin()
    rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
    rot_mat[..., 0, 0] = cos
    rot_mat[..., 0, 1] = -sin
    rot_mat[..., 1, 0] = sin
    rot_mat[..., 1, 1] = cos
    traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
    # 将坐标转换后的多段线与地图map_token进行匹配
    distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1))
    pt_token_id = torch.argmin(distance, dim=1)

    if noise:
        topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8]
        sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
        pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)

    cos, sin = traj_theta.cos(), traj_theta.sin()
    rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
    rot_mat[..., 0, 0] = cos
    rot_mat[..., 0, 1] = sin
    rot_mat[..., 1, 0] = -sin
    rot_mat[..., 1, 1] = cos
    token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
                                rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
    token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)

    pl_idx_full = pl_idx_list.clone()
    token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
    count_nums = []
    for pl in pl_idx_full.unique():
        pt = token2pl[0, token2pl[1, :] == pl]
        left_side = (data['pt_token']['side'][pt] == 0).sum()
        right_side = (data['pt_token']['side'][pt] == 1).sum()
        center_side = (data['pt_token']['side'][pt] == 2).sum()
        count_nums.append(torch.Tensor([left_side, right_side, center_side]))
    # count_nums: [N_polyline, 3]分别记录每个原始多段线对应的左侧、右侧、中心token有多少
    count_nums = torch.stack(count_nums, dim=0)
    # 获取每个原始多段线对应的最多token数量
    max_token_num = int(count_nums.max().item())
    # 构建多段线的轨迹掩码 [N_polyline, 3, max_token_num]
    traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, max_token_num), dtype=bool)
    idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
    idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1)    #[N_polyline, 3, max_token_num]
    counts_num_expanded = count_nums.unsqueeze(-1)                              #[N_polyline, 3, 1]
    traj_mask[idx_matrix < counts_num_expanded] = True

    data['pt_token']['traj_mask'] = traj_mask
    data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
                                                                        device=traj_pos.device, dtype=torch.float)], dim=-1)
    data['pt_token']['orientation'] = pt_token_orientation
    data['pt_token']['height'] = data['pt_token']['position'][:, -1]
    data[('pt_token', 'to', 'map_polygon')] = {}
    data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl
    data['pt_token']['token_idx'] = pt_token_id
    return data

batch = match_token_map(batch)

### 1.2.5 随机生成地图token预测掩码信息

地图token预测仅在训练阶段用到，在推理时仅使用对地图进行编码的上下文向量

In [248]:
def sample_pt_pred(data):
    # traj_mask: [n_map_poly, 3, max_token_num]
    traj_mask = data['pt_token']['traj_mask']
    # 从每个原始多段线中随机选取1/3的traj值被掩码掉
    raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1)
    masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)]
    masked_pt_index = torch.sort(masked_pt_index, -1)[0]
    # 有效掩码
    pt_valid_mask = traj_mask.clone()
    pt_valid_mask.scatter_(2, masked_pt_index, False)
    # 预测掩码
    pt_pred_mask = traj_mask.clone()
    pt_pred_mask.scatter_(2, masked_pt_index, False)
    tmp_mask = pt_pred_mask.clone()
    tmp_mask[:, :, :] = True
    tmp_mask.scatter_(2, masked_pt_index-1, False)
    pt_pred_mask.masked_fill_(tmp_mask, False)
    pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2)
    # 目标掩码
    pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2)
    # 通过traj_mask将生成的掩码向量从[n_map_poly, 3, max_token_num]转换为[n_polyline]的形式，使其与token信息对应
    data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask]
    data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask]
    data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask]

    return data

batch = sample_pt_pred(batch)

### 1.2.6 加载模型验证MapDecoder能否进行推理

In [249]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

from smart.model import SMART
from smart.utils.log import Logging
from smart.utils.config import load_config_act

config = load_config_act("../configs/validation/validation_scalable.yaml")

pretrain_ckpt = "../ckpt/20241021_1037/epoch=07-step=30440-val_loss=2.52.ckpt"
Predictor = SMART
logger = Logging().log(level='DEBUG')
model = Predictor(config.Model)
model.load_params_from_file(filename=pretrain_ckpt, logger=logger)

model.eval()
with torch.no_grad():
    map_enc = model.encoder.map_encoder(batch)

map_enc


2024-11-21 20:18:00,468-INFO-smart.py-Line:222-Message:==> Loading parameters from checkpoint ../ckpt/20241021_1037/epoch=07-step=30440-val_loss=2.52.ckpt to GPU
2024-11-21 20:18:01,667-INFO-smart.py-Line:231-Message:The number of disk ckpt keys: 818
2024-11-21 20:18:01,752-INFO-smart.py-Line:247-Message:Missing keys: []
2024-11-21 20:18:01,753-INFO-smart.py-Line:248-Message:The number of missing keys: 0
2024-11-21 20:18:01,754-INFO-smart.py-Line:249-Message:The number of unexpected keys: 0
2024-11-21 20:18:01,754-INFO-smart.py-Line:250-Message:==> Done (total keys 818)


{'x_pt': tensor([[-2.0293,  1.2010, -6.2030,  ..., -1.1212,  9.1886, -0.7540],
         [-1.1812, -0.3491, -5.7021,  ..., -0.2255,  7.4776, -0.3044],
         [ 2.3393, -1.1243, -6.3258,  ..., -0.6223,  4.6040,  2.0616],
         ...,
         [ 4.8849, -6.9743, -5.0598,  ..., -1.0279,  2.5471, -2.3432],
         [ 5.0264, -6.9748, -5.1602,  ..., -2.4884,  1.4778, -3.5238],
         [ 4.6786, -7.1148, -4.2925,  ..., -3.0385,  4.2767, -3.4176]]),
 'map_next_token_idx': tensor([[  35,  630, 1011,  ...,  128,  836,  829],
         [  35, 1011,  128,  ...,  570,  218,  829],
         [  35, 1011,  128,  ...,  218,  630,  378],
         ...,
         [ 476,  885, 1009,  ...,  555,  196,  883],
         [ 885,  555,  476,  ...,  613, 1014,  873],
         [ 885,  555,  289,  ...,   30,  873,  913]]),
 'map_next_token_prob': tensor([[-0.2374, -0.1808,  0.0048,  ..., -0.0427,  0.0704,  0.1785],
         [-0.1539, -0.1308,  0.1705,  ..., -0.0255, -0.0135,  0.0489],
         [-0.1327, -0.1211,  

## 1.3 可视化验证batch加载的地图信息

> 需要先通过1.1或1.2加载包含完整地图信息的batch

### 1.3.1 根据原始map_polygon和map_point进行绘制

可以通过FILTER_TYPE指定突出绘制的地图线型

In [86]:
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
import torch

# 突出显示的线形
FILTER_TYPE = [16]
FILTER_TYPE = []

# 创建绘图和信息区域
fig = plt.figure(figsize=(18, 20))
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 9])  # 上下两部分，比例为4:1

# 绘图区域
ax = fig.add_subplot(gs[1])

colors = plt.cm.tab20.colors
num_colors = len(colors)

# 准备数据
polylines = []
polyline_type = []
for i in range(batch['map_polygon']['num_nodes']):
    point_idx = batch[('map_point', 'to', 'map_polygon')]['edge_index'][0, batch[('map_point', 'to', 'map_polygon')]['edge_index'][1] == i]
    polylines.append(torch.gather(batch['map_point']['position'][:, :2], dim=0, index=point_idx[..., None].repeat(1, 2)))
    polyline_type.append(batch['map_point']['type'][point_idx[0]])

# 存储绘制的线段和其对应的数据
lines = []
line_data = []

# 绘制每条线段
for idx, (type, data) in enumerate(zip(polyline_type, polylines)):
    x = data[:, 0].numpy()
    y = data[:, 1].numpy()
    if len(FILTER_TYPE) == 0 or type in FILTER_TYPE:
        line, = ax.plot(x, y, marker='', linestyle='-', linewidth=2, color=colors[type], picker=5)  # 启用 picker
    else:
        line, = ax.plot(x, y, marker='', linestyle='-', linewidth=2, color='#F0F0F0', alpha=1, picker=5)
    lines.append(line)
    line_data.append((idx, type, x, y))

# 设置轴比例相同
ax.set_aspect('equal')

# 添加颜色条
sm = plt.cm.ScalarMappable(cmap=plt.cm.tab20, norm=mcolors.Normalize(vmin=0, vmax=num_colors - 1))
sm.set_array([])  # 必须设置 array 以显示颜色条
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Color Index')
cbar.set_ticks(range(num_colors))
cbar.set_ticklabels(range(num_colors))

# 设置图例、标题和网格
ax.set_xlabel('X-coordinate')
ax.set_ylabel('Y-coordinate')
ax.set_title('Multiple Line Segments')
ax.grid(True)

# 信息显示区域
info_ax = fig.add_subplot(gs[0])
info_ax.axis('off')  # 关闭坐标轴

# 初始化显示信息
info_text = info_ax.text(0.01, 0.5, "Selected Line Info: None", fontsize=12, verticalalignment='center')

# 用于跟踪当前高亮的线段
highlighted_line = None
highlighted_prev_color = None

# 鼠标点击事件
def on_pick(event):
    global highlighted_line
    global highlighted_prev_color

    # 获取被选中的线段
    line = event.artist
    idx = lines.index(line)
    idx, pl_type, x, y = line_data[idx]

    # 如果有高亮的线段，先恢复默认样式
    if highlighted_line is not None:
        highlighted_line.set_linewidth(2)
        highlighted_line.set_color(highlighted_prev_color)
    # 更新高亮的线段
    highlighted_line = line
    highlighted_prev_color = colors[pl_type] if len(FILTER_TYPE) == 0 or pl_type in FILTER_TYPE else '#F0F0F0'
    line.set_linewidth(4)
    line.set_color('red')

    _point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
                'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
                'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
                'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
    
    # 更新信息显示区域
    info = (f"Selected Line Index: {idx}\n"
            f"Selected Line Type Index: {pl_type}\n"
            f"Selected Line Type: {_point_types[pl_type]}\n")
            # f"Coordinates: {list(zip(x, y))[:5]}... (truncated)")
    info_text.set_text(info)

    # 刷新图像
    fig.canvas.draw_idle()

# 连接事件
fig.canvas.mpl_connect('pick_event', on_pick)

# 显示图像
plt.show()


### 1.3.2 绘制信控信息

In [78]:
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch

# 0:'DASH_SOLID_YELLOW', 1:'DASH_SOLID_WHITE', 2:'DASHED_WHITE', 3:'DASHED_YELLOW', 4:'DOUBLE_SOLID_YELLOW', 5:'DOUBLE_SOLID_WHITE', 6:'DOUBLE_DASH_YELLOW', 7:'DOUBLE_DASH_WHITE',
# 8:'SOLID_YELLOW', 9:'SOLID_WHITE', 10:'SOLID_DASH_WHITE', 11:'SOLID_DASH_YELLOW', 12:'EDGE', 13:'NONE', 14:'UNKNOWN', 15:'CROSSWALK', 16:'CENTERLINE'
FILTER_TYPE = [16]
FILTER_TYPE = []

# 创建绘图，包含 Axes 对象
fig, ax = plt.subplots(figsize=(20, 20))

colors = ['red', 'green', 'yellow', 'lightgray']

# 准备数据
polylines = []
polyline_type = []
for i in range(batch['map_polygon']['num_nodes']):
    point_idx = batch[('map_point', 'to', 'map_polygon')]['edge_index'][0, batch[('map_point', 'to', 'map_polygon')]['edge_index'][1] == i]
    polylines.append(torch.gather(batch['map_point']['position'][:, :2], dim=0, index=point_idx[..., None].repeat(1, 2)))
    polyline_type.append(batch['map_point']['type'][point_idx[0]])

# 绘制每条线段
for idx, (type, data) in enumerate(zip(polyline_type, polylines)):
    x = data[:, 0].numpy()
    y = data[:, 1].numpy()
    ax.plot(x, y, marker='', linestyle='-', linewidth=2, color=colors[batch['map_polygon']['light_type'][idx]])

# 设置轴比例相同
ax.set_aspect('equal')

# 设置图例、标题和网格
ax.set_xlabel('X-coordinate')
ax.set_ylabel('Y-coordinate')
ax.set_title('Multiple Line Segments')
ax.grid(True)

# 显示图像
plt.show()

### 1.3.3 根据分段采样后的map_save数据绘制地图

In [85]:
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
import torch

# 示例数据
FILTER_TYPE = [16]
FILTER_TYPE = []

# 创建绘图和信息区域
fig = plt.figure(figsize=(18, 20))
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 9])  # 上下两部分，比例为4:1

# 绘图区域
ax = fig.add_subplot(gs[1])

colors = plt.cm.tab20.colors
num_colors = len(colors)

polyline_type = []
for i in range(batch['map_polygon']['num_nodes']):
    point_idx = batch[('map_point', 'to', 'map_polygon')]['edge_index'][0, batch[('map_point', 'to', 'map_polygon')]['edge_index'][1] == i]
    polyline_type.append(batch['map_point']['type'][point_idx[0]])

# 存储绘制的线段和其对应的数据
lines = []
line_data = []

# 绘制每条线段
for idx, (pl_idx, data) in enumerate(zip(batch['map_save']['pl_idx_list'], batch['map_save']['traj_pos'])):
    x = data[:, 0].numpy()
    y = data[:, 1].numpy()
    type = polyline_type[int(pl_idx)]
    if len(FILTER_TYPE) == 0 or type in FILTER_TYPE:
        line, = ax.plot(x, y, marker='', linestyle='-', linewidth=2, color=colors[type], picker=5)  # 启用 picker
    else:
        line, = ax.plot(x, y, marker='', linestyle='-', linewidth=2, color='#F0F0F0', alpha=1, picker=5)
    lines.append(line)
    line_data.append((idx, type, x, y))

# 设置轴比例相同
ax.set_aspect('equal')

# 添加颜色条
sm = plt.cm.ScalarMappable(cmap=plt.cm.tab20, norm=mcolors.Normalize(vmin=0, vmax=num_colors - 1))
sm.set_array([])  # 必须设置 array 以显示颜色条
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Color Index')
cbar.set_ticks(range(num_colors))
cbar.set_ticklabels(range(num_colors))

# 设置图例、标题和网格
ax.set_xlabel('X-coordinate')
ax.set_ylabel('Y-coordinate')
ax.set_title('Multiple Line Segments')
ax.grid(True)

# 信息显示区域
info_ax = fig.add_subplot(gs[0])
info_ax.axis('off')  # 关闭坐标轴

# 初始化显示信息
info_text = info_ax.text(0.01, 0.5, "Selected Line Info: None", fontsize=12, verticalalignment='center')

# 用于跟踪当前高亮的线段
highlighted_line = None
highlighted_prev_color = None

# 鼠标点击事件
def on_pick(event):
    global highlighted_line
    global highlighted_prev_color

    # 获取被选中的线段
    line = event.artist
    idx = lines.index(line)
    idx, pl_type, x, y = line_data[idx]

    # 如果有高亮的线段，先恢复默认样式
    if highlighted_line is not None:
        highlighted_line.set_linewidth(2)
        highlighted_line.set_color(highlighted_prev_color)
    # 更新高亮的线段
    highlighted_line = line
    highlighted_prev_color = colors[pl_type] if len(FILTER_TYPE) == 0 or pl_type in FILTER_TYPE else '#F0F0F0'
    line.set_linewidth(4)
    line.set_color('red')

    _point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
                'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
                'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
                'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
    
    # 更新信息显示区域
    info = (f"Selected Line Index: {idx}\n"
            f"Selected Line Head: {batch['map_save']['traj_theta'][idx]: .4f}\n"
            f"Selected Line Type Index: {pl_type}\n"
            f"Selected Line Type: {_point_types[pl_type]}\n")
            # f"Coordinates: {list(zip(x, y))[:5]}... (truncated)")
    info_text.set_text(info)

    # 刷新图像
    fig.canvas.draw_idle()

# 连接事件
fig.canvas.mpl_connect('pick_event', on_pick)

# 显示图像
plt.show()


### 1.3.4 根据匹配的地图token进行地图可视化还原

In [182]:
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
import torch

# 创建绘图和信息区域
fig = plt.figure(figsize=(18, 20))
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 9])  # 上下两部分，比例为4:1

# 绘图区域
ax = fig.add_subplot(gs[1])

colors = plt.cm.tab20.colors
num_colors = len(colors)

# 准备数据
polyline_type = []
for i in range(batch['map_polygon']['num_nodes']):
    point_idx = batch[('map_point', 'to', 'map_polygon')]['edge_index'][0, batch[('map_point', 'to', 'map_polygon')]['edge_index'][1] == i]
    polyline_type.append(batch['map_point']['type'][point_idx[0]])
pl_num = batch['pt_token']['position'].shape[0]
traj_theta = batch['pt_token']['orientation'].clone()
traj_pos = batch['pt_token']['position'].clone()
pt_token_id = batch['pt_token']['token_idx'].clone()

cos, sin = traj_theta.cos(), traj_theta.sin()
rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
rot_mat[..., 0, 0] = cos
rot_mat[..., 0, 1] = sin
rot_mat[..., 1, 0] = -sin
rot_mat[..., 1, 1] = cos

map_token_traj_path = "/home/yangyh408/codes/SMART/smart/tokens/map_traj_token5.pkl"
map_token_traj = pickle.load(open(map_token_traj_path, 'rb'))

token_src = torch.from_numpy(map_token_traj['traj_src']).to(torch.float)
token_src_world = torch.bmm(
                        token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
                        rot_mat.view(-1, 2, 2)
                    ).reshape(pl_num, token_src.shape[0], -1, 2) + traj_pos[:, None, None, :2]
token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, -1, 2)

# 存储绘制的线段和其对应的数据
lines = []
line_data = []

# 绘制每条线段
for idx, data in enumerate(token_src_world_select):
    pl_idx = batch[('pt_token', 'to', 'map_polygon')]['edge_index'][1, idx].item()
    token_idx = batch['pt_token']['token_idx'][idx].item()
    x = data[:, 0].numpy()
    y = data[:, 1].numpy()
    type = polyline_type[int(pl_idx)]
    line, = ax.plot(x, y, marker='', linestyle='-', linewidth=2, color=colors[type], picker=5)
    lines.append(line)
    line_data.append((idx, token_idx, type, x, y))

# 设置轴比例相同
ax.set_aspect('equal')

# 添加颜色条
sm = plt.cm.ScalarMappable(cmap=plt.cm.tab20, norm=mcolors.Normalize(vmin=0, vmax=num_colors - 1))
sm.set_array([])  # 必须设置 array 以显示颜色条
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Color Index')
cbar.set_ticks(range(num_colors))
cbar.set_ticklabels(range(num_colors))

# 设置图例、标题和网格
ax.set_xlabel('X-coordinate')
ax.set_ylabel('Y-coordinate')
ax.set_title('Multiple Line Segments')
ax.grid(True)

# 信息显示区域
info_ax = fig.add_subplot(gs[0])
info_ax.axis('off')  # 关闭坐标轴

# 初始化显示信息
info_text = info_ax.text(0.01, 0.5, "Selected Line Info: None", fontsize=12, verticalalignment='center')

# 用于跟踪当前高亮的线段
highlighted_line = None
highlighted_prev_color = None

# 鼠标点击事件
def on_pick(event):
    global highlighted_line
    global highlighted_prev_color

    # 获取被选中的线段
    line = event.artist
    idx = lines.index(line)
    idx, token_idx, pl_type, x, y = line_data[idx]

    # 如果有高亮的线段，先恢复默认样式
    if highlighted_line is not None:
        highlighted_line.set_linewidth(2)
        highlighted_line.set_color(highlighted_prev_color)
    # 更新高亮的线段
    highlighted_line = line
    highlighted_prev_color = colors[pl_type] if len(FILTER_TYPE) == 0 or pl_type in FILTER_TYPE else '#F0F0F0'
    line.set_linewidth(4)
    line.set_color('red')

    _point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
                'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
                'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
                'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
    
    # 更新信息显示区域
    info = (f"Selected Line Index: {idx}\n"
            f"Selected Token Index: {token_idx}\n"
            f"Selected Line Type Index: {pl_type}\n"
            f"Selected Line Type: {_point_types[pl_type]}\n")
            # f"Coordinates: {list(zip(x, y))[:5]}... (truncated)")
    info_text.set_text(info)

    # 刷新图像
    fig.canvas.draw_idle()

# 连接事件
fig.canvas.mpl_connect('pick_event', on_pick)

# 显示图像
plt.show()