In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import torch.nn as nn
class SemKITTI_DVPS_Dataset(Dataset):
    def __init__(self, root, split='train',
                 image_transform=None,
                 GT_transform=None,
                 ):
        """
        Args:
            root (str): 数据集根目录，例如 '/path/to/dataset'
            split (str): 数据集的划分，'train' 或 'val'
            image_transform: 对 RGB 图像进行预处理的 transform
            depth_transform: 对深度图进行预处理的 transform
            seg_transform: 对语义分割标签进行预处理的 transform
            inst_transform: 对实例分割标签进行预处理的 transform
        """
        self.root = root
        self.split = split
        self.image_transform = image_transform
        self.GT_transform = GT_transform
        
        self.samples = []  # 每个元素为一个字典，包含该样本的各个图片路径
        split_dir = os.path.join(root, split)
        all_files = sorted(os.listdir(split_dir))
        
        # 按照样本前缀分组，假设前缀为前4个字符（例如 "0001"）
        sample_dict = {}
        for file in all_files:
            # 去除空格、统一小写
            file_name_element = file.split("_")#[00000_00000_depth_718]
            scene=file_name_element[0]#标记场景
            frame=file_name_element[1]#标记帧
            
            # if scene != "000000":#只加载第一组
            #     continue
                
            if scene not in sample_dict:
                sample_dict[scene] = {}
                if frame not in sample_dict[scene]:
                    sample_dict[scene][frame]={}
                else:
                    pass
            else:
                if frame not in sample_dict[scene]:
                    sample_dict[scene][frame]={}
                else:
                    pass
            
            # 根据文件名中包含的关键字确定图片类型
            if 'depth' in file_name_element:
                sample_dict[scene][frame]['depth'] = os.path.join(split_dir, file)
                sample_dict[scene][frame]['focal'] = file_name_element[3].split(".")[0]
            else:
                pass
            
            if 'class.png' in file_name_element:
                sample_dict[scene][frame]['class'] = os.path.join(split_dir, file)
                
            else:
                pass
            
            if 'instance.png' in file_name_element:
                sample_dict[scene][frame]['instance'] = os.path.join(split_dir, file)
                
            else:
                pass
            
            if 'leftImg8bit.png' in file_name_element:
                sample_dict[scene][frame]['Img'] = os.path.join(split_dir, file)
                
            else:
                pass
        
        # 过滤出包含所有四种图片的样本frame
        for scene, frames in sample_dict.items():
            for frame, files in frames.items():
                if all(key in files for key in ['depth', 'Img', 'class', 'instance']):
                    self.samples.append(files)
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        # 加载 RGB 图像，并转换为 RGB 格式
        image = Image.open(sample['Img']).convert('RGB')
        # 加载深度图（深度图可能为单通道图像）
        depth = Image.open(sample['depth'])
        # 加载语义分割标签，通常为单通道标签图
        seg = Image.open(sample['class'])
        # 加载实例分割标签
        inst = Image.open(sample['instance'])
        
        # 对图像应用预处理 transform
        if self.image_transform:
            image = self.image_transform(image)
        else:
            image = transforms.ToTensor()(image)
        
        
        depth = GT_transforms(depth)
        
        seg = GT_transforms(seg)
            
        inst = GT_transforms(inst)
        
        return image, depth, seg, inst




In [2]:
import numpy as np
import matplotlib.pyplot as plt

def get_color_map(num_colors):
    """
    生成一个包含 num_colors 个随机颜色的映射表。
    """
    np.random.seed(42)  # 固定种子，保证每次生成相同的颜色
    return np.random.randint(0, 256, (num_colors, 3), dtype=np.uint8)

def colorize_panoptic(panoptic_map, colormap):
    """
    根据 panoptic_map 中每个像素的 panoptic_id，从 colormap 中取对应颜色，
    生成彩色图像。
    """
    b,c,h, w = panoptic_map.shape
    color_image = np.zeros((h, w, 3), dtype=np.uint8)
    unique_ids = np.unique(panoptic_map[0,0,:,:])
    for uid in unique_ids:
        # 如果 uid 为 0 或 2550000，设定为黑色
        if uid == 2550000:
            color = np.array([0, 0, 0], dtype=np.uint8)
        else:
            # 使用 modulo 确保 uid 超过颜色数量时仍然可以映射
            color = colormap[uid % len(colormap)]
        color_image[panoptic_map[0,0,:,:] == uid] = color
    return color_image
# 定义颜色映射表，假设最多256种不同颜色
num_colors = 256
colormap = get_color_map(num_colors)

In [3]:
if __name__ == '__main__':
    dataset_root = '/root/autodl-tmp/video_sequence'  # 修改为你的数据集根目录

    # 定义图像预处理
    image_transforms = transforms.Compose([
        transforms.Resize((376, 1241)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    GT_transforms = transforms.Compose([
        transforms.Resize((376, 1241),interpolation=transforms.InterpolationMode.NEAREST),
        transforms.ToTensor(),
    ])
    # 可根据需要为深度图单独定义 transform，例如仅做 resize 和 ToTensor
    # 对于 segment 和 instance，由于它们是标签，通常不希望有归一化操作，可以直接转换为 Tensor
    # 这里我们在 __getitem__ 中已处理

    # 构造训练集
    train_dataset = SemKITTI_DVPS_Dataset(root=dataset_root,
                                          split='train',
                                          image_transform=image_transforms,
                                             GT_transform=GT_transforms)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=4)

    # 测试加载数据
    i=0
    max_pool = nn.MaxPool2d(kernel_size=2, stride=1,padding=1)
    for images, depths, segments, instances in train_loader:
        # print("RGB 图像 batch 尺寸:", images.shape)       # 例如 [4, 3, 256, 512]
        # print("深度图 batch 尺寸:", depths.shape)          # 例如 [4, 1, 256, 512]
        # print("语义标签 batch 尺寸:", segments)       # 例如 [4, 256, 512]
        # print("实例标签 batch 尺寸:", instances.shape)      # 例如 [4, 256, 512]
        # 生成彩色图像
        gt_pop=segments*10000+instances
        color_image = colorize_panoptic(pop, colormap)
        img_tensor = torch.tensor(color_image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)

        # 3. 定义最大池化层，kernel_size=2, stride=2（对每个通道独立操作）
        

        pooled_tensor = max_pool(max_pool(max_pool(img_tensor)))
        
        # 4. 转换回 NumPy 数组，并将通道维度移到最后，得到形状 (88, 620, 3)
        pooled_image = pooled_tensor.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
        
        img = Image.fromarray(pooled_image)

        # 保存图片为 "output.png"
        img.save("/root/autodl-tmp/pop_gt/"+str(i)+"_output.png")
        #flag=i%10
        i=i+1
        print(i)

NameError: name 'gt_cls' is not defined