In [8]:
import os
import pandas as pd
import cv2
import torch
import torch.nn.utils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.model_selection import train_test_split
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from peft import LoraConfig, get_peft_model

# ... [前面的數據加載和預處理代碼保持不變] ...

# Path to the chest-ct-segmentation dataset folder
data_dir = "dataset\Lucchi++"
train_images_dir = os.path.join(data_dir, "Train_In")
train_masks_dir = os.path.join(data_dir, "Train_Out")
test_images_dir = os.path.join(data_dir, "Test_In")
test_masks_dir = os.path.join(data_dir, "Test_Out")

i = 0
# Prepare the training data, Append image and corresponding mask paths
train_data = []
for image_file in os.listdir(train_images_dir):
    image_path = os.path.join(train_images_dir, image_file)
    mask_path = os.path.join(train_masks_dir, f"{i}.png")
    i += 1
    train_data.append(
    { 
        "image" : image_path, 
        "annotation" : mask_path
    })

i = 0
# Prepare the test data, Append image and corresponding mask paths
test_data = []
for image_file in os.listdir(test_images_dir):
    image_path = os.path.join(test_images_dir, image_file)
    mask_path = os.path.join(test_masks_dir, f"{i}.png")
    i += 1
    test_data.append(
    { 
        "image" : image_path, 
        "annotation" : mask_path
    })
print(train_data)

[{'image': 'dataset\\Lucchi++\\Train_In\\mask0000.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\0.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0001.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\1.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0002.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\2.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0003.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\3.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0004.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\4.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0005.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\5.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0006.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\6.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0007.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\7.png'}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0008.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\8.png'}, {'image': 'dataset

In [9]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from skimage import measure

def read_batch(data, visualize_data=False):
    # 選擇隨機條目
    ent = data[np.random.randint(len(data))]

    # 讀取圖像
    Img = cv2.imread(ent["image"])[..., ::-1]  # 轉換BGR為RGB
    ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE)  # 以灰度圖讀取註釋

    if Img is None or ann_map is None:
        print(f"錯誤：無法從路徑 {ent['image']} 或 {ent['annotation']} 讀取圖像或遮罩")
        return None, None, None, 0

    # 調整圖像和遮罩大小
    r = min(1024 / Img.shape[1], 1024 / Img.shape[0])  # 縮放因子
    Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
    ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)

    # 初始化二值遮罩
    binary_mask = np.zeros_like(ann_map, dtype=np.uint8)

    # 獲取二值遮罩並合併為單一遮罩
    inds = np.unique(ann_map)[1:]  # 跳過背景（索引0）
    for ind in inds:
        mask = (ann_map == ind).astype(np.uint8)  # 為每個唯一索引創建二值遮罩
        binary_mask = np.maximum(binary_mask, mask)  # 與現有二值遮罩合併

    # 腐蝕合併的二值遮罩以避免邊界點
    eroded_mask = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1)

    # 使用連通區域分析來找到所有獨立的白色區域
    labels = measure.label(eroded_mask)
    regions = measure.regionprops(labels)

    points = []
    for region in regions:
        # 為每個區域選擇一個隨機點
        y, x = region.coords[np.random.randint(len(region.coords))]
        points.append([x, y])  # 注意：我們存儲為 [x, y] 以與原始代碼保持一致

    points = np.array(points)

    if visualize_data:
        # Plotting the images and points
        plt.figure(figsize=(15, 5))

        # Original Image
        plt.subplot(1, 3, 1)
        plt.title('Original Image')
        plt.imshow(Img)
        plt.axis('off')

        # Segmentation Mask (binary_mask)
        plt.subplot(1, 3, 2)
        plt.title('Binarized Mask')
        plt.imshow(binary_mask, cmap='gray')
        plt.axis('off')

        # Mask with Points in Different Colors
        plt.subplot(1, 3, 3)
        plt.title('Binarized Mask with Points')
        plt.imshow(binary_mask, cmap='gray')

        # Plot points in different colors
        colors = list(mcolors.TABLEAU_COLORS.values())
        for i, point in enumerate(points):
            plt.scatter(point[0], point[1], c=colors[i % len(colors)], s=100, label=f'Point {i+1}')  # Corrected to plot y, x order

        # plt.legend()
        plt.axis('off')

        plt.tight_layout()
        plt.show()

    binary_mask = np.expand_dims(binary_mask, axis=-1)  # 現在形狀是 (1024, 1024, 1)
    binary_mask = binary_mask.transpose((2, 0, 1))
    points = np.expand_dims(points, axis=1)

    # 返回圖像、二值化遮罩、點和遮罩數量
    return Img, binary_mask, points, len(inds)


