加载Lucchi++的函数设定


In [1]:
import os
import cv2
import numpy as np
from torch.utils.data import Dataset

from skimage import measure

class LucchiPPDataset(Dataset):
    """
    Lucchi++数据集加载器
    数据集结构：
    dataset/Lucchi++/
        ├── Train_In/    
        ├── Train_Out/   
        ├── Test_In/     
        └── Test_Out/    
    """
    def __init__(self, data_dir, split='test', transform=None):
        """
        参数:
            data_dir (str): Lucchi++数据集的根目录
            split (str): 'train' 或 'test'
            transform: 可选的图像变换
        """
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        
        # 设置图像和掩码目录
        if split == 'train':
            self.image_dir = os.path.join(data_dir, "Train_In")
            self.mask_dir = os.path.join(data_dir, "Train_Out")
        else:
            self.image_dir = os.path.join(data_dir, "Test_In")
            self.mask_dir = os.path.join(data_dir, "Test_Out")
            
        # 获取所有图像文件
        self.image_files = sorted(os.listdir(self.image_dir))
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # 获取图像路径
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, f"{idx}.png")
        
        # 读取图像和掩码
        image = cv2.imread(image_path)[..., ::-1]  # BGR转RGB
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        if image is None or mask is None:
            raise ValueError(f"无法读取图像或掩码: {image_path}, {mask_path}")
            
        # 调整大小
        r = min(1024 / image.shape[1], 1024 / image.shape[0])
        image = cv2.resize(image, (int(image.shape[1] * r), int(image.shape[0] * r)))
        mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), 
                         interpolation=cv2.INTER_NEAREST)
        
        # 二值化掩码
        binary_mask = (mask > 0).astype(np.uint8)
        
        # 获取点标注
        eroded_mask = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1)
        labels = measure.label(eroded_mask)
        regions = measure.regionprops(labels)
        
        points = []
        for region in regions:
            y, x = region.coords[np.random.randint(len(region.coords))]
            points.append([x, y])
            
        points = np.array(points)
        
        # 调整维度
        binary_mask = np.expand_dims(binary_mask, axis=0)  # (1, H, W)
        if len(points) > 0:
            points = np.expand_dims(points, axis=1)  # (N, 1, 2)
            
        num_masks = len(regions)
        
        return image, binary_mask, points, num_masks

def load_lucchi_dataset(data_dir="dataset/Lucchi++", split='test'):
    """
    加载Lucchi++数据集
    
    参数:
        data_dir (str): 数据集根目录
        split (str): 'train' 或 'test'
    
    返回:
        LucchiPPDataset对象
    """
    return LucchiPPDataset(data_dir, split)

# 使用示例：
if __name__ == "__main__":
    # 加载测试集
    test_dataset = load_lucchi_dataset(data_dir="dataset/Lucchi++", split='test')
    
    # 查看数据集大小
    print(f"测试集大小: {len(test_dataset)}")
    
    # 获取一个样本
    image, mask, points, num_masks = test_dataset[0]
    
    # 打印形状
    print(f"图像形状: {image.shape}")
    print(f"掩码形状: {mask.shape}")
    print(f"点标注形状: {points.shape if points.size > 0 else 'No points'}")
    print(f"掩码数量: {num_masks}")

测试集大小: 165
图像形状: (768, 1024, 3)
掩码形状: (1, 768, 1024)
点标注形状: (21, 1, 2)
掩码数量: 21


In [2]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

