In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
from dataset.semKITTI_dataset import SemKITTI_DVPS_Dataset
from diffusers import AutoencoderKL
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ldmseg.models import GeneralVAESeg
from ldmseg.trainers import TrainerAE
from ldmseg.utils import prepare_config, Logger, is_main_process

In [2]:
def train_classifier(model, dataloader, num_epochs=20, device="cuda"):
    """
    针对 KITTI 像素级分类训练：
      - 模型输出 logits 形状应为 [B, 19, H, W]
      - Ground Truth segments 形状为 [B, H, W]，取值范围 0～18
    """
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    model.train()
   
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for i, batch in enumerate(dataloader):
            # 假设 dataloader 返回 (images, depths, segments, instances)
            images, depths, segments, instances = batch[0],batch[1],batch[2],batch[3]
            images = images.to(device)
            segments = segments.to(device)  # segments 应为 [B, H, W] 且类型为 long
            if segments.ndim == 4 and segments.shape[1] == 1:
                segments = segments.squeeze(1).long()
            
            optimizer.zero_grad()
            # 前向传播：输出 sample 部分应为 logits，形状 [B, 19, h, w]
            output = model(images, sample_posterior=True)
            logits = output.sample
            # 如果输出尺寸与目标尺寸不一致，通过双线性上采样调整到 (192, 640)
            logits = F.interpolate(logits, size=(192, 640), mode='bilinear', align_corners=False)
            
            loss = criterion(logits, segments)
            
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss / (i+1):.4f}")
    return model


In [3]:
image_transforms = transforms.Compose([
    transforms.Resize((192, 640)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# 定义分割标签（Ground Truth）预处理
# 使用 PILToTensor 保持标签原始值，再通过 Lambda 转换为 [H, W] long 张量
GT_transforms = transforms.Compose([
    transforms.Resize((192, 640), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ToTensor(),  # 输出形状为 [C, H, W]，类型为 uint8
])

# 数据集根目录（请根据你的路径修改）
dataset_root = '/root/autodl-tmp/video_sequence'

# 构造训练集
train_dataset = SemKITTI_DVPS_Dataset(
    root=dataset_root,
    split='train',
    image_transform=image_transforms,
    GT_transform=GT_transforms
)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=False, num_workers=4)

# -----------------------
# 实例化模型
# 对于分类任务，将 out_channels 设置为类别数 19
vae = GeneralVAESeg(
    in_channels= 7,  # consider bit encoding
   int_channels= 256,
   out_channels= 128,
   block_out_channels= [32, 64, 128, 256],
   latent_channels= 4,
   num_latents= 2,
   num_upscalers= 2,
   upscale_channels= 256,
   norm_num_groups= 32,
   scaling_factor= 0.2,
   parametrization= 'gaussian',
   act_fn= 'none',
   clamp_output= False,
   freeze_codebook= False,
   num_mid_blocks= 0,
   fuse_rgb= False,
   resize_input= False,
   skip_encoder= False,
)
print(vae)

Interpolation factor:  2
Parametrization:  gaussian
Activation function:  none
GeneralVAESeg(
  (encoder): Sequential(
    (0): Conv2d(7, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): SiLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): SiLU()
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): SiLU()
    (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (10): SiLU()
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): Identity()
    (13): GroupNorm(32, 256, eps=1e-06, affine=True)
    (14): SiLU()
    (15): Conv2d(256, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (decoder): Sequential(
    (0): Conv

In [5]:
print(vae.encoder)

Sequential(
  (0): Conv2d(7, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): SiLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (4): SiLU()
  (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (7): SiLU()
  (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (10): SiLU()
  (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (12): Identity()
  (13): GroupNorm(32, 256, eps=1e-06, affine=True)
  (14): SiLU()
  (15): Conv2d(256, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)


In [7]:
def dice_loss(pred, target, smooth=1.0):
    """
    计算 Dice 损失（用于衡量区域重叠）。
    pred: 模型输出 logits，形状 [B, C, H, W]，需先 softmax
    target: ground truth，形状 [B, H, W]（整数标签）
    """
    pred_soft = F.softmax(pred, dim=1)
    target_one_hot = F.one_hot(target, num_classes=pred.shape[1]).permute(0,3,1,2).float()
    intersection = (pred_soft * target_one_hot).sum(dim=(2,3))
    union = pred_soft.sum(dim=(2,3)) + target_one_hot.sum(dim=(2,3))
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - dice.mean()

def panoptic_loss(sem_logits, inst_logits, sem_target, inst_target, lambda_dice=1.0):
    """
    对两个输出 head 分别计算交叉熵损失和可选的 Dice 损失，并组合成总损失。
    
    sem_logits: [B, num_sem_classes, H, W] 语义分支输出
    inst_logits: [B, num_inst, H, W] 实例分支输出
    sem_target: [B, H, W] ground truth 语义标签
    inst_target: [B, H, W] ground truth 实例标签
    lambda_dice: Dice 损失权重
    """
    sem_ce = F.cross_entropy(sem_logits, sem_target)
    inst_ce = F.cross_entropy(inst_logits, inst_target)
    sem_dice = dice_loss(sem_logits, sem_target)
    inst_dice = dice_loss(inst_logits, inst_target)
    total_loss = sem_ce + inst_ce + lambda_dice * (sem_dice + inst_dice)
    return total_loss