In [10]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank=4, scaling=1.0):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features

        # 保存原始層
        self.linear = linear_layer

        # LoRA 組件
        self.lora_down = nn.Linear(self.in_features, rank, bias=False)
        self.lora_up = nn.Linear(rank, self.out_features, bias=False)
        self.scaling = scaling

        # 初始化
        nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_up.weight)

        # 凍結原始權重
        self.linear.weight.requires_grad = False
        if self.linear.bias is not None:
            self.linear.bias.requires_grad = False

    def forward(self, x):
        # 原始層的輸出
        orig_output = self.linear(x)
        # LoRA 路徑
        lora_output = self.lora_up(self.lora_down(x)) * self.scaling
        return orig_output + lora_output

def add_lora_to_model(model, rank=4, scaling=1.0, device="cuda"):
    """
    將 LoRA 添加到模型的關鍵組件，並確保所有組件都在正確的設備上
    """
    modified_layers = []

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(key in name for key in [
            'sam_prompt_encoder', 'sam_mask_decoder'
        ]):
            parent_name = '.'.join(name.split('.')[:-1])
            child_name = name.split('.')[-1]
            parent_module = model

            for part in parent_name.split('.'):
                if part:
                    parent_module = getattr(parent_module, part)

            original_layer = getattr(parent_module, child_name)
            # 創建 LoRA 層並移到指定設備
            lora_layer = LoRALinear(original_layer, rank=rank, scaling=scaling).to(device)
            setattr(parent_module, child_name, lora_layer)
            modified_layers.append((name, lora_layer))

    return modified_layers

def train_with_lora(predictor, train_data, num_steps=3000, device="cuda"):
    # 確保模型在正確的設備上
    predictor.model = predictor.model.to(device)

    # 添加 LoRA 層
    modified_layers = add_lora_to_model(predictor.model, rank=4, scaling=1.0, device=device)

    if not modified_layers:
        raise ValueError("No layers were modified with LoRA!")

    # 收集需要訓練的參數
    trainable_params = []
    for _, layer in modified_layers:
        trainable_params.extend([
            layer.lora_down.weight,
            layer.lora_up.weight
        ])

    if not trainable_params:
        raise ValueError("No trainable parameters found!")

    # 配置優化器
    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=1e-4,
        weight_decay=1e-4
    )

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.2)
    scaler = torch.cuda.amp.GradScaler()
    accumulation_steps = 4

    print(f"Training {len(trainable_params)} LoRA parameters")
    mean_iou = 0

    for step in range(1, num_steps + 1):
        with torch.cuda.amp.autocast():
            image, mask, input_point, num_masks = read_batch(train_data, visualize_data=False)
            if image is None or mask is None or num_masks == 0:
                continue

            input_label = np.ones((num_masks, 1))
            if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray):
                continue

            if input_point.size == 0 or input_label.size == 0:
                continue

            # 將數據移到 GPU
            predictor.set_image(image)
            mask_input, ucc, labels, unnorm_box = predictor._prep_prompts(
                input_point, input_label, box=None, mask_logits=None, normalize_coords=True
            )

            if ucc is None or labels is None or ucc.shape[0] == 0 or labels.shape[0] == 0:
                continue

            for i in range(ucc.shape[0]):
                uc = ucc[i:i+1, :, :]
                # 確保輸入在正確的設備上
                uc = uc.to(device)
                labels = labels.to(device)

                sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                    points=(uc, labels), boxes=None, masks=None,
                )

                batched_mode = uc.shape[0] > 1
                high_res_features = [feat_level[-1].unsqueeze(0).to(device) for feat_level in predictor._features["high_res_feats"]]

                low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                    image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0).to(device),
                    image_pe=predictor.model.sam_prompt_encoder.get_dense_pe().to(device),
                    sparse_prompt_embeddings=sparse_embeddings,
                    dense_prompt_embeddings=dense_embeddings,
                    multimask_output=True,
                    repeat_image=batched_mode,
                    high_res_features=high_res_features,
                )

                prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
                gt_mask = torch.tensor(mask.astype(np.float32), device=device)
                prd_mask = torch.sigmoid(prd_masks[:, 0])

                seg_loss = (-gt_mask * torch.log(prd_mask + 1e-6) -
                          (1 - gt_mask) * torch.log(1 - prd_mask + 1e-6)).mean()

                inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
                union = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter
                iou = inter / union
                score_loss = torch.abs(prd_scores[:, 0] - iou).mean()

                loss = seg_loss + score_loss * 0.05
                loss = loss / accumulation_steps
                scaler.scale(loss).backward()

            if step % accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            scheduler.step()

            if step % 500 == 0:
                lora_state = {}
                for name, layer in modified_layers:
                    lora_state[f"{name}.lora_down.weight"] = layer.lora_down.weight
                    lora_state[f"{name}.lora_up.weight"] = layer.lora_up.weight
                torch.save(lora_state, f"sam2_lora_checkpoint_{step}.pth")

            mean_iou = mean_iou * 0.99 + 0.01 * torch.mean(iou).item()
            if step % 100 == 0:
                print(f"Step {step}:\tAccuracy (IoU) = {mean_iou:.4f}")

    return predictor

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, sam2_checkpoint, device=device))