def evaluate_model(predictor, test_dataset, save_dir=None):
    """
    评估模型在测试数据集上的性能
    
    Args:
        predictor: 加载了预训练模型的SAM2预测器
        test_dataset: 测试数据集
        save_dir: 可选，保存预测结果的目录
    """
    predictor.model.eval()  # 设置为评估模式
    all_ious = []
    
    # 创建保存目录
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    with torch.no_grad():  # 不计算梯度
        for idx, (image, gt_mask, input_point, num_masks) in enumerate(tqdm(test_dataset)):
            if image is None or gt_mask is None or num_masks == 0:
                continue
                
            # 准备输入数据
            input_point = np.array(input_point)
            input_label = np.ones((num_masks, 1))
            
            # 基本的数据检查
            if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray):
                continue
            if input_point.size == 0 or input_label.size == 0:
                continue
                
            # 设置图像并获取预测
            predictor.set_image(image)
            pred_masks = predictor.predict(
                point_coords=input_point,
                point_labels=input_label
            )
            
            # 计算IoU
            gt_mask = torch.tensor(gt_mask.astype(np.float32)).cuda()
            pred_mask = torch.tensor(pred_masks > 0.5).cuda().float()
            
            intersection = (gt_mask * pred_mask).sum((1, 2))
            union = gt_mask.sum((1, 2)) + pred_mask.sum((1, 2)) - intersection
            iou = (intersection / (union + 1e-6)).cpu().numpy()
            
            all_ious.extend(iou)
            
            # 可选：保存预测结果
            if save_dir:
                save_path = os.path.join(save_dir, f'pred_{idx}.png')
                # 这里添加保存预测mask的代码
                
    # 计算统计信息
    mean_iou = np.mean(all_ious)
    std_iou = np.std(all_ious)
    
    return mean_iou, std_iou



In [10]:

def main():
    # 设置设备
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # 加载预训练模型
    sam2_checkpoint = "sam2_lora_checkpoint_3000.pth"  # 模型路径
    model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"  # 配置文件路径
    
    # 构建模型
    model = build_sam2(model_cfg)  # 假设 build_sam2 返回模型
    
    # 加载权重
    checkpoint = torch.load(sam2_checkpoint, map_location=device)
    model.load_state_dict(checkpoint, strict=False)  # 加载权重
    model = model.to(device)  # 将模型迁移到设备
    
    # 初始化预测器
    predictor = SAM2ImagePredictor(model)
    
    # 加载测试数据集
    test_datasets = {
        'Lucchi+': load_lucchi_dataset(),  # 需要实现
        #  'VNC': load_vnc_dataset()          
    }
    
    # 在每个测试数据集上评估
    results = {}
    for dataset_name, dataset in test_datasets.items():
        print(f"\nEvaluating on {dataset_name} dataset...")
        
        # 创建保存目录
        save_dir = os.path.join("results", dataset_name)
        os.makedirs(save_dir, exist_ok=True)
        
        # 评估模型
        mean_iou, std_iou = evaluate_model(
            predictor=predictor,
            test_dataset=dataset,
            save_dir=save_dir
        )
        
        results[dataset_name] = (mean_iou, std_iou)
        print(f"{dataset_name} Results: IoU = {mean_iou:.3f}±{std_iou:.3f}")
    
    # 打印总结果
    print("\nFinal Results:")
    print("-" * 50)
    for dataset_name, (mean_iou, std_iou) in results.items():
        print(f"{dataset_name}: {mean_iou:.3f}±{std_iou:.3f}")

if __name__ == "__main__":
    main()



Evaluating on Lucchi+ dataset...


  x = F.scaled_dot_product_attention(
  x = F.scaled_dot_product_attention(
  x = F.scaled_dot_product_attention(
  x = F.scaled_dot_product_attention(
  x = F.scaled_dot_product_attention(
  x = F.scaled_dot_product_attention(
  0%|          | 0/165 [00:00<?, ?it/s]


RuntimeError: No available kernel. Aborting execution.

In [7]:
import torch

checkpoint = torch.load("sam2_lora_checkpoint_3000.pth", map_location="cuda")
print(checkpoint.keys())

dict_keys(['sam_mask_decoder.transformer.layers.0.self_attn.q_proj.lora_down.weight', 'sam_mask_decoder.transformer.layers.0.self_attn.q_proj.lora_up.weight', 'sam_mask_decoder.transformer.layers.0.self_attn.k_proj.lora_down.weight', 'sam_mask_decoder.transformer.layers.0.self_attn.k_proj.lora_up.weight', 'sam_mask_decoder.transformer.layers.0.self_attn.v_proj.lora_down.weight', 'sam_mask_decoder.transformer.layers.0.self_attn.v_proj.lora_up.weight', 'sam_mask_decoder.transformer.layers.0.self_attn.out_proj.lora_down.weight', 'sam_mask_decoder.transformer.layers.0.self_attn.out_proj.lora_up.weight', 'sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.lora_down.weight', 'sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.lora_up.weight', 'sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.lora_down.weight', 'sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.lora_up.weight', 'sam_mask_decoder.transformer.lay