In [5]:
import sys
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import glob
from tqdm import tqdm
import open3d as o3d # 建议安装 open3d 用于更好的点云显示，如果没有安装，后面有点云显示的 fallback 代码

# ================= 路径设置与导入 =================
# 将 Projection 目录加入路径
project_root = os.path.abspath(os.path.join(os.getcwd(), "..", "Projection"))
if project_root not in sys.path:
    sys.path.append(project_root)

try:
    from rellis_utils.lidar2img import load_from_bin, get_cam_mtx, get_mtx_from_yaml
    print("成功导入 rellis_utils")
except ImportError as e:
    print(f"导入失败: {e}，请检查 project_root: {project_root}")

# ================= 步骤一核心函数封装 =================
def project_and_filter(lidar_points, P, RT_os1_to_cam, dist_coeff, img_size):
    h, w = img_size
    xyz_h = np.hstack((lidar_points, np.ones((lidar_points.shape[0], 1))))
    xyz_cam = (RT_os1_to_cam @ xyz_h.T).T[:, :3]
    
    # Z轴过滤 (前方且不过近)
    mask_z = xyz_cam[:, 2] > 2.0 
    xyz_cam = xyz_cam[mask_z]
    lidar_points_filtered = lidar_points[mask_z]
    
    # 投影
    rvec = np.zeros((3, 1)); tvec = np.zeros((3, 1))
    img_points, _ = cv2.projectPoints(xyz_cam, rvec, tvec, P, dist_coeff)
    img_points = img_points.squeeze()
    
    # 图像边界过滤
    u, v = img_points[:, 0], img_points[:, 1]
    mask_uv = (u >= 0) & (u < w) & (v >= 0) & (v < h)
    return img_points[mask_uv], xyz_cam[mask_uv], lidar_points_filtered[mask_uv]

def handle_occlusion(img_points, cam_points, img_size):
    h, w = img_size
    u = np.round(img_points[:, 0]).astype(int)
    v = np.round(img_points[:, 1]).astype(int)
    depth = cam_points[:, 2]
    
    # 扁平化索引用于去重
    flat_indices = v * w + u
    sort_idx = np.argsort(depth) # 按深度排序
    _, unique_idx = np.unique(flat_indices[sort_idx], return_index=True)
    return sort_idx[unique_idx] # 返回深度最小的点的原始索引

def process_single_frame(lidar_file, prob_map_file, P, RT, dist_coeff, proj_shape, map_shape):
    """
    输入: 单帧文件路径
    输出: 剔除遮挡后的 3D 点 (LiDAR坐标系) 和 对应的概率向量
    """
    points = load_from_bin(lidar_file)
    prob_map = np.load(prob_map_file) 
    
    # 1. 投影与视锥过滤
    img_pts, cam_pts, raw_pts = project_and_filter(points, P, RT, dist_coeff, proj_shape)
    
    # 2. 遮挡剔除 (使用投影分辨率)
    keep_idx = handle_occlusion(img_pts, cam_pts, proj_shape)
    img_pts, raw_pts = img_pts[keep_idx], raw_pts[keep_idx]
    
    # 3. 采样概率 (映射到概率图分辨率)
    scale_x, scale_y = map_shape[1] / proj_shape[1], map_shape[0] / proj_shape[0]
    u_map = np.clip((img_pts[:, 0] * scale_x).astype(int), 0, map_shape[1]-1)
    v_map = np.clip((img_pts[:, 1] * scale_y).astype(int), 0, map_shape[0]-1)
    
    point_probs = prob_map[:, v_map, u_map].T 
    return raw_pts, point_probs

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
成功导入 rellis_utils


In [7]:
# ================= 配置区域 (Updated for Step 2) =================
# 1. 基础路径
RELLIS_ROOT = '/home/xzy/datasets/Rellis-3D'
INFERENCE_DIR = '/home/xzy/Downloads/convertedRellis/rellisv3_edl_train-4/01_inferenced_npy'
SEQ_ID = '00004' 

# 2. [新增] 序列处理配置
START_FRAME_IDX = 0    # 从第几帧开始
NUM_FRAMES = 10        # 累积多少帧 (建议先用 5-10 帧测试)
BASE_FRAME_IDX = 0     # 以哪一帧为原点 (通常等于 START_FRAME_IDX)

# 3. [新增] 位姿文件路径
# Rellis 的 poses.txt 通常在序列目录下
POSES_FILE = os.path.join(RELLIS_ROOT, SEQ_ID, 'poses.txt')

# 4. 畸变系数 (保持不变)
dist_coeff = np.array([-0.134313,-0.025905,0.002181,0.00084,0]).reshape((5,1))

# 5. 分辨率配置 (需要在循环外定义)
PROJ_SHAPE = (1200, 1920) # 原始图像分辨率 (H, W)
MAP_SHAPE = (600, 960)    # 推理结果分辨率 (H, W), 请根据实际 .npy 形状调整