predictor = train_with_lora(predictor, train_data, device=device)

Training 100 LoRA parameters


  x = F.scaled_dot_product_attention(
  out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
Falling back to all available kernels for scaled_dot_product_attention (which may have a slower speed).
  return forward_call(*args, **kwargs)


In [5]:
def read_image(image_path, mask_path):  # read and resize image and mask
   img = cv2.imread(image_path)[..., ::-1]  # Convert BGR to RGB
   mask = cv2.imread(mask_path, 0)
   r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
   img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
   mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), interpolation=cv2.INTER_NEAREST)
   return img, mask

def get_points(mask, num_points):  # Sample points inside the input mask
   points = []
   coords = np.argwhere(mask > 0)
   for i in range(num_points):
       yx = np.array(coords[np.random.randint(len(coords))])
       points.append([[yx[1], yx[0]]])
   return np.array(points)

In [6]:
def generate_auto_prompts(image, strategy="combined"):
    """
    自动为图像生成提示点的函数

    Args:
        image: 输入图像
        strategy: 提示点生成策略 ["threshold", "gradient", "combined"]

    Returns:
        points: numpy array of shape (N, 2) 包含提示点坐标
    """
    # 转换为灰度图
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    else:
        gray = image.copy()

    points = []

    if strategy == "threshold" or strategy == "combined":
        # 1. 基于阈值的方法
        _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        for contour in contours:
            # 计算轮廓的中心点
            M = cv2.moments(contour)
            if M["m00"] != 0:
                cx = int(M["m10"] / M["m00"])
                cy = int(M["m01"] / M["m00"])
                points.append([cx, cy])

    if strategy == "gradient" or strategy == "combined":
        # 2. 基于梯度的方法
        # Sobel 边缘检测
        sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
        sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        gradient_magnitude = np.sqrt(sobelx**2 + sobely**2)

        # 选择梯度最强的点
        threshold = np.percentile(gradient_magnitude, 95)  # 选择前5%的强梯度点
        strong_gradient = gradient_magnitude > threshold
        gradient_points = np.where(strong_gradient)

        # 从强梯度点中随机选择一些点
        if len(gradient_points[0]) > 0:
            indices = np.random.choice(len(gradient_points[0]),
                                     min(5, len(gradient_points[0])),
                                     replace=False)
            for idx in indices:
                points.append([gradient_points[1][idx], gradient_points[0][idx]])

    return np.array(points)

def sam2_inference_pipeline(image, sam_model):
    """
    SAM2推理流水线

    Args:
        image: 输入图像
        sam_model: 加载好的SAM2模型

    Returns:
        masks: 分割结果
    """
    # 1. 自动生成提示点
    prompt_points = generate_auto_prompts(image)

    # 2. 对每个提示点进行预测
    masks = []
    labels = np.ones(len(prompt_points))  # 假设所有点都是前景

    # 3. SAM2推理
    for i in range(0, len(prompt_points), 5):  # 每批处理5个点
        batch_points = prompt_points[i:i+5]
        batch_labels = labels[i:i+5]

        # SAM2模型预测
        masks_batch = sam_model.predict(
            point_coords=batch_points,
            point_labels=batch_labels,
        )
        masks.extend(masks_batch)

    # 4. 后处理
    final_mask = post_process_masks(masks)

    return final_mask

def post_process_masks(masks):
    """
    后处理函数，合并和优化多个掩码
    """
    if not masks:
        return None

    # 将所有掩码合并为一个
    combined_mask = np.zeros_like(masks[0])
    for mask in masks:
        combined_mask = np.logical_or(combined_mask, mask)

    # 形态学操作清理掩码
    kernel = np.ones((3,3), np.uint8)
    cleaned_mask = cv2.morphologyEx(combined_mask.astype(np.uint8),
                                  cv2.MORPH_OPEN,
                                  kernel)
    cleaned_mask = cv2.morphologyEx(cleaned_mask,
                                  cv2.MORPH_CLOSE,
                                  kernel)

    return cleaned_mask

# 示例使用
def process_medical_image(image_path, mask_path, sam_model):
    """
    处理医学图像的完整流程
    """
    # 读取图像
    image, mask = read_image(image_path, mask_path)

    # 自动生成提示点并进行SAM2推理
    final_mask = sam2_inference_pipeline(image, sam_model)

    # 可视化结果
    if final_mask is not None:
        plt.figure(figsize=(15, 5))

        plt.subplot(141)
        plt.imshow(image)
        plt.title('Original Image')

        plt.subplot(142)
        plt.imshow(mask)
        plt.title('Ground Truth')

        plt.subplot(143)
        plt.imshow(final_mask, cmap='gray')
        plt.title('Segmentation Mask')

        plt.subplot(144)
        plt.imshow(image)
        plt.imshow(final_mask, alpha=0.4, cmap='jet')
        plt.title('Overlay')

        plt.show()

    return final_mask

NameError: name 'predictor' is not defined