# ================= 标定参数加载 =================
# 自动寻找配置文件
cam_info_path = os.path.join(RELLIS_ROOT, 'Rellis_3D_cam_intrinsic', 'Rellis-3D', SEQ_ID, 'camera_info.txt')
if not os.path.exists(cam_info_path): cam_info_path = os.path.join(RELLIS_ROOT, SEQ_ID, 'camera_info.txt')

trans_path = os.path.join(RELLIS_ROOT, SEQ_ID, 'transforms.yaml')

# 加载 P 和 RT
P = get_cam_mtx(cam_info_path)
RT_os1_to_cam = get_mtx_from_yaml(trans_path)

print(f"标定加载完成.\n P shape: {P.shape}\n RT shape: {RT_os1_to_cam.shape}")

# ================= [新增] 位姿加载函数 =================
def load_poses(pose_file):
    """
    读取 KITTI 格式的 poses.txt (N x 12)，返回 (N, 4, 4) 的变换矩阵列表
    """
    if not os.path.exists(pose_file):
        raise FileNotFoundError(f"位姿文件未找到: {pose_file}")
    
    poses = []
    with open(pose_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            # 将 12 个数转为 3x4 矩阵，最后补一行 [0,0,0,1]
            vals = np.fromstring(line, sep=' ')
            mat = vals.reshape(3, 4)
            mat_4x4 = np.eye(4)
            mat_4x4[:3, :] = mat
            poses.append(mat_4x4)
    return poses

# 加载位姿
all_poses = load_poses(POSES_FILE)
print(f"成功加载位姿，共 {len(all_poses)} 帧")

标定加载完成.
 P shape: (3, 3)
 RT shape: (4, 4)
成功加载位姿，共 2059 帧


In [8]:
# ================= 拼接主循环 =================

acc_points_list = [] # 存储拼接后的几何点 (x, y, z)
acc_probs_list = []  # 存储对应的概率向量

# 获取基准帧的位姿逆矩阵 T_world_to_base
T_world_curr = all_poses[BASE_FRAME_IDX]
T_base_inv = np.linalg.inv(T_world_curr)

print(f"开始拼接序列 {SEQ_ID}, 帧 {START_FRAME_IDX} 到 {START_FRAME_IDX + NUM_FRAMES - 1}...")

for i in tqdm(range(START_FRAME_IDX, START_FRAME_IDX + NUM_FRAMES)):
    frame_str = f"{i:06d}"
    
    # --- 1. 动态构建路径 (保持您的 glob 风格) ---
    lidar_path = os.path.join(RELLIS_ROOT, SEQ_ID, 'os1_cloud_node_kitti_bin', f"{frame_str}.bin")
    
    # 查找 npy 文件 (处理时间戳后缀)
    npy_pattern = os.path.join(INFERENCE_DIR, SEQ_ID, f"frame{frame_str}-*.npy")
    npy_matches = glob.glob(npy_pattern)
    
    if not os.path.exists(lidar_path) or len(npy_matches) == 0:
        print(f"跳过帧 {frame_str}: 文件缺失")
        continue
    
    npy_path = npy_matches[0]
    
    # --- 2. 获取单帧数据 (调用 Step 1 函数) ---
    # pts_local 是当前帧 LiDAR 坐标系下的点
    pts_local, probs = process_single_frame(
        lidar_path, npy_path, P, RT_os1_to_cam, dist_coeff, PROJ_SHAPE, MAP_SHAPE
    )
    
    if len(pts_local) == 0:
        continue
        
    # --- 3. 坐标变换 (Space Transformation) ---
    # 公式: P_base = T_base_inv * T_curr * P_curr
    
    # 当前帧位姿 T_world_to_curr
    T_curr = all_poses[i]
    
    # 计算相对位姿 T_curr_to_base
    T_rel = T_base_inv @ T_curr
    
    # 齐次变换
    pts_homo = np.hstack((pts_local, np.ones((pts_local.shape[0], 1)))) # (N, 4)
    pts_aligned = (T_rel @ pts_homo.T).T # (N, 4)
    pts_aligned = pts_aligned[:, :3]     # (N, 3)
    
    # --- 4. 收集数据 ---
    acc_points_list.append(pts_aligned)
    acc_probs_list.append(probs)

# 合并所有帧
if len(acc_points_list) > 0:
    global_points = np.vstack(acc_points_list)
    global_probs = np.vstack(acc_probs_list)
    
    # 获取每个点的最大概率类别用于着色
    global_labels = np.argmax(global_probs, axis=1)
    
    print(f"拼接完成!")
    print(f"总点数: {global_points.shape[0]}")
    print(f"概率矩阵形状: {global_probs.shape}")
else:
    print("错误: 未收集到任何点云数据。")

开始拼接序列 00004, 帧 0 到 9...


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 28.95it/s]

拼接完成!
总点数: 58515
概率矩阵形状: (58515, 9)





In [10]:
def get_rellis_colors(labels):
    """
    输入: 
        labels: (N,) 的整数数组，代表 Rellis-3D 的原始类别 ID (0-34)
    输出:
        colors: (N, 3) 的 float 数组 (0.0-1.0)，用于 Open3D/Matplotlib 显示
    """
    # 1. 建立 ID -> Group 的映射表 (Class ID to Color Category)
    # 使用数组实现快速查找，大小设为 35 (最大ID 34 + 1)
    id_to_group = np.zeros(35, dtype=int)
    
    # 填入您提供的表 1 数据
    mapping_data = {
        0:0, 1:6, 2:5, 3:7, 4:8, 5:3, 6:2, 7:1, 8:3, 9:3, 
        10:4, 11:5, 12:3, 13:6, 14:5, 15:6, 16:3, 17:3, 18:0, 
        19:8, 20:3, 21:0, 22:3, 23:4, 24:3, 27:0, 31:2, 33:6, 34:0
    }
    
    for cls_id, grp_id in mapping_data.items():
        if cls_id < 35:
            id_to_group[cls_id] = grp_id

    # 2. 建立 Group -> RGB 的映射表 (Color Category to RGB)
    # 填入您提供的表 2 数据 (归一化到 0-1)
    group_colors = np.array([
        [0, 0, 0],       # 0: void
        [196, 255, 255], # 1: sky
        [0, 0, 255],     # 2: water
        [204, 153, 255], # 3: object
        [255, 255, 0],   # 4: paved
        [255, 153, 204], # 5: unpaved
        [153, 76, 0],    # 6: brown (dirt/mulch)
        [111, 255, 74],  # 7: green (grass)
        [0, 102, 0]      # 8: vegetation (tree/bush)
    ]) / 255.0

    # 3. 执行映射
    # 第一步：Label -> Group
    # 处理超出范围的标签（为了安全，映射为 0）
    safe_labels = labels.copy()
    safe_labels[safe_labels >= 35] = 0
    groups = id_to_group[safe_labels]
    
    # 第二步：Group -> Color
    colors = group_colors[groups]
    
    return colors

In [None]:
import open3d as o3d

def visualize_with_coordinates(points, labels, T_world_curr_base):
    """
    参数:
        points: (N, 3) 点云数据 (已变换到 Base Frame)
        labels: (N,) 原始类别标签
        T_world_curr_base: (4, 4) 基准帧在世界坐标系下的位姿 (即 all_poses[BASE_FRAME_IDX])
    """
    # --- 1. 创建点云对象 ---
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    
    # 应用自定义颜色映射
    colors = get_rellis_colors(labels)
    pcd.colors = o3d.utility.Vector3dVector(colors)
    
    # --- 2. 创建坐标轴 ---
    
    # A. 基准帧坐标轴 (Base Frame Axis)
    # 因为点云已经是 Base Frame 下的，所以 Base Frame 就在原点 (0,0,0)
    # 我们用一个小一点的坐标轴表示，代表"车的起始位置"
    axis_base = o3d.geometry.TriangleMesh.create_coordinate_frame(size=2.0, origin=[0, 0, 0])
    
    # B. 世界坐标轴 (World Frame Axis)
    # 我们需要求出 World Frame 在 Base Frame 视角下的位置。
    # 已知: P_world = T_world_curr_base * P_base
    # 所以: P_base  = inv(T_world_curr_base) * P_world
    # 世界原点 (World Origin) 在 World Frame 中是 Identity。
    # 在 Base Frame 中，它就是 inv(T_world_curr_base)。
    T_base_to_world_view = np.linalg.inv(T_world_curr_base)
    
    axis_world = o3d.geometry.TriangleMesh.create_coordinate_frame(size=5.0) #以此区分，设大一点
    axis_world.transform(T_base_to_world_view)
    
    # --- 3. 渲染 ---
    print("可视化说明:")
    print("1. 点云颜色: 已应用 Rellis-3D 官方/简化配色")
    print("2. 小坐标轴 (Size 2.0): 起始帧 (Base Frame) 位置 (你在这里)")
    print("3. 大坐标轴 (Size 5.0): 世界原点 (World Frame) 位置 (poses.txt 的 0,0,0)")
    
    o3d.visualization.draw_geometries([pcd, axis_base, axis_world],
                                      zoom=0.3412,
                                      front=[0.4257, -0.2125, -0.8795],
                                      lookat=[2.6172, 2.0475, 1.532],
                                      up=[-0.0694, -0.9768, 0.2024],
                                      window_name="Rellis-3D Accumulation")

# ================= 调用示例 =================
# 请确保此时 global_points 和 global_labels 已经由上一段代码生成
# all_poses[BASE_FRAME_IDX] 就是基准帧的位姿矩阵

if 'global_points' in locals():
    visualize_with_coordinates(global_points, global_labels, all_poses[BASE_FRAME_IDX])
else:
    print("请先运行点云拼接代码块生成 global_points 数据。")