In [None]:
import os
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
import logging
from segment_anything import sam_model_registry
from tqdm import tqdm
import torch.optim.lr_scheduler as lr_scheduler
from lora import LoRA_sam
from types import MethodType

# 设置随机种子
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 配置字典
CONFIG = {
    'dataset_name': 'beijing',  # 仅需设置数据集名称
    'data_base_dir': 'datasets',  # 数据集的基目录
    'log_dir_base': 'logs',        # 日志的基目录
    'log_file': 'best_model_metrics.log',
    'save_dir_base': 'logs',       # 模型保存的基目录
    'save_prefix': 'best_model',
    'model_type': 'vit_l',
    'checkpoint': 'weights/sam_vit_l_0b3195.pth',
    'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    'num_epochs': 100,
    'learning_rate': 1e-4,
    'betas': (0.9, 0.999),
    'weight_decay': 1e-4,
    'metric_weights': {
        'iou': 0.25,
        'f1': 0.25,
        'precision': 0.25,
        'recall': 0.25
    },
    'aux_weight': 0.4,
    'num_classes': 1,
    'rank': 512,
    'batch_size': 1
}

# # 更新随机种子
set_seed(42)  # 42 是任意选择的种子值

# 基于 dataset_name 构建路径
image_dir = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'images')
mask_dir = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'masks')
train_txt = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'train.txt')
val_txt = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'val.txt')

# 定义能够返回中间层的 forward 方法
def forward_inter(self, x: torch.Tensor) -> torch.Tensor:
    x = self.patch_embed(x)
    if self.pos_embed is not None:
        x = x + self.pos_embed
    inter_features = []
    for blk in self.blocks:
        x = blk(x)
        inter_features.append(x)

    x = self.neck(x.permute(0, 3, 1, 2))
    return x, inter_features

# 更新日志和保存目录
log_dir = os.path.join(CONFIG['log_dir_base'], CONFIG['dataset_name'])
save_dir = os.path.join(CONFIG['save_dir_base'], CONFIG['dataset_name'])
os.makedirs(log_dir, exist_ok=True)
os.makedirs(save_dir, exist_ok=True)

# 配置日志
logging.basicConfig(
    filename=os.path.join(log_dir, CONFIG['log_file']),
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# 加载 SAM 模型并初始化 LoRA
sam_model = sam_model_registry[CONFIG['model_type']](checkpoint=CONFIG['checkpoint'])
sam_model.image_encoder.forward_inter = MethodType(forward_inter, sam_model.image_encoder)
sam_model.to(CONFIG['device'])

lora_sam_model = LoRA_sam(sam_model, rank=CONFIG['rank'])
lora_sam_model.to(CONFIG['device'])

# 读取训练集和验证集列表
def read_split_files(file_path):
    with open(file_path, 'r') as f:
        file_names = f.read().strip().split('\n')
    return file_names

# 数据集加载
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, sam_model, file_list, mask_size=(1024, 1024), device='cpu'):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.sam_model = sam_model
        self.mask_size = mask_size
        self.device = device
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png') and f.replace('.png', '') in file_list]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_file)
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_NEAREST)

        mask_file = image_file
        mask_path = os.path.join(self.mask_dir, mask_file)
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        mask = cv2.resize(mask, self.mask_size, interpolation=cv2.INTER_NEAREST)

        input_image_torch = torch.as_tensor(image, dtype=torch.float32).to(self.device)
        input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()  # [C, H, W]

        input_image = self.sam_model.preprocess(input_image_torch.to(self.device))

        mask = torch.as_tensor(mask, dtype=torch.float32).to(self.device)  # 单通道浮点数

        return input_image, mask

# 读取文件列表
train_files = read_split_files(train_txt)
val_files = read_split_files(val_txt)

# 创建数据集和数据加载器
train_dataset = SegmentationDataset(
    image_dir=image_dir,
    mask_dir=mask_dir,
    sam_model=sam_model,
    file_list=train_files,
    device=CONFIG['device']
)

val_dataset = SegmentationDataset(
    image_dir=image_dir,
    mask_dir=mask_dir,
    sam_model=sam_model,
    file_list=val_files,
    device=CONFIG['device']
)

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False
)

# def process_class_logits(class_logits):
#     probs = torch.sigmoid(class_logits)
#     binary_masks = (probs > 0.5).cpu().numpy().astype(np.uint8)  # [B,1,1024,1024]
#     batch_results = []
#     for batch_idx in range(binary_masks.shape[0]):
#         current_mask = binary_masks[batch_idx, 0, :, :]
#         num_labels, labels = cv2.connectedComponents(current_mask)
#         sample_results = []
#         for label in range(1, num_labels):
#             current_component = (labels == label).astype(np.uint8)
#             y_coords, x_coords = np.nonzero(current_component)

#             min_x, max_x = np.min(x_coords), np.max(x_coords)
#             min_y, max_y = np.min(y_coords), np.max(y_coords)

#             mask = np.zeros_like(current_component, dtype=np.uint8)
#             mask[min_y:max_y+1, min_x:max_x+1] = current_component[min_y:max_y+1, min_x:max_x+1]

#             mask_resized = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)

#             sample_results.append({
#                 'bbox': [min_x, min_y, max_x, max_y],
#                 'mask': mask_resized
#             })
#         batch_results.append(sample_results)
#     return batch_results

def process_class_logits(class_logits):
    probs = torch.sigmoid(class_logits)
    binary_masks = (probs > 0.5).cpu().numpy().astype(np.uint8)  # [B,1,1024,1024]
    image_size = 1024  # 图像的宽度和高度
    batch_results = []

    for batch_idx in range(binary_masks.shape[0]):
        current_mask = binary_masks[batch_idx, 0, :, :]
        num_labels, labels = cv2.connectedComponents(current_mask)
        sample_results = []

        for label in range(1, num_labels):
            current_component = (labels == label).astype(np.uint8)
            y_coords, x_coords = np.nonzero(current_component)

            min_x, max_x = np.min(x_coords), np.max(x_coords)
            min_y, max_y = np.min(y_coords), np.max(y_coords)

            # 扩大边界框
            bbox_width = max_x - min_x
            bbox_height = max_y - min_y

            expanded_min_x = max(0, int(min_x - 0.05 * bbox_width))
            expanded_max_x = min(image_size - 1, int(max_x + 0.05 * bbox_width))
            expanded_min_y = max(0, int(min_y - 0.05 * bbox_height))
            expanded_max_y = min(image_size - 1, int(max_y + 0.05 * bbox_height))

            mask = np.zeros_like(current_component, dtype=np.uint8)
            mask[min_y:max_y+1, min_x:max_x+1] = current_component[min_y:max_y+1, min_x:max_x+1]

            mask_resized = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)

            sample_results.append({
                'bbox': [expanded_min_x, expanded_min_y, expanded_max_x, expanded_max_y],
                'mask': mask_resized
            })

        batch_results.append(sample_results)

    return batch_results

def predict_masks_batch(sam_model, image_embeddings, batch_results, device='cuda'):
    sam_model.eval()
    final_predictions = []
    for idx, sample_results in enumerate(batch_results):
        if not sample_results:
            final_predictions.append(torch.zeros((1, 1, 1024, 1024), device=device))
            continue

        current_image_embedding = image_embeddings[idx:idx+1]

        sparse_embeddings_list = []
        dense_embeddings_list = []

        for mask_info in sample_results:
            box = torch.tensor(mask_info['bbox'], dtype=torch.float, device=device).unsqueeze(0)
            mask = torch.from_numpy(mask_info['mask']).float().to(device).unsqueeze(0).unsqueeze(0)

            sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                points=None,
                boxes=box,
                masks=mask
                #masks=None
            )

            sparse_embeddings_list.append(sparse_embeddings)
            dense_embeddings_list.append(dense_embeddings)

        if len(sparse_embeddings_list) == 0:
            final_predictions.append(torch.zeros((1, 1, 1024, 1024), device=device))
            continue

        sparse_embeddings_all = torch.cat(sparse_embeddings_list, dim=0)
        dense_embeddings_all = torch.cat(dense_embeddings_list, dim=0)

        low_res_masks, _ = sam_model.mask_decoder(
            image_embeddings=current_image_embedding,
            image_pe=sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings_all,
            dense_prompt_embeddings=dense_embeddings_all,
            multimask_output=False,
        )

        resized_masks = F.interpolate(
            low_res_masks,
            size=(1024, 1024),
            mode='bilinear',
            align_corners=False
        )

        merged_mask = torch.max(resized_masks, dim=0)[0]  # [1,1024,1024]

        final_predictions.append(merged_mask.unsqueeze(0))  # [1,1,1024,1024]

    return torch.cat(final_predictions, dim=0)  # [B,1,1024,1024]

def dice_loss(preds, targets, smooth=1e-6):
    preds = torch.sigmoid(preds)
    preds = preds.view(preds.size(0), -1)
    targets = targets.view(targets.size(0), -1)

    intersection = (preds * targets).sum(dim=1)
    union = preds.sum(dim=1) + targets.sum(dim=1)

    dice = (2. * intersection + smooth) / (union + smooth)
    loss = 1 - dice
    return loss.mean()

def compute_loss(
    seg_head_logits,
    final_predictions,
    masks,
    loss_fn,
    aux_classifiers=None,
    inter_features=None,
    selected_aux_layers=None,
    aux_weight=0.4,
    seg_weight=1.0,
    sam_weight=1.0,
    loss_weights=[1, 1]
):
    seg_bce = loss_fn(seg_head_logits, masks)
    seg_dice = dice_loss(seg_head_logits, masks)
    seg_main_loss = loss_weights[0] * seg_bce + loss_weights[1] * seg_dice

    sam_bce = loss_fn(final_predictions, masks)
    sam_dice = dice_loss(final_predictions, masks)
    sam_main_loss = loss_weights[0] * sam_bce + loss_weights[1] * sam_dice
#     sam_main_loss = sam_dice

    total_aux_loss = torch.tensor(0.0, device=seg_head_logits.device)
    if aux_classifiers is not None and inter_features is not None and selected_aux_layers is not None:
        aux_losses = []
        for idx, aux_cls in zip(selected_aux_layers, aux_classifiers):
            feature_idx = idx - 1
            if feature_idx < len(inter_features):
                aux_feat = inter_features[feature_idx]
                aux_logits = aux_cls(aux_feat)
                loss_aux = loss_fn(aux_logits, masks)
                aux_losses.append(loss_aux)
            else:
                logging.warning(f"inter_features does not have index {feature_idx}")

        if aux_losses:
            aux_loss_mean = torch.mean(torch.stack(aux_losses))
            total_aux_loss = aux_loss_mean * aux_weight

    total_loss = seg_weight * seg_main_loss + sam_weight * sam_main_loss + total_aux_loss

    return (
        total_loss,
        seg_main_loss.item(),
        sam_main_loss.item(),
        total_aux_loss.item()
    )

def initialize_metrics():
    return {
        'tp': 0,
        'fp': 0,
        'fn': 0,
        'intersection': 0,
        'union': 0
    }

def accumulate_metrics(preds, targets, global_metrics, threshold=0.5):
    preds_binary = (preds > threshold).astype(np.uint8)
    targets_binary = (targets > threshold).astype(np.uint8)

    tp = np.logical_and(preds_binary == 1, targets_binary == 1).sum()
    fp = np.logical_and(preds_binary == 1, targets_binary == 0).sum()
    fn = np.logical_and(preds_binary == 0, targets_binary == 1).sum()

    intersection = tp
    union = np.logical_or(preds_binary, targets_binary).sum()

    global_metrics['tp'] += tp
    global_metrics['fp'] += fp
    global_metrics['fn'] += fn
    global_metrics['intersection'] += intersection
    global_metrics['union'] += union

class AuxiliaryClassifier(nn.Module):
    def __init__(self, in_channels, num_classes=1):
        super(AuxiliaryClassifier, self).__init__()
        self.aux_conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1)
        self.aux_bn1 = nn.BatchNorm2d(256)
        self.aux_relu1 = nn.ReLU(inplace=True)
        self.aux_conv2 = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = self.aux_conv1(x)
        x = self.aux_bn1(x)
        x = self.aux_relu1(x)
        x = self.aux_conv2(x)
        x = F.interpolate(x, size=(1024, 1024), mode='bilinear', align_corners=False)
        return x

class SegmentationHead(nn.Module):
    def __init__(self, in_channels, intermediate_channels, out_channels=1, align_corners=False):
        super(SegmentationHead, self).__init__()
        self.align_corners = align_corners

        self.mla_branches = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(1024, 512, kernel_size=3, padding=1, stride=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True),
                nn.Conv2d(512, 256, kernel_size=3, padding=1, stride=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True)
            ) for _ in range(4)
        ])

        self.mla_image_branch = nn.Sequential(
            nn.Conv2d(in_channels, 256, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.mla_classifier_branch = nn.Sequential(
            nn.Conv2d(256 * 5, intermediate_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(intermediate_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(intermediate_channels, out_channels, kernel_size=1, stride=1)
        )

    def forward(self, image_embedding, inter_features):
        if inter_features is None:
            raise ValueError("inter_features must be provided for MLA strategy")
        if len(inter_features) < 24:
            raise ValueError(f"Expected at least 24 inter_features for MLA strategy, but got {len(inter_features)}")

        selected_features = [inter_features[i] for i in [5, 11, 17, 23]]
        selected_features = [feat.permute(0, 3, 1, 2) for feat in selected_features]

        processed_features = []
        for i, feat in enumerate(selected_features):
            branch = self.mla_branches[i]
            x_feat = branch(feat)
            x_feat = F.interpolate(x_feat, scale_factor=4, mode='bilinear', align_corners=self.align_corners)
            processed_features.append(x_feat)

        img_feat = self.mla_image_branch(image_embedding)
        img_feat = F.interpolate(img_feat, scale_factor=4, mode='bilinear', align_corners=self.align_corners)
        processed_features.append(img_feat)

        aggregated = torch.cat(processed_features, dim=1)
        x = self.mla_classifier_branch(aggregated)
        x = F.interpolate(x, size=(1024, 1024), mode='bilinear', align_corners=self.align_corners)

        return x 

# class SegmentationHead(nn.Module):
#     def __init__(self, in_channels, intermediate_channels, out_channels=1, align_corners=False):
#         super(SegmentationHead, self).__init__()
#         self.align_corners = align_corners

#         # Image branch: Processes the image embedding
#         self.image_branch = nn.Sequential(
#             nn.Conv2d(in_channels, 256, kernel_size=1, stride=1, bias=False),
#             nn.BatchNorm2d(256),
#             nn.ReLU(inplace=True)
#         )

#         # Classifier branch: Maps the processed image feature to the desired output channels
#         self.classifier = nn.Sequential(
#             nn.Conv2d(256, intermediate_channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(intermediate_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(intermediate_channels, out_channels, kernel_size=1, stride=1)
#         )

#     def forward(self, image_embedding, inter_features=None):
#         """
#         Forward pass for the segmentation head.

#         Args:
#             image_embedding (torch.Tensor): The embedding of the input image. Shape: (N, C, H, W)
#             inter_features (Optional[torch.Tensor]): Not used in this simplified version.

#         Returns:
#             torch.Tensor: The segmentation map. Shape: (N, out_channels, 1024, 1024)
#         """
#         # Process the image embedding through the image branch
#         img_feat = self.image_branch(image_embedding)
        
#         # Upsample the image feature by a factor of 4
#         img_feat = F.interpolate(img_feat, scale_factor=4, mode='bilinear', align_corners=self.align_corners)
        
#         # Pass the upsampled feature through the classifier to get the segmentation map
#         x = self.classifier(img_feat)
        
#         # Upsample the segmentation map to the desired output size (1024x1024)
#         x = F.interpolate(x, size=(1024, 1024), mode='bilinear', align_corners=self.align_corners)
        
#         return x


# 初始化分割头模型
model = SegmentationHead(
    in_channels=256,
    intermediate_channels=256,
    out_channels=CONFIG['num_classes'],
    align_corners=False
)
model.to(CONFIG['device'])

# 初始化辅助分类器
selected_aux_layers = [5, 11, 17, 23]
aux_classifiers = nn.ModuleList([
    AuxiliaryClassifier(in_channels=1024, num_classes=CONFIG['num_classes']).to(CONFIG['device'])
    for _ in selected_aux_layers
])

# 设置参数的可训练性
for param in model.parameters():
    param.requires_grad = True

for aux in aux_classifiers:
    for param in aux.parameters():
        param.requires_grad = True

for param in lora_sam_model.sam.parameters():
    param.requires_grad = False

for layer in lora_sam_model.A_weights + lora_sam_model.B_weights:
    for param in layer.parameters():
        param.requires_grad = True

# 收集所有可训练参数
lora_trainable_params = list(filter(lambda p: p.requires_grad, lora_sam_model.parameters()))
model_trainable_params = list(model.parameters())
aux_trainable_params = list(aux_classifiers.parameters())
trainable_params = lora_trainable_params + model_trainable_params + aux_trainable_params

# 初始化优化器
optimizer = torch.optim.AdamW(
    trainable_params,
    lr=CONFIG['learning_rate'],
    betas=CONFIG['betas'],
    weight_decay=CONFIG['weight_decay']
)

# 损失函数
loss_fn = nn.BCEWithLogitsLoss()

num_epochs = CONFIG['num_epochs']
best_composite_score = float('-inf')
best_epoch = 0

weights = CONFIG['metric_weights']
num_classes = CONFIG['num_classes']
AUX_WEIGHT = CONFIG['aux_weight']

warmup_epochs = 3
min_lr_factor = 0.01

# 学习率调度器
def lr_lambda(epoch):
    if epoch < warmup_epochs:
        return float((epoch + 1) / warmup_epochs)
    else:
        cosine_decay = 0.5 * (1 + math.cos((epoch - warmup_epochs) * math.pi / (num_epochs - warmup_epochs)))
        return float(min_lr_factor + (1 - min_lr_factor) * cosine_decay)

scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

for epoch in range(num_epochs):
    lora_sam_model.train()
    model.train()
    aux_classifiers.train()

    total_loss = 0
    num_batches = 0

    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Train]"):
        images = images.to(CONFIG['device'])
        masks = masks.to(CONFIG['device']).unsqueeze(1)

        if images.dim() != 4 or masks.dim() != 4:
            logging.error(f"Invalid input dimensions: images {images.shape}, masks {masks.shape}")
            continue

        image_embedding, inter_features = lora_sam_model.sam.image_encoder.forward_inter(images)
        seg_head_logits = model(image_embedding, inter_features)

        prompts = process_class_logits(seg_head_logits)
        with torch.no_grad():
            final_predictions = predict_masks_batch(lora_sam_model.sam, image_embedding, prompts, CONFIG['device'])

        loss, seg_loss_val, sam_loss_val, aux_loss_val = compute_loss(
            seg_head_logits=seg_head_logits,
            final_predictions=final_predictions,
            masks=masks,
            loss_fn=loss_fn,
            aux_classifiers=aux_classifiers,
            inter_features=inter_features,
            selected_aux_layers=selected_aux_layers,
            aux_weight=AUX_WEIGHT,
            seg_weight=1.0,
            sam_weight=1.0,
            loss_weights=[1,1]
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    avg_train_loss = total_loss / num_batches if num_batches > 0 else 0

    # 验证阶段
    lora_sam_model.eval()
    model.eval()
    aux_classifiers.eval()

    val_loss = 0
    num_val_batches = 0
    global_metrics_val = initialize_metrics()       # 最终预测指标
    global_metrics_val_seg = initialize_metrics()   # 分割头输出指标

    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Validation]"):
            images = images.to(CONFIG['device'])
            masks = masks.to(CONFIG['device']).unsqueeze(1)

            if images.dim() != 4 or masks.dim() != 4:
                logging.error(f"Invalid input dimensions: images {images.shape}, masks {masks.shape}")
                continue

            image_embedding, inter_features = lora_sam_model.sam.image_encoder.forward_inter(images)
            seg_head_logits = model(image_embedding, inter_features)

            prompts = process_class_logits(seg_head_logits)
            final_predictions = predict_masks_batch(lora_sam_model.sam, image_embedding, prompts, CONFIG['device'])

            loss, seg_loss_val, sam_loss_val, _ = compute_loss(
                seg_head_logits=seg_head_logits,
                final_predictions=final_predictions,
                masks=masks,
                loss_fn=loss_fn,
                aux_classifiers=None,
                inter_features=None,
                selected_aux_layers=None,
                aux_weight=0.0,
                seg_weight=1.0,
                sam_weight=1.0,
                loss_weights=[1,1]
            )
            val_loss += loss.item()
            num_val_batches += 1

            # 计算分割头输出指标
            preds_seg = torch.sigmoid(seg_head_logits).cpu().numpy()  # 分割头输出的预测
            preds_final = torch.sigmoid(final_predictions).cpu().numpy()  # 最终预测结果
            masks_np = masks.cpu().numpy()

            for p_seg, p_final, m_gt in zip(preds_seg, preds_final, masks_np):
                # 分割头输出指标
                accumulate_metrics(p_seg[0], m_gt, global_metrics_val_seg)
                # 最终预测指标
                accumulate_metrics(p_final[0], m_gt, global_metrics_val)

    # 计算评价指标（分割头输出）
    tp_seg = global_metrics_val_seg['tp']
    fp_seg = global_metrics_val_seg['fp']
    fn_seg = global_metrics_val_seg['fn']
    intersection_seg = global_metrics_val_seg['intersection']
    union_seg = global_metrics_val_seg['union']

    iou_seg = intersection_seg / (union_seg + 1e-6)
    precision_seg = tp_seg / (tp_seg + fp_seg + 1e-6)
    recall_seg = tp_seg / (tp_seg + fn_seg + 1e-6)
    f1_seg = (2 * precision_seg * recall_seg) / (precision_seg + recall_seg + 1e-6)

    # 计算评价指标（最终预测）
    tp = global_metrics_val['tp']
    fp = global_metrics_val['fp']
    fn = global_metrics_val['fn']
    intersection = global_metrics_val['intersection']
    union = global_metrics_val['union']

    iou = intersection / (union + 1e-6)
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = (2 * precision * recall) / (precision + recall + 1e-6)

    avg_iou_seg = iou_seg
    avg_precision_seg = precision_seg
    avg_recall_seg = recall_seg
    avg_f1_seg = f1_seg

    avg_iou = iou
    avg_precision = precision
    avg_recall = recall
    avg_f1 = f1

    composite_score = (
        avg_iou * weights['iou'] +
        avg_f1 * weights['f1'] +
        avg_precision * weights['precision'] +
        avg_recall * weights['recall']
    )

    avg_val_loss = val_loss / num_val_batches if num_val_batches > 0 else 0

    log_message = (
        f"Epoch [{epoch + 1}/{num_epochs}], "
        f"Train Loss: {avg_train_loss:.4f}, "
        f"Val Loss: {avg_val_loss:.4f}, "
        f"(SegHead) IoU: {avg_iou_seg:.4f}, F1: {avg_f1_seg:.4f}, Precision: {avg_precision_seg:.4f}, Recall: {avg_recall_seg:.4f}, "
        f"(Final) IoU: {avg_iou:.4f}, F1: {avg_f1:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, "
        f"Composite Score: {composite_score:.4f}, "
        f"LR: {optimizer.param_groups[0]['lr']:.6f}"
    )
    logging.info(log_message)
    print(log_message)

    if composite_score > best_composite_score:
        best_composite_score = composite_score
        best_epoch = epoch + 1

        strategy = 'MLA'
        lora_path = os.path.join(save_dir, f"{CONFIG['save_prefix']}_lora_{strategy}.safetensors")
        checkpoint_path = os.path.join(save_dir, f"{CONFIG['save_prefix']}_{strategy}.pth")

        if hasattr(lora_sam_model, 'save_lora_parameters'):
            lora_sam_model.save_lora_parameters(lora_path)
        else:
            logging.warning("lora_sam_model 没有定义 save_lora_parameters 方法。")

        torch.save(model.state_dict(), checkpoint_path)

        save_message = (
            f"Best model saved at epoch {best_epoch} with Composite Score {best_composite_score:.4f} using {strategy} strategy"
        )
        logging.info(save_message)
        print(save_message)

    scheduler.step()

logging.info("训练完成")
print("训练完成")


  state_dict = torch.load(f)
Epoch 1/100 [Train]: 100%|██████████| 342/342 [04:34<00:00,  1.24it/s]
Epoch 1/100 [Validation]: 100%|██████████| 95/95 [00:35<00:00,  2.66it/s]


Epoch [1/100], Train Loss: 1.6515, Val Loss: 1.3061, (SegHead) IoU: 0.6490, F1: 0.7871, Precision: 0.8795, Recall: 0.7123, (Final) IoU: 0.6576, F1: 0.7935, Precision: 0.8529, Recall: 0.7418, Composite Score: 0.7614, LR: 0.000033
Best model saved at epoch 1 with Composite Score 0.7614 using MLA strategy


Epoch 2/100 [Train]: 100%|██████████| 342/342 [04:18<00:00,  1.33it/s]
Epoch 2/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.85it/s]


Epoch [2/100], Train Loss: 1.2577, Val Loss: 1.2697, (SegHead) IoU: 0.6491, F1: 0.7872, Precision: 0.8892, Recall: 0.7062, (Final) IoU: 0.6608, F1: 0.7958, Precision: 0.8525, Recall: 0.7461, Composite Score: 0.7638, LR: 0.000067
Best model saved at epoch 2 with Composite Score 0.7638 using MLA strategy


Epoch 3/100 [Train]: 100%|██████████| 342/342 [04:16<00:00,  1.33it/s]
Epoch 3/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.84it/s]


Epoch [3/100], Train Loss: 1.1467, Val Loss: 1.2557, (SegHead) IoU: 0.6644, F1: 0.7984, Precision: 0.8709, Recall: 0.7370, (Final) IoU: 0.6717, F1: 0.8036, Precision: 0.8558, Recall: 0.7575, Composite Score: 0.7722, LR: 0.000100
Best model saved at epoch 3 with Composite Score 0.7722 using MLA strategy


Epoch 4/100 [Train]: 100%|██████████| 342/342 [04:17<00:00,  1.33it/s]
Epoch 4/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.84it/s]


Epoch [4/100], Train Loss: 1.0001, Val Loss: 1.5306, (SegHead) IoU: 0.6172, F1: 0.7633, Precision: 0.9111, Recall: 0.6567, (Final) IoU: 0.6307, F1: 0.7736, Precision: 0.8831, Recall: 0.6882, Composite Score: 0.7439, LR: 0.000100


Epoch 5/100 [Train]: 100%|██████████| 342/342 [04:16<00:00,  1.34it/s]
Epoch 5/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.82it/s]


Epoch [5/100], Train Loss: 0.8858, Val Loss: 1.0226, (SegHead) IoU: 0.7085, F1: 0.8294, Precision: 0.7666, Recall: 0.9034, (Final) IoU: 0.6758, F1: 0.8065, Precision: 0.7342, Recall: 0.8947, Composite Score: 0.7778, LR: 0.000100
Best model saved at epoch 5 with Composite Score 0.7778 using MLA strategy


Epoch 6/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 6/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.82it/s]


Epoch [6/100], Train Loss: 0.7924, Val Loss: 1.2922, (SegHead) IoU: 0.6640, F1: 0.7981, Precision: 0.8580, Recall: 0.7460, (Final) IoU: 0.6676, F1: 0.8007, Precision: 0.8183, Recall: 0.7838, Composite Score: 0.7676, LR: 0.000100


Epoch 7/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 7/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.85it/s]


Epoch [7/100], Train Loss: 0.6864, Val Loss: 1.1906, (SegHead) IoU: 0.6868, F1: 0.8143, Precision: 0.8127, Recall: 0.8160, (Final) IoU: 0.6734, F1: 0.8048, Precision: 0.7891, Recall: 0.8212, Composite Score: 0.7722, LR: 0.000100


Epoch 8/100 [Train]: 100%|██████████| 342/342 [04:45<00:00,  1.20it/s]
Epoch 8/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.83it/s]


Epoch [8/100], Train Loss: 0.6141, Val Loss: 1.0749, (SegHead) IoU: 0.6922, F1: 0.8181, Precision: 0.8454, Recall: 0.7926, (Final) IoU: 0.6885, F1: 0.8155, Precision: 0.8257, Recall: 0.8055, Composite Score: 0.7838, LR: 0.000100
Best model saved at epoch 8 with Composite Score 0.7838 using MLA strategy


Epoch 9/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 9/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.84it/s]


Epoch [9/100], Train Loss: 0.5980, Val Loss: 1.3361, (SegHead) IoU: 0.6600, F1: 0.7952, Precision: 0.8923, Recall: 0.7171, (Final) IoU: 0.6731, F1: 0.8046, Precision: 0.8635, Recall: 0.7532, Composite Score: 0.7736, LR: 0.000099


Epoch 10/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 10/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.86it/s]


Epoch [10/100], Train Loss: 0.5016, Val Loss: 1.4809, (SegHead) IoU: 0.6349, F1: 0.7767, Precision: 0.9026, Recall: 0.6816, (Final) IoU: 0.6445, F1: 0.7838, Precision: 0.8836, Recall: 0.7043, Composite Score: 0.7540, LR: 0.000099


Epoch 11/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 11/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.84it/s]


Epoch [11/100], Train Loss: 0.4655, Val Loss: 1.6167, (SegHead) IoU: 0.6334, F1: 0.7755, Precision: 0.9019, Recall: 0.6803, (Final) IoU: 0.6501, F1: 0.7879, Precision: 0.8751, Recall: 0.7165, Composite Score: 0.7574, LR: 0.000099


Epoch 12/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 12/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.84it/s]


Epoch [12/100], Train Loss: 0.4312, Val Loss: 1.5619, (SegHead) IoU: 0.6231, F1: 0.7678, Precision: 0.8986, Recall: 0.6702, (Final) IoU: 0.6473, F1: 0.7859, Precision: 0.8696, Recall: 0.7169, Composite Score: 0.7549, LR: 0.000098


Epoch 13/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 13/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.82it/s]


Epoch [13/100], Train Loss: 0.4753, Val Loss: 1.1984, (SegHead) IoU: 0.6794, F1: 0.8091, Precision: 0.8851, Recall: 0.7451, (Final) IoU: 0.6811, F1: 0.8103, Precision: 0.8575, Recall: 0.7680, Composite Score: 0.7792, LR: 0.000098


Epoch 14/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 14/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.84it/s]


Epoch [14/100], Train Loss: 0.4062, Val Loss: 1.3515, (SegHead) IoU: 0.6604, F1: 0.7955, Precision: 0.8950, Recall: 0.7158, (Final) IoU: 0.6719, F1: 0.8038, Precision: 0.8720, Recall: 0.7455, Composite Score: 0.7733, LR: 0.000097


Epoch 15/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.35it/s]
Epoch 15/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.81it/s]


Epoch [15/100], Train Loss: 0.3573, Val Loss: 1.1234, (SegHead) IoU: 0.6912, F1: 0.8174, Precision: 0.8839, Recall: 0.7602, (Final) IoU: 0.6993, F1: 0.8231, Precision: 0.8593, Recall: 0.7897, Composite Score: 0.7929, LR: 0.000097
Best model saved at epoch 15 with Composite Score 0.7929 using MLA strategy


Epoch 16/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 16/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.83it/s]


Epoch [16/100], Train Loss: 0.3373, Val Loss: 1.2265, (SegHead) IoU: 0.6770, F1: 0.8074, Precision: 0.9018, Recall: 0.7309, (Final) IoU: 0.6967, F1: 0.8213, Precision: 0.8707, Recall: 0.7772, Composite Score: 0.7915, LR: 0.000096


Epoch 17/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 17/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.82it/s]


Epoch [17/100], Train Loss: 0.3089, Val Loss: 1.1013, (SegHead) IoU: 0.6958, F1: 0.8206, Precision: 0.8920, Recall: 0.7599, (Final) IoU: 0.7128, F1: 0.8324, Precision: 0.8568, Recall: 0.8092, Composite Score: 0.8028, LR: 0.000096
Best model saved at epoch 17 with Composite Score 0.8028 using MLA strategy


Epoch 18/100 [Train]: 100%|██████████| 342/342 [04:44<00:00,  1.20it/s]
Epoch 18/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.81it/s]


Epoch [18/100], Train Loss: 0.3033, Val Loss: 1.1590, (SegHead) IoU: 0.6799, F1: 0.8095, Precision: 0.8843, Recall: 0.7463, (Final) IoU: 0.6949, F1: 0.8200, Precision: 0.8535, Recall: 0.7890, Composite Score: 0.7894, LR: 0.000095


Epoch 19/100 [Train]: 100%|██████████| 342/342 [04:13<00:00,  1.35it/s]
Epoch 19/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.83it/s]


Epoch [19/100], Train Loss: 0.2902, Val Loss: 1.2674, (SegHead) IoU: 0.6752, F1: 0.8061, Precision: 0.9001, Recall: 0.7299, (Final) IoU: 0.6951, F1: 0.8201, Precision: 0.8728, Recall: 0.7734, Composite Score: 0.7904, LR: 0.000094


Epoch 20/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 20/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.82it/s]


Epoch [20/100], Train Loss: 0.2945, Val Loss: 1.1623, (SegHead) IoU: 0.6725, F1: 0.8042, Precision: 0.8924, Recall: 0.7318, (Final) IoU: 0.6935, F1: 0.8190, Precision: 0.8548, Recall: 0.7861, Composite Score: 0.7884, LR: 0.000094


Epoch 21/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.35it/s]
Epoch 21/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.81it/s]


Epoch [21/100], Train Loss: 0.3200, Val Loss: 1.1296, (SegHead) IoU: 0.7102, F1: 0.8305, Precision: 0.8788, Recall: 0.7873, (Final) IoU: 0.7135, F1: 0.8328, Precision: 0.8510, Recall: 0.8153, Composite Score: 0.8032, LR: 0.000093
Best model saved at epoch 21 with Composite Score 0.8032 using MLA strategy


Epoch 22/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 22/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.83it/s]


Epoch [22/100], Train Loss: 0.4230, Val Loss: 1.7168, (SegHead) IoU: 0.6027, F1: 0.7521, Precision: 0.9043, Recall: 0.6437, (Final) IoU: 0.6356, F1: 0.7772, Precision: 0.8680, Recall: 0.7036, Composite Score: 0.7461, LR: 0.000092


Epoch 23/100 [Train]: 100%|██████████| 342/342 [04:48<00:00,  1.19it/s]
Epoch 23/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.81it/s]


Epoch [23/100], Train Loss: 0.4593, Val Loss: 1.2835, (SegHead) IoU: 0.6576, F1: 0.7934, Precision: 0.8742, Recall: 0.7264, (Final) IoU: 0.6658, F1: 0.7994, Precision: 0.8258, Recall: 0.7746, Composite Score: 0.7664, LR: 0.000091


Epoch 24/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 24/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.83it/s]


Epoch [24/100], Train Loss: 0.3293, Val Loss: 1.3940, (SegHead) IoU: 0.6543, F1: 0.7910, Precision: 0.8928, Recall: 0.7101, (Final) IoU: 0.6703, F1: 0.8026, Precision: 0.8568, Recall: 0.7549, Composite Score: 0.7712, LR: 0.000090


Epoch 25/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 25/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.79it/s]


Epoch [25/100], Train Loss: 0.2669, Val Loss: 1.1154, (SegHead) IoU: 0.6990, F1: 0.8229, Precision: 0.8651, Recall: 0.7846, (Final) IoU: 0.7078, F1: 0.8289, Precision: 0.8336, Recall: 0.8244, Composite Score: 0.7987, LR: 0.000089


Epoch 26/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.35it/s]
Epoch 26/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.81it/s]


Epoch [26/100], Train Loss: 0.2459, Val Loss: 1.2932, (SegHead) IoU: 0.6742, F1: 0.8054, Precision: 0.8860, Recall: 0.7383, (Final) IoU: 0.6872, F1: 0.8146, Precision: 0.8530, Recall: 0.7794, Composite Score: 0.7836, LR: 0.000088


Epoch 27/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 27/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.80it/s]


Epoch [27/100], Train Loss: 0.2333, Val Loss: 1.2508, (SegHead) IoU: 0.6831, F1: 0.8117, Precision: 0.8796, Recall: 0.7536, (Final) IoU: 0.6941, F1: 0.8195, Precision: 0.8494, Recall: 0.7916, Composite Score: 0.7886, LR: 0.000087


Epoch 28/100 [Train]: 100%|██████████| 342/342 [04:13<00:00,  1.35it/s]
Epoch 28/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.80it/s]


Epoch [28/100], Train Loss: 0.2329, Val Loss: 1.3393, (SegHead) IoU: 0.6740, F1: 0.8052, Precision: 0.8807, Recall: 0.7417, (Final) IoU: 0.6815, F1: 0.8106, Precision: 0.8457, Recall: 0.7783, Composite Score: 0.7790, LR: 0.000086


Epoch 29/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.35it/s]
Epoch 29/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.79it/s]


Epoch [29/100], Train Loss: 0.2263, Val Loss: 1.3972, (SegHead) IoU: 0.6614, F1: 0.7962, Precision: 0.8879, Recall: 0.7217, (Final) IoU: 0.6751, F1: 0.8060, Precision: 0.8519, Recall: 0.7648, Composite Score: 0.7745, LR: 0.000085


Epoch 30/100 [Train]: 100%|██████████| 342/342 [04:13<00:00,  1.35it/s]
Epoch 30/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.79it/s]


Epoch [30/100], Train Loss: 0.2250, Val Loss: 1.3971, (SegHead) IoU: 0.6633, F1: 0.7976, Precision: 0.8866, Recall: 0.7248, (Final) IoU: 0.6820, F1: 0.8109, Precision: 0.8529, Recall: 0.7729, Composite Score: 0.7797, LR: 0.000083


Epoch 31/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 31/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.80it/s]


Epoch [31/100], Train Loss: 0.2254, Val Loss: 1.3000, (SegHead) IoU: 0.6721, F1: 0.8039, Precision: 0.8844, Recall: 0.7369, (Final) IoU: 0.6897, F1: 0.8164, Precision: 0.8520, Recall: 0.7836, Composite Score: 0.7854, LR: 0.000082


Epoch 32/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.35it/s]
Epoch 32/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.81it/s]


Epoch [32/100], Train Loss: 0.2342, Val Loss: 1.5565, (SegHead) IoU: 0.6431, F1: 0.7828, Precision: 0.9036, Recall: 0.6905, (Final) IoU: 0.6691, F1: 0.8017, Precision: 0.8714, Recall: 0.7424, Composite Score: 0.7712, LR: 0.000081


Epoch 33/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.35it/s]
Epoch 33/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.76it/s]


Epoch [33/100], Train Loss: 0.2275, Val Loss: 1.1889, (SegHead) IoU: 0.6977, F1: 0.8220, Precision: 0.8724, Recall: 0.7770, (Final) IoU: 0.7110, F1: 0.8311, Precision: 0.8360, Recall: 0.8262, Composite Score: 0.8011, LR: 0.000080


Epoch 34/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 34/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.77it/s]


Epoch [34/100], Train Loss: 0.2229, Val Loss: 1.2334, (SegHead) IoU: 0.6875, F1: 0.8148, Precision: 0.8831, Recall: 0.7563, (Final) IoU: 0.7049, F1: 0.8269, Precision: 0.8470, Recall: 0.8078, Composite Score: 0.7966, LR: 0.000078


Epoch 35/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 35/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.76it/s]


Epoch [35/100], Train Loss: 0.2203, Val Loss: 1.3696, (SegHead) IoU: 0.6671, F1: 0.8003, Precision: 0.8874, Recall: 0.7288, (Final) IoU: 0.6856, F1: 0.8135, Precision: 0.8517, Recall: 0.7786, Composite Score: 0.7824, LR: 0.000077


Epoch 36/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 36/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.81it/s]


Epoch [36/100], Train Loss: 0.2283, Val Loss: 1.7799, (SegHead) IoU: 0.6168, F1: 0.7630, Precision: 0.9181, Recall: 0.6527, (Final) IoU: 0.6543, F1: 0.7910, Precision: 0.8831, Recall: 0.7163, Composite Score: 0.7612, LR: 0.000076


Epoch 37/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.35it/s]
Epoch 37/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.76it/s]


Epoch [37/100], Train Loss: 0.2247, Val Loss: 1.3468, (SegHead) IoU: 0.6727, F1: 0.8043, Precision: 0.8864, Recall: 0.7362, (Final) IoU: 0.6909, F1: 0.8172, Precision: 0.8463, Recall: 0.7901, Composite Score: 0.7861, LR: 0.000074


Epoch 38/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 38/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.78it/s]


Epoch [38/100], Train Loss: 0.2191, Val Loss: 1.3867, (SegHead) IoU: 0.6675, F1: 0.8006, Precision: 0.8922, Recall: 0.7261, (Final) IoU: 0.6944, F1: 0.8197, Precision: 0.8534, Recall: 0.7885, Composite Score: 0.7890, LR: 0.000073


Epoch 39/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 39/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.75it/s]


Epoch [39/100], Train Loss: 0.2177, Val Loss: 1.2339, (SegHead) IoU: 0.7012, F1: 0.8244, Precision: 0.8733, Recall: 0.7806, (Final) IoU: 0.7090, F1: 0.8297, Precision: 0.8361, Recall: 0.8234, Composite Score: 0.7995, LR: 0.000071


Epoch 40/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.35it/s]
Epoch 40/100 [Validation]: 100%|██████████| 95/95 [00:33<00:00,  2.80it/s]


Epoch [40/100], Train Loss: 0.2156, Val Loss: 1.5020, (SegHead) IoU: 0.6533, F1: 0.7903, Precision: 0.8998, Recall: 0.7046, (Final) IoU: 0.6804, F1: 0.8098, Precision: 0.8675, Recall: 0.7592, Composite Score: 0.7792, LR: 0.000070


Epoch 41/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 41/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.78it/s]


Epoch [41/100], Train Loss: 0.2253, Val Loss: 1.2638, (SegHead) IoU: 0.6930, F1: 0.8187, Precision: 0.8766, Recall: 0.7680, (Final) IoU: 0.7050, F1: 0.8270, Precision: 0.8338, Recall: 0.8203, Composite Score: 0.7965, LR: 0.000069


Epoch 42/100 [Train]: 100%|██████████| 342/342 [04:19<00:00,  1.32it/s]
Epoch 42/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.77it/s]


Epoch [42/100], Train Loss: 0.2249, Val Loss: 1.3525, (SegHead) IoU: 0.6778, F1: 0.8080, Precision: 0.8832, Recall: 0.7446, (Final) IoU: 0.6967, F1: 0.8212, Precision: 0.8495, Recall: 0.7948, Composite Score: 0.7906, LR: 0.000067


Epoch 43/100 [Train]: 100%|██████████| 342/342 [04:18<00:00,  1.32it/s]
Epoch 43/100 [Validation]: 100%|██████████| 95/95 [00:38<00:00,  2.45it/s]


Epoch [43/100], Train Loss: 0.2084, Val Loss: 1.2491, (SegHead) IoU: 0.6925, F1: 0.8183, Precision: 0.8815, Recall: 0.7635, (Final) IoU: 0.7035, F1: 0.8259, Precision: 0.8461, Recall: 0.8067, Composite Score: 0.7956, LR: 0.000065


Epoch 44/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 44/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.73it/s]


Epoch [44/100], Train Loss: 0.2775, Val Loss: 1.2006, (SegHead) IoU: 0.6836, F1: 0.8120, Precision: 0.8017, Recall: 0.8227, (Final) IoU: 0.6813, F1: 0.8104, Precision: 0.7590, Recall: 0.8694, Composite Score: 0.7800, LR: 0.000064


Epoch 45/100 [Train]: 100%|██████████| 342/342 [04:15<00:00,  1.34it/s]
Epoch 45/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.79it/s]


Epoch [45/100], Train Loss: 0.2945, Val Loss: 1.9866, (SegHead) IoU: 0.6170, F1: 0.7631, Precision: 0.9090, Recall: 0.6576, (Final) IoU: 0.6457, F1: 0.7847, Precision: 0.8713, Recall: 0.7138, Composite Score: 0.7539, LR: 0.000062


Epoch 46/100 [Train]: 100%|██████████| 342/342 [04:13<00:00,  1.35it/s]
Epoch 46/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.76it/s]


Epoch [46/100], Train Loss: 0.2354, Val Loss: 1.3302, (SegHead) IoU: 0.6767, F1: 0.8072, Precision: 0.8823, Recall: 0.7438, (Final) IoU: 0.6919, F1: 0.8179, Precision: 0.8502, Recall: 0.7881, Composite Score: 0.7870, LR: 0.000061


Epoch 47/100 [Train]: 100%|██████████| 342/342 [04:14<00:00,  1.34it/s]
Epoch 47/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.76it/s]


Epoch [47/100], Train Loss: 0.2152, Val Loss: 1.4437, (SegHead) IoU: 0.6695, F1: 0.8020, Precision: 0.8891, Recall: 0.7304, (Final) IoU: 0.6908, F1: 0.8171, Precision: 0.8558, Recall: 0.7818, Composite Score: 0.7864, LR: 0.000059


Epoch 48/100 [Train]: 100%|██████████| 342/342 [06:13<00:00,  1.09s/it]
Epoch 48/100 [Validation]: 100%|██████████| 95/95 [01:02<00:00,  1.51it/s]


Epoch [48/100], Train Loss: 0.2088, Val Loss: 1.3992, (SegHead) IoU: 0.6776, F1: 0.8078, Precision: 0.8777, Recall: 0.7482, (Final) IoU: 0.6906, F1: 0.8170, Precision: 0.8453, Recall: 0.7905, Composite Score: 0.7858, LR: 0.000058


Epoch 49/100 [Train]: 100%|██████████| 342/342 [07:54<00:00,  1.39s/it]
Epoch 49/100 [Validation]: 100%|██████████| 95/95 [01:03<00:00,  1.50it/s]


Epoch [49/100], Train Loss: 0.2066, Val Loss: 1.3746, (SegHead) IoU: 0.6812, F1: 0.8104, Precision: 0.8756, Recall: 0.7542, (Final) IoU: 0.6937, F1: 0.8192, Precision: 0.8419, Recall: 0.7977, Composite Score: 0.7881, LR: 0.000056


Epoch 50/100 [Train]: 100%|██████████| 342/342 [07:54<00:00,  1.39s/it]
Epoch 50/100 [Validation]: 100%|██████████| 95/95 [01:02<00:00,  1.51it/s]


Epoch [50/100], Train Loss: 0.2040, Val Loss: 1.4519, (SegHead) IoU: 0.6747, F1: 0.8058, Precision: 0.8816, Recall: 0.7420, (Final) IoU: 0.6901, F1: 0.8167, Precision: 0.8486, Recall: 0.7870, Composite Score: 0.7856, LR: 0.000055


Epoch 51/100 [Train]: 100%|██████████| 342/342 [07:54<00:00,  1.39s/it]
Epoch 51/100 [Validation]: 100%|██████████| 95/95 [01:02<00:00,  1.51it/s]


Epoch [51/100], Train Loss: 0.2004, Val Loss: 1.4664, (SegHead) IoU: 0.6678, F1: 0.8008, Precision: 0.8853, Recall: 0.7311, (Final) IoU: 0.6864, F1: 0.8141, Precision: 0.8547, Recall: 0.7771, Composite Score: 0.7831, LR: 0.000053


Epoch 52/100 [Train]: 100%|██████████| 342/342 [07:55<00:00,  1.39s/it]
Epoch 52/100 [Validation]: 100%|██████████| 95/95 [00:37<00:00,  2.56it/s]


Epoch [52/100], Train Loss: 0.2056, Val Loss: 1.4843, (SegHead) IoU: 0.6672, F1: 0.8004, Precision: 0.8873, Recall: 0.7290, (Final) IoU: 0.6864, F1: 0.8140, Precision: 0.8528, Recall: 0.7787, Composite Score: 0.7830, LR: 0.000051


Epoch 53/100 [Train]: 100%|██████████| 342/342 [05:06<00:00,  1.12it/s]
Epoch 53/100 [Validation]: 100%|██████████| 95/95 [01:02<00:00,  1.52it/s]


Epoch [53/100], Train Loss: 0.2032, Val Loss: 1.5340, (SegHead) IoU: 0.6669, F1: 0.8002, Precision: 0.8938, Recall: 0.7243, (Final) IoU: 0.6871, F1: 0.8146, Precision: 0.8597, Recall: 0.7739, Composite Score: 0.7838, LR: 0.000050


Epoch 54/100 [Train]: 100%|██████████| 342/342 [07:54<00:00,  1.39s/it]
Epoch 54/100 [Validation]: 100%|██████████| 95/95 [01:02<00:00,  1.52it/s]


Epoch [54/100], Train Loss: 0.2022, Val Loss: 1.6811, (SegHead) IoU: 0.6430, F1: 0.7827, Precision: 0.9064, Recall: 0.6888, (Final) IoU: 0.6753, F1: 0.8062, Precision: 0.8667, Recall: 0.7535, Composite Score: 0.7754, LR: 0.000048


Epoch 55/100 [Train]: 100%|██████████| 342/342 [07:54<00:00,  1.39s/it]
Epoch 55/100 [Validation]: 100%|██████████| 95/95 [01:02<00:00,  1.51it/s]


Epoch [55/100], Train Loss: 0.2023, Val Loss: 1.4971, (SegHead) IoU: 0.6703, F1: 0.8026, Precision: 0.8843, Recall: 0.7348, (Final) IoU: 0.6858, F1: 0.8136, Precision: 0.8485, Recall: 0.7815, Composite Score: 0.7824, LR: 0.000046


Epoch 56/100 [Train]: 100%|██████████| 342/342 [07:56<00:00,  1.39s/it]
Epoch 56/100 [Validation]: 100%|██████████| 95/95 [01:03<00:00,  1.51it/s]


Epoch [56/100], Train Loss: 0.2060, Val Loss: 1.4469, (SegHead) IoU: 0.6766, F1: 0.8071, Precision: 0.8809, Recall: 0.7447, (Final) IoU: 0.6906, F1: 0.8170, Precision: 0.8470, Recall: 0.7890, Composite Score: 0.7859, LR: 0.000045


Epoch 57/100 [Train]: 100%|██████████| 342/342 [07:57<00:00,  1.40s/it]
Epoch 57/100 [Validation]: 100%|██████████| 95/95 [01:22<00:00,  1.16it/s]


Epoch [57/100], Train Loss: 0.2014, Val Loss: 1.5490, (SegHead) IoU: 0.6584, F1: 0.7940, Precision: 0.8955, Recall: 0.7132, (Final) IoU: 0.6829, F1: 0.8116, Precision: 0.8596, Recall: 0.7687, Composite Score: 0.7807, LR: 0.000043


Epoch 58/100 [Train]: 100%|██████████| 342/342 [05:07<00:00,  1.11it/s]
Epoch 58/100 [Validation]: 100%|██████████| 95/95 [00:36<00:00,  2.63it/s]


Epoch [58/100], Train Loss: 0.2046, Val Loss: 1.5196, (SegHead) IoU: 0.6673, F1: 0.8005, Precision: 0.8882, Recall: 0.7285, (Final) IoU: 0.6885, F1: 0.8155, Precision: 0.8560, Recall: 0.7787, Composite Score: 0.7847, LR: 0.000042


Epoch 59/100 [Train]: 100%|██████████| 342/342 [04:28<00:00,  1.28it/s]
Epoch 59/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.75it/s]


Epoch [59/100], Train Loss: 0.1966, Val Loss: 1.4458, (SegHead) IoU: 0.6787, F1: 0.8086, Precision: 0.8803, Recall: 0.7477, (Final) IoU: 0.6940, F1: 0.8193, Precision: 0.8461, Recall: 0.7942, Composite Score: 0.7884, LR: 0.000040


Epoch 60/100 [Train]: 100%|██████████| 342/342 [04:13<00:00,  1.35it/s]
Epoch 60/100 [Validation]: 100%|██████████| 95/95 [00:44<00:00,  2.13it/s]


Epoch [60/100], Train Loss: 0.2033, Val Loss: 1.4215, (SegHead) IoU: 0.6809, F1: 0.8101, Precision: 0.8786, Recall: 0.7515, (Final) IoU: 0.6947, F1: 0.8198, Precision: 0.8424, Recall: 0.7985, Composite Score: 0.7888, LR: 0.000039


Epoch 61/100 [Train]: 100%|██████████| 342/342 [07:54<00:00,  1.39s/it]
Epoch 61/100 [Validation]: 100%|██████████| 95/95 [01:03<00:00,  1.51it/s]


Epoch [61/100], Train Loss: 0.1973, Val Loss: 1.4224, (SegHead) IoU: 0.6813, F1: 0.8104, Precision: 0.8780, Recall: 0.7526, (Final) IoU: 0.6978, F1: 0.8220, Precision: 0.8447, Recall: 0.8005, Composite Score: 0.7912, LR: 0.000037


Epoch 62/100 [Train]: 100%|██████████| 342/342 [07:55<00:00,  1.39s/it]
Epoch 62/100 [Validation]: 100%|██████████| 95/95 [01:03<00:00,  1.50it/s]


Epoch [62/100], Train Loss: 0.1992, Val Loss: 1.4009, (SegHead) IoU: 0.6856, F1: 0.8135, Precision: 0.8758, Recall: 0.7595, (Final) IoU: 0.7024, F1: 0.8252, Precision: 0.8359, Recall: 0.8147, Composite Score: 0.7945, LR: 0.000036


Epoch 63/100 [Train]: 100%|██████████| 342/342 [07:55<00:00,  1.39s/it]
Epoch 63/100 [Validation]: 100%|██████████| 95/95 [01:03<00:00,  1.50it/s]


Epoch [63/100], Train Loss: 0.2000, Val Loss: 1.4045, (SegHead) IoU: 0.6810, F1: 0.8102, Precision: 0.8757, Recall: 0.7538, (Final) IoU: 0.6985, F1: 0.8225, Precision: 0.8412, Recall: 0.8046, Composite Score: 0.7917, LR: 0.000034


Epoch 64/100 [Train]: 100%|██████████| 342/342 [07:55<00:00,  1.39s/it]
Epoch 64/100 [Validation]: 100%|██████████| 95/95 [01:03<00:00,  1.50it/s]


Epoch [64/100], Train Loss: 0.1968, Val Loss: 1.4885, (SegHead) IoU: 0.6746, F1: 0.8057, Precision: 0.8812, Recall: 0.7421, (Final) IoU: 0.6925, F1: 0.8183, Precision: 0.8453, Recall: 0.7930, Composite Score: 0.7873, LR: 0.000032


Epoch 65/100 [Train]: 100%|██████████| 342/342 [06:01<00:00,  1.06s/it]
Epoch 65/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.74it/s]


Epoch [65/100], Train Loss: 0.2025, Val Loss: 1.4306, (SegHead) IoU: 0.6870, F1: 0.8145, Precision: 0.8729, Recall: 0.7634, (Final) IoU: 0.7019, F1: 0.8248, Precision: 0.8373, Recall: 0.8127, Composite Score: 0.7942, LR: 0.000031


Epoch 66/100 [Train]: 100%|██████████| 342/342 [04:18<00:00,  1.32it/s]
Epoch 66/100 [Validation]: 100%|██████████| 95/95 [00:36<00:00,  2.61it/s]


Epoch [66/100], Train Loss: 0.1986, Val Loss: 1.5184, (SegHead) IoU: 0.6735, F1: 0.8049, Precision: 0.8834, Recall: 0.7392, (Final) IoU: 0.6946, F1: 0.8198, Precision: 0.8486, Recall: 0.7928, Composite Score: 0.7889, LR: 0.000030


Epoch 67/100 [Train]: 100%|██████████| 342/342 [07:37<00:00,  1.34s/it]
Epoch 67/100 [Validation]: 100%|██████████| 95/95 [01:05<00:00,  1.45it/s]


Epoch [67/100], Train Loss: 0.2002, Val Loss: 1.4512, (SegHead) IoU: 0.6824, F1: 0.8112, Precision: 0.8763, Recall: 0.7551, (Final) IoU: 0.6989, F1: 0.8228, Precision: 0.8435, Recall: 0.8030, Composite Score: 0.7920, LR: 0.000028


Epoch 68/100 [Train]: 100%|██████████| 342/342 [08:16<00:00,  1.45s/it]
Epoch 68/100 [Validation]: 100%|██████████| 95/95 [01:03<00:00,  1.50it/s]


Epoch [68/100], Train Loss: 0.1967, Val Loss: 1.5983, (SegHead) IoU: 0.6687, F1: 0.8015, Precision: 0.8887, Recall: 0.7299, (Final) IoU: 0.6908, F1: 0.8172, Precision: 0.8563, Recall: 0.7814, Composite Score: 0.7864, LR: 0.000027


Epoch 69/100 [Train]:   5%|▌         | 18/342 [00:25<07:35,  1.41s/it]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 72/100 [Train]: 100%|██████████| 342/342 [08:00<00:00,  1.41s/it]
Epoch 72/100 [Validation]: 100%|██████████| 95/95 [01:03<00:00,  1.50it/s]


Epoch [72/100], Train Loss: 0.1998, Val Loss: 1.9140, (SegHead) IoU: 0.6422, F1: 0.7821, Precision: 0.9060, Recall: 0.6880, (Final) IoU: 0.6712, F1: 0.8033, Precision: 0.8676, Recall: 0.7478, Composite Score: 0.7725, LR: 0.000021


Epoch 73/100 [Train]: 100%|██████████| 342/342 [08:01<00:00,  1.41s/it]
Epoch 73/100 [Validation]: 100%|██████████| 95/95 [01:04<00:00,  1.47it/s]


Epoch [73/100], Train Loss: 0.2009, Val Loss: 1.5208, (SegHead) IoU: 0.6783, F1: 0.8083, Precision: 0.8813, Recall: 0.7466, (Final) IoU: 0.6956, F1: 0.8205, Precision: 0.8456, Recall: 0.7969, Composite Score: 0.7896, LR: 0.000020


Epoch 74/100 [Train]: 100%|██████████| 342/342 [08:03<00:00,  1.41s/it]
Epoch 74/100 [Validation]: 100%|██████████| 95/95 [01:05<00:00,  1.46it/s]


Epoch [74/100], Train Loss: 0.1967, Val Loss: 1.4408, (SegHead) IoU: 0.6933, F1: 0.8189, Precision: 0.8627, Recall: 0.7793, (Final) IoU: 0.7027, F1: 0.8254, Precision: 0.8281, Recall: 0.8227, Composite Score: 0.7948, LR: 0.000019


Epoch 75/100 [Train]: 100%|██████████| 342/342 [07:37<00:00,  1.34s/it]
Epoch 75/100 [Validation]: 100%|██████████| 95/95 [00:34<00:00,  2.72it/s]


Epoch [75/100], Train Loss: 0.1955, Val Loss: 1.6509, (SegHead) IoU: 0.6708, F1: 0.8030, Precision: 0.8874, Recall: 0.7332, (Final) IoU: 0.6881, F1: 0.8152, Precision: 0.8555, Recall: 0.7786, Composite Score: 0.7843, LR: 0.000018


Epoch 76/100 [Train]: 100%|██████████| 342/342 [04:18<00:00,  1.32it/s]
Epoch 76/100 [Validation]: 100%|██████████| 95/95 [00:35<00:00,  2.65it/s]


Epoch [76/100], Train Loss: 0.1994, Val Loss: 1.6146, (SegHead) IoU: 0.6715, F1: 0.8034, Precision: 0.8832, Recall: 0.7369, (Final) IoU: 0.6907, F1: 0.8171, Precision: 0.8492, Recall: 0.7873, Composite Score: 0.7861, LR: 0.000016


Epoch 77/100 [Train]: 100%|██████████| 342/342 [04:17<00:00,  1.33it/s]
Epoch 77/100 [Validation]: 100%|██████████| 95/95 [00:36<00:00,  2.60it/s]


Epoch [77/100], Train Loss: 0.1977, Val Loss: 1.5428, (SegHead) IoU: 0.6825, F1: 0.8113, Precision: 0.8774, Recall: 0.7545, (Final) IoU: 0.6972, F1: 0.8216, Precision: 0.8426, Recall: 0.8016, Composite Score: 0.7908, LR: 0.000015


Epoch 78/100 [Train]: 100%|██████████| 342/342 [04:18<00:00,  1.32it/s]
Epoch 78/100 [Validation]: 100%|██████████| 95/95 [00:35<00:00,  2.64it/s]


Epoch [78/100], Train Loss: 0.1956, Val Loss: 1.5073, (SegHead) IoU: 0.6889, F1: 0.8158, Precision: 0.8670, Recall: 0.7703, (Final) IoU: 0.6997, F1: 0.8233, Precision: 0.8328, Recall: 0.8141, Composite Score: 0.7925, LR: 0.000014


Epoch 79/100 [Train]: 100%|██████████| 342/342 [04:18<00:00,  1.32it/s]
Epoch 79/100 [Validation]: 100%|██████████| 95/95 [00:35<00:00,  2.65it/s]


Epoch [79/100], Train Loss: 0.1963, Val Loss: 1.6585, (SegHead) IoU: 0.6706, F1: 0.8028, Precision: 0.8857, Recall: 0.7341, (Final) IoU: 0.6926, F1: 0.8184, Precision: 0.8527, Recall: 0.7867, Composite Score: 0.7876, LR: 0.000013


Epoch 80/100 [Train]: 100%|██████████| 342/342 [04:16<00:00,  1.33it/s]
Epoch 80/100 [Validation]: 100%|██████████| 95/95 [00:35<00:00,  2.69it/s]


Epoch [80/100], Train Loss: 0.1974, Val Loss: 1.5981, (SegHead) IoU: 0.6769, F1: 0.8073, Precision: 0.8808, Recall: 0.7451, (Final) IoU: 0.6940, F1: 0.8193, Precision: 0.8474, Recall: 0.7931, Composite Score: 0.7885, LR: 0.000012


Epoch 81/100 [Train]:  33%|███▎      | 114/342 [01:24<02:50,  1.34it/s]

In [None]:
# import os
# import random
# import math
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import Dataset, DataLoader
# import cv2
# import logging
# from segment_anything import sam_model_registry
# from tqdm import tqdm
# import torch.optim.lr_scheduler as lr_scheduler
# # Removed LoRA import
# # from lora import LoRA_sam
# from types import MethodType

# # 设置随机种子
# def set_seed(seed):
#     random.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False

# # 配置字典
# CONFIG = {
#     'dataset_name': 'CTS-Pore',  # 仅需设置数据集名称
#     'data_base_dir': 'datasets',  # 数据集的基目录
#     'log_dir_base': 'logs',        # 日志的基目录
#     'log_file': 'best_model_metrics.log',
#     'save_dir_base': 'logs',       # 模型保存的基目录
#     'save_prefix': 'best_model',
#     'model_type': 'vit_l',
#     'checkpoint': 'weights/sam_vit_l_0b3195.pth',
#     'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
#     'num_epochs': 100,
#     'learning_rate': 1e-4,
#     'betas': (0.9, 0.999),
#     'weight_decay': 1e-4,
#     'metric_weights': {
#         'iou': 0.25,
#         'f1': 0.25,
#         'precision': 0.25,
#         'recall': 0.25
#     },
#     'aux_weight': 0.4,
#     'num_classes': 1,
#     # Removed 'rank' as it's specific to LoRA
#     'batch_size': 2
# }

# # 更新随机种子
# set_seed(42)  # 42 是任意选择的种子值

# # 基于 dataset_name 构建路径
# image_dir = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'images')
# mask_dir = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'masks')
# train_txt = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'train.txt')
# val_txt = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'val.txt')

# # 定义能够返回中间层的 forward 方法
# def forward_inter(self, x: torch.Tensor) -> torch.Tensor:
#     x = self.patch_embed(x)
#     if self.pos_embed is not None:
#         x = x + self.pos_embed
#     inter_features = []
#     for blk in self.blocks:
#         x = blk(x)
#         inter_features.append(x)

#     x = self.neck(x.permute(0, 3, 1, 2))
#     return x, inter_features

# # 更新日志和保存目录
# log_dir = os.path.join(CONFIG['log_dir_base'], CONFIG['dataset_name'])
# save_dir = os.path.join(CONFIG['save_dir_base'], CONFIG['dataset_name'])
# os.makedirs(log_dir, exist_ok=True)
# os.makedirs(save_dir, exist_ok=True)

# # 配置日志
# logging.basicConfig(
#     filename=os.path.join(log_dir, CONFIG['log_file']),
#     level=logging.INFO,
#     format='%(asctime)s - %(levelname)s - %(message)s'
# )

# # 加载 SAM 模型
# sam_model = sam_model_registry[CONFIG['model_type']](checkpoint=CONFIG['checkpoint'])
# sam_model.image_encoder.forward_inter = MethodType(forward_inter, sam_model.image_encoder)
# sam_model.to(CONFIG['device'])

# # 冻结 SAM 模型参数
# for param in sam_model.parameters():
#     param.requires_grad = False

# # 读取训练集和验证集列表
# def read_split_files(file_path):
#     with open(file_path, 'r') as f:
#         file_names = f.read().strip().split('\n')
#     return file_names

# # 数据集加载
# class SegmentationDataset(Dataset):
#     def __init__(self, image_dir, mask_dir, sam_model, file_list, mask_size=(1024, 1024), device='cpu'):
#         self.image_dir = image_dir
#         self.mask_dir = mask_dir
#         self.sam_model = sam_model
#         self.mask_size = mask_size
#         self.device = device
#         self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png') and f.replace('.png', '') in file_list]

#     def __len__(self):
#         return len(self.image_files)

#     def __getitem__(self, idx):
#         image_file = self.image_files[idx]
#         image_path = os.path.join(self.image_dir, image_file)
#         image = cv2.imread(image_path)
#         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_NEAREST)

#         mask_file = image_file
#         mask_path = os.path.join(self.mask_dir, mask_file)
#         mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
#         mask = cv2.resize(mask, self.mask_size, interpolation=cv2.INTER_NEAREST)

#         input_image_torch = torch.as_tensor(image, dtype=torch.float32).to(self.device)
#         input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()  # [C, H, W]

#         input_image = self.sam_model.preprocess(input_image_torch.to(self.device))

#         mask = torch.as_tensor(mask, dtype=torch.float32).to(self.device)  # 单通道浮点数

#         return input_image, mask

# # 读取文件列表
# train_files = read_split_files(train_txt)
# val_files = read_split_files(val_txt)

# # 创建数据集和数据加载器
# train_dataset = SegmentationDataset(
#     image_dir=image_dir,
#     mask_dir=mask_dir,
#     sam_model=sam_model,
#     file_list=train_files,
#     device=CONFIG['device']
# )

# val_dataset = SegmentationDataset(
#     image_dir=image_dir,
#     mask_dir=mask_dir,
#     sam_model=sam_model,
#     file_list=val_files,
#     device=CONFIG['device']
# )

# train_loader = DataLoader(
#     train_dataset,
#     batch_size=CONFIG['batch_size'],
#     shuffle=True
# )

# val_loader = DataLoader(
#     val_dataset,
#     batch_size=CONFIG['batch_size'],
#     shuffle=False
# )

# def process_class_logits(class_logits):
#     probs = torch.sigmoid(class_logits)
#     binary_masks = (probs > 0.5).cpu().numpy().astype(np.uint8)  # [B,1,1024,1024]
#     batch_results = []
#     for batch_idx in range(binary_masks.shape[0]):
#         current_mask = binary_masks[batch_idx, 0, :, :]
#         num_labels, labels = cv2.connectedComponents(current_mask)
#         sample_results = []
#         for label in range(1, num_labels):
#             current_component = (labels == label).astype(np.uint8)
#             y_coords, x_coords = np.nonzero(current_component)

#             min_x, max_x = np.min(x_coords), np.max(x_coords)
#             min_y, max_y = np.min(y_coords), np.max(y_coords)

#             mask = np.zeros_like(current_component, dtype=np.uint8)
#             mask[min_y:max_y+1, min_x:max_x+1] = current_component[min_y:max_y+1, min_x:max_x+1]

#             mask_resized = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)

#             sample_results.append({
#                 'bbox': [min_x, min_y, max_x, max_y],
#                 'mask': mask_resized
#             })
#         batch_results.append(sample_results)
#     return batch_results

# def predict_masks_batch(sam_model, image_embeddings, batch_results, device='cuda'):
#     sam_model.eval()
#     final_predictions = []
#     for idx, sample_results in enumerate(batch_results):
#         if not sample_results:
#             final_predictions.append(torch.zeros((1, 1, 1024, 1024), device=device))
#             continue

#         current_image_embedding = image_embeddings[idx:idx+1]

#         sparse_embeddings_list = []
#         dense_embeddings_list = []

#         for mask_info in sample_results:
#             box = torch.tensor(mask_info['bbox'], dtype=torch.float, device=device).unsqueeze(0)
#             mask = torch.from_numpy(mask_info['mask']).float().to(device).unsqueeze(0).unsqueeze(0)

#             sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
#                 points=None,
#                 boxes=box,
#                 masks=mask
#                 #masks=None
#             )

#             sparse_embeddings_list.append(sparse_embeddings)
#             dense_embeddings_list.append(dense_embeddings)

#         if len(sparse_embeddings_list) == 0:
#             final_predictions.append(torch.zeros((1, 1, 1024, 1024), device=device))
#             continue

#         sparse_embeddings_all = torch.cat(sparse_embeddings_list, dim=0)
#         dense_embeddings_all = torch.cat(dense_embeddings_list, dim=0)

#         low_res_masks, _ = sam_model.mask_decoder(
#             image_embeddings=current_image_embedding,
#             image_pe=sam_model.prompt_encoder.get_dense_pe(),
#             sparse_prompt_embeddings=sparse_embeddings_all,
#             dense_prompt_embeddings=dense_embeddings_all,
#             multimask_output=False,
#         )

#         resized_masks = F.interpolate(
#             low_res_masks,
#             size=(1024, 1024),
#             mode='bilinear',
#             align_corners=False
#         )

#         merged_mask = torch.max(resized_masks, dim=0)[0]  # [1,1024,1024]

#         final_predictions.append(merged_mask.unsqueeze(0))  # [1,1,1024,1024]

#     return torch.cat(final_predictions, dim=0)  # [B,1,1024,1024]

# def dice_loss(preds, targets, smooth=1e-6):
#     preds = torch.sigmoid(preds)
#     preds = preds.view(preds.size(0), -1)
#     targets = targets.view(targets.size(0), -1)

#     intersection = (preds * targets).sum(dim=1)
#     union = preds.sum(dim=1) + targets.sum(dim=1)

#     dice = (2. * intersection + smooth) / (union + smooth)
#     loss = 1 - dice
#     return loss.mean()

# def compute_loss(
#     seg_head_logits,
#     final_predictions,
#     masks,
#     loss_fn,
#     aux_classifiers=None,
#     inter_features=None,
#     selected_aux_layers=None,
#     aux_weight=0.4,
#     seg_weight=1.0,
#     sam_weight=1.0,
#     loss_weights=[1, 1]
# ):
#     seg_bce = loss_fn(seg_head_logits, masks)
#     seg_dice = dice_loss(seg_head_logits, masks)
#     seg_main_loss = loss_weights[0] * seg_bce + loss_weights[1] * seg_dice

#     sam_bce = loss_fn(final_predictions, masks)
#     sam_dice = dice_loss(final_predictions, masks)
#     sam_main_loss = loss_weights[0] * sam_bce + loss_weights[1] * sam_dice
#     # sam_main_loss = sam_dice

#     total_aux_loss = torch.tensor(0.0, device=seg_head_logits.device)
#     if aux_classifiers is not None and inter_features is not None and selected_aux_layers is not None:
#         aux_losses = []
#         for idx, aux_cls in zip(selected_aux_layers, aux_classifiers):
#             feature_idx = idx - 1
#             if feature_idx < len(inter_features):
#                 aux_feat = inter_features[feature_idx]
#                 aux_logits = aux_cls(aux_feat)
#                 loss_aux = loss_fn(aux_logits, masks)
#                 aux_losses.append(loss_aux)
#             else:
#                 logging.warning(f"inter_features does not have index {feature_idx}")

#         if aux_losses:
#             aux_loss_mean = torch.mean(torch.stack(aux_losses))
#             total_aux_loss = aux_loss_mean * aux_weight

#     total_loss = seg_weight * seg_main_loss + sam_weight * sam_main_loss + total_aux_loss

#     return (
#         total_loss,
#         seg_main_loss.item(),
#         sam_main_loss.item(),
#         total_aux_loss.item()
#     )

# def initialize_metrics():
#     return {
#         'tp': 0,
#         'fp': 0,
#         'fn': 0,
#         'intersection': 0,
#         'union': 0
#     }

# def accumulate_metrics(preds, targets, global_metrics, threshold=0.5):
#     preds_binary = (preds > threshold).astype(np.uint8)
#     targets_binary = (targets > threshold).astype(np.uint8)

#     tp = np.logical_and(preds_binary == 1, targets_binary == 1).sum()
#     fp = np.logical_and(preds_binary == 1, targets_binary == 0).sum()
#     fn = np.logical_and(preds_binary == 0, targets_binary == 1).sum()

#     intersection = tp
#     union = np.logical_or(preds_binary, targets_binary).sum()

#     global_metrics['tp'] += tp
#     global_metrics['fp'] += fp
#     global_metrics['fn'] += fn
#     global_metrics['intersection'] += intersection
#     global_metrics['union'] += union

# class AuxiliaryClassifier(nn.Module):
#     def __init__(self, in_channels, num_classes=1):
#         super(AuxiliaryClassifier, self).__init__()
#         self.aux_conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1)
#         self.aux_bn1 = nn.BatchNorm2d(256)
#         self.aux_relu1 = nn.ReLU(inplace=True)
#         self.aux_conv2 = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)

#     def forward(self, x):
#         x = x.permute(0, 3, 1, 2)
#         x = self.aux_conv1(x)
#         x = self.aux_bn1(x)
#         x = self.aux_relu1(x)
#         x = self.aux_conv2(x)
#         x = F.interpolate(x, size=(1024, 1024), mode='bilinear', align_corners=False)
#         return x

# class SegmentationHead(nn.Module):
#     def __init__(self, in_channels, intermediate_channels, out_channels=1, align_corners=False):
#         super(SegmentationHead, self).__init__()
#         self.align_corners = align_corners

#         self.mla_branches = nn.ModuleList([
#             nn.Sequential(
#                 nn.Conv2d(1024, 512, kernel_size=3, padding=1, stride=1),
#                 nn.BatchNorm2d(512),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(512, 256, kernel_size=3, padding=1, stride=1),
#                 nn.BatchNorm2d(256),
#                 nn.ReLU(inplace=True)
#             ) for _ in range(4)
#         ])

#         self.mla_image_branch = nn.Sequential(
#             nn.Conv2d(in_channels, 256, kernel_size=1, stride=1, bias=False),
#             nn.BatchNorm2d(256),
#             nn.ReLU(inplace=True)
#         )
#         self.mla_classifier_branch = nn.Sequential(
#             nn.Conv2d(256 * 5, intermediate_channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(intermediate_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(intermediate_channels, out_channels, kernel_size=1, stride=1)
#         )

#     def forward(self, image_embedding, inter_features):
#         if inter_features is None:
#             raise ValueError("inter_features must be provided for MLA strategy")
#         if len(inter_features) < 24:
#             raise ValueError(f"Expected at least 24 inter_features for MLA strategy, but got {len(inter_features)}")

#         selected_features = [inter_features[i] for i in [5, 11, 17, 23]]
#         selected_features = [feat.permute(0, 3, 1, 2) for feat in selected_features]

#         processed_features = []
#         for i, feat in enumerate(selected_features):
#             branch = self.mla_branches[i]
#             x_feat = branch(feat)
#             x_feat = F.interpolate(x_feat, scale_factor=4, mode='bilinear', align_corners=self.align_corners)
#             processed_features.append(x_feat)

#         img_feat = self.mla_image_branch(image_embedding)
#         img_feat = F.interpolate(img_feat, scale_factor=4, mode='bilinear', align_corners=self.align_corners)
#         processed_features.append(img_feat)

#         aggregated = torch.cat(processed_features, dim=1)
#         x = self.mla_classifier_branch(aggregated)
#         x = F.interpolate(x, size=(1024, 1024), mode='bilinear', align_corners=self.align_corners)

#         return x 

# # 初始化分割头模型
# model = SegmentationHead(
#     in_channels=256,
#     intermediate_channels=256,
#     out_channels=CONFIG['num_classes'],
#     align_corners=False
# )
# model.to(CONFIG['device'])

# # 初始化辅助分类器
# selected_aux_layers = [5, 11, 17, 23]
# aux_classifiers = nn.ModuleList([
#     AuxiliaryClassifier(in_channels=1024, num_classes=CONFIG['num_classes']).to(CONFIG['device'])
#     for _ in selected_aux_layers
# ])

# # 设置参数的可训练性
# for param in model.parameters():
#     param.requires_grad = True

# for aux in aux_classifiers:
#     for param in aux.parameters():
#         param.requires_grad = True

# # 确保 SAM 模型参数不参与训练
# # (Already done by setting requires_grad=False above)

# # 收集所有可训练参数 (Only model and auxiliary classifiers)
# model_trainable_params = list(model.parameters())
# aux_trainable_params = list(aux_classifiers.parameters())
# trainable_params = model_trainable_params + aux_trainable_params

# # 初始化优化器
# optimizer = torch.optim.AdamW(
#     trainable_params,
#     lr=CONFIG['learning_rate'],
#     betas=CONFIG['betas'],
#     weight_decay=CONFIG['weight_decay']
# )

# # 损失函数
# loss_fn = nn.BCEWithLogitsLoss()

# num_epochs = CONFIG['num_epochs']
# best_composite_score = float('-inf')
# best_epoch = 0

# weights = CONFIG['metric_weights']
# num_classes = CONFIG['num_classes']
# AUX_WEIGHT = CONFIG['aux_weight']

# warmup_epochs = 3
# min_lr_factor = 0.01

# # 学习率调度器
# def lr_lambda(epoch):
#     if epoch < warmup_epochs:
#         return float((epoch + 1) / warmup_epochs)
#     else:
#         cosine_decay = 0.5 * (1 + math.cos((epoch - warmup_epochs) * math.pi / (num_epochs - warmup_epochs)))
#         return float(min_lr_factor + (1 - min_lr_factor) * cosine_decay)

# scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# for epoch in range(num_epochs):
#     model.train()
#     aux_classifiers.train()

#     total_loss = 0
#     num_batches = 0

#     for images, masks in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Train]"):
#         images = images.to(CONFIG['device'])
#         masks = masks.to(CONFIG['device']).unsqueeze(1)

#         if images.dim() != 4 or masks.dim() != 4:
#             logging.error(f"Invalid input dimensions: images {images.shape}, masks {masks.shape}")
#             continue

#         with torch.no_grad():
#             image_embedding, inter_features = sam_model.image_encoder.forward_inter(images)
#         seg_head_logits = model(image_embedding, inter_features)

#         prompts = process_class_logits(seg_head_logits)
#         final_predictions = predict_masks_batch(sam_model, image_embedding, prompts, CONFIG['device'])

#         loss, seg_loss_val, sam_loss_val, aux_loss_val = compute_loss(
#             seg_head_logits=seg_head_logits,
#             final_predictions=final_predictions,
#             masks=masks,
#             loss_fn=loss_fn,
#             aux_classifiers=aux_classifiers,
#             inter_features=inter_features,
#             selected_aux_layers=selected_aux_layers,
#             aux_weight=AUX_WEIGHT,
#             seg_weight=1.0,
#             sam_weight=1.0,
#             loss_weights=[1,1]
#         )

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()
#         num_batches += 1

#     avg_train_loss = total_loss / num_batches if num_batches > 0 else 0

#     # 验证阶段
#     model.eval()
#     aux_classifiers.eval()

#     val_loss = 0
#     num_val_batches = 0
#     global_metrics_val = initialize_metrics()       # 最终预测指标
#     global_metrics_val_seg = initialize_metrics()   # 分割头输出指标

#     with torch.no_grad():
#         for images, masks in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Validation]"):
#             images = images.to(CONFIG['device'])
#             masks = masks.to(CONFIG['device']).unsqueeze(1)

#             if images.dim() != 4 or masks.dim() != 4:
#                 logging.error(f"Invalid input dimensions: images {images.shape}, masks {masks.shape}")
#                 continue

#             image_embedding, inter_features = sam_model.image_encoder.forward_inter(images)
#             seg_head_logits = model(image_embedding, inter_features)

#             prompts = process_class_logits(seg_head_logits)
#             final_predictions = predict_masks_batch(sam_model, image_embedding, prompts, CONFIG['device'])

#             loss, seg_loss_val, sam_loss_val, _ = compute_loss(
#                 seg_head_logits=seg_head_logits,
#                 final_predictions=final_predictions,
#                 masks=masks,
#                 loss_fn=loss_fn,
#                 aux_classifiers=None,
#                 inter_features=None,
#                 selected_aux_layers=None,
#                 aux_weight=0.0,
#                 seg_weight=1.0,
#                 sam_weight=1.0,
#                 loss_weights=[1,1]
#             )
#             val_loss += loss.item()
#             num_val_batches += 1

#             # 计算分割头输出指标
#             preds_seg = torch.sigmoid(seg_head_logits).cpu().numpy()  # 分割头输出的预测
#             preds_final = torch.sigmoid(final_predictions).cpu().numpy()  # 最终预测结果
#             masks_np = masks.cpu().numpy()

#             for p_seg, p_final, m_gt in zip(preds_seg, preds_final, masks_np):
#                 # 分割头输出指标
#                 accumulate_metrics(p_seg[0], m_gt, global_metrics_val_seg)
#                 # 最终预测指标
#                 accumulate_metrics(p_final[0], m_gt, global_metrics_val)

#     # 计算评价指标（分割头输出）
#     tp_seg = global_metrics_val_seg['tp']
#     fp_seg = global_metrics_val_seg['fp']
#     fn_seg = global_metrics_val_seg['fn']
#     intersection_seg = global_metrics_val_seg['intersection']
#     union_seg = global_metrics_val_seg['union']

#     iou_seg = intersection_seg / (union_seg + 1e-6)
#     precision_seg = tp_seg / (tp_seg + fp_seg + 1e-6)
#     recall_seg = tp_seg / (tp_seg + fn_seg + 1e-6)
#     f1_seg = (2 * precision_seg * recall_seg) / (precision_seg + recall_seg + 1e-6)

#     # 计算评价指标（最终预测）
#     tp = global_metrics_val['tp']
#     fp = global_metrics_val['fp']
#     fn = global_metrics_val['fn']
#     intersection = global_metrics_val['intersection']
#     union = global_metrics_val['union']

#     iou = intersection / (union + 1e-6)
#     precision = tp / (tp + fp + 1e-6)
#     recall = tp / (tp + fn + 1e-6)
#     f1 = (2 * precision * recall) / (precision + recall + 1e-6)

#     avg_iou_seg = iou_seg
#     avg_precision_seg = precision_seg
#     avg_recall_seg = recall_seg
#     avg_f1_seg = f1_seg

#     avg_iou = iou
#     avg_precision = precision
#     avg_recall = recall
#     avg_f1 = f1

#     composite_score = (
#         avg_iou * weights['iou'] +
#         avg_f1 * weights['f1'] +
#         avg_precision * weights['precision'] +
#         avg_recall * weights['recall']
#     )

#     avg_val_loss = val_loss / num_val_batches if num_val_batches > 0 else 0

#     log_message = (
#         f"Epoch [{epoch + 1}/{num_epochs}], "
#         f"Train Loss: {avg_train_loss:.4f}, "
#         f"Val Loss: {avg_val_loss:.4f}, "
#         f"(SegHead) IoU: {avg_iou_seg:.4f}, F1: {avg_f1_seg:.4f}, Precision: {avg_precision_seg:.4f}, Recall: {avg_recall_seg:.4f}, "
#         f"(Final) IoU: {avg_iou:.4f}, F1: {avg_f1:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, "
#         f"Composite Score: {composite_score:.4f}, "
#         f"LR: {optimizer.param_groups[0]['lr']:.6f}"
#     )
#     logging.info(log_message)
#     print(log_message)

#     if composite_score > best_composite_score:
#         best_composite_score = composite_score
#         best_epoch = epoch + 1

#         strategy = 'MLA'
#         checkpoint_path = os.path.join(save_dir, f"{CONFIG['save_prefix']}_{strategy}.pth")

#         # 只保存分割头模型
#         torch.save(model.state_dict(), checkpoint_path)

#         save_message = (
#             f"Best model saved at epoch {best_epoch} with Composite Score {best_composite_score:.4f} using {strategy} strategy"
#         )
#         logging.info(save_message)
#         print(save_message)

#     scheduler.step()

# logging.info("训练完成")
# print("训练完成")


In [None]:
# import os
# import random
# import math
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import Dataset, DataLoader
# import cv2
# import logging
# from segment_anything import sam_model_registry
# from tqdm import tqdm
# from types import MethodType
# import torch.optim.lr_scheduler as lr_scheduler

# # 设置随机种子（确保结果可复现）
# def set_seed(seed):
#     random.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False

# # 配置字典
# CONFIG = {
#     'dataset_name': 'CTS-Pore',  # 数据集名称
#     'data_base_dir': 'datasets',  # 数据集基目录
#     'log_dir_base': 'logs',        # 日志基目录
#     'log_file': 'best_model_metrics.log',
#     'save_dir_base': 'logs',       # 模型保存基目录
#     'save_prefix': 'best_model',
#     'model_type': 'vit_l',
#     'checkpoint': 'weights/sam_vit_l_0b3195.pth',
#     'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
#     'num_epochs': 100,
#     'learning_rate': 1e-4,
#     'betas': (0.9, 0.999),
#     'weight_decay': 1e-4,
#     'metric_weights': {
#         'iou': 0.25,
#         'f1': 0.25,
#         'precision': 0.25,
#         'recall': 0.25
#     },
#     'aux_weight': 0.4,
#     'num_classes': 1,
#     'batch_size': 2
# }

# # 更新随机种子
# set_seed(42)  # 42 是任意选择的种子值

# # 基于 dataset_name 构建路径
# image_dir = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'images')
# mask_dir = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'masks')
# train_txt = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'train.txt')
# val_txt = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'val.txt')

# # 定义适配器模块
# class Adapter(nn.Module):
#     def __init__(self, input_dim, reduction_factor=256):
#         super(Adapter, self).__init__()
#         self.down_project = nn.Linear(input_dim, input_dim // reduction_factor)
#         self.activation = nn.GELU()
#         self.up_project = nn.Linear(input_dim // reduction_factor, input_dim)
#         self.layer_norm = nn.LayerNorm(input_dim)
        
#         # 初始化权重
#         nn.init.normal_(self.down_project.weight, std=1e-3)
#         nn.init.normal_(self.up_project.weight, std=1e-3)
#         nn.init.zeros_(self.down_project.bias)
#         nn.init.zeros_(self.up_project.bias)

#     def forward(self, x):
#         residual = x
#         x = self.layer_norm(x)
#         x = self.down_project(x)
#         x = self.activation(x)
#         x = self.up_project(x)
#         return x + residual

# # 修改 SAM 的 forward_inter 函数，加入适配器
# def forward_inter(self, x: torch.Tensor) -> torch.Tensor:
#     x = self.patch_embed(x)
#     if self.pos_embed is not None:
#         x = x + self.pos_embed
    
#     # 创建每个 transformer 块的适配器（仅在第一次调用时）
#     if not hasattr(self, 'adapters'):
#         self.adapters = nn.ModuleList([
#             Adapter(input_dim=x.shape[-1]) 
#             for _ in self.blocks
#         ])
        
#         # 冻结所有参数
#         for param in self.parameters():
#             param.requires_grad = False
            
#         # 仅解冻适配器参数
#         for adapter in self.adapters:
#             for param in adapter.parameters():
#                 param.requires_grad = True
                
#         # 如果存在分类器，解冻其参数
#         if hasattr(self, 'classifier'):
#             for param in self.classifier.parameters():
#                 param.requires_grad = True

#     inter_features = []
#     for i, blk in enumerate(self.blocks):
#         x = blk(x)
#         # 在每个 transformer 块之后应用适配器
#         x = self.adapters[i](x)
#         inter_features.append(x)

#     x = self.neck(x.permute(0, 3, 1, 2))
#     return x, inter_features

# # 定义分类头
# class SegmentationHead(nn.Module):
#     def __init__(self, in_channels, intermediate_channels, out_channels=1, align_corners=False):
#         super(SegmentationHead, self).__init__()
#         self.align_corners = align_corners

#         self.mla_branches = nn.ModuleList([
#             nn.Sequential(
#                 nn.Conv2d(1024, 512, kernel_size=3, padding=1, stride=1),
#                 nn.BatchNorm2d(512),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(512, 256, kernel_size=3, padding=1, stride=1),
#                 nn.BatchNorm2d(256),
#                 nn.ReLU(inplace=True)
#             ) for _ in range(4)
#         ])

#         self.mla_image_branch = nn.Sequential(
#             nn.Conv2d(in_channels, 256, kernel_size=1, stride=1, bias=False),
#             nn.BatchNorm2d(256),
#             nn.ReLU(inplace=True)
#         )
#         self.mla_classifier_branch = nn.Sequential(
#             nn.Conv2d(256 * 5, intermediate_channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(intermediate_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(intermediate_channels, out_channels, kernel_size=1, stride=1)
#         )

#     def forward(self, image_embedding, inter_features):
#         if inter_features is None:
#             raise ValueError("inter_features must be provided for MLA strategy")
#         if len(inter_features) < 24:
#             raise ValueError(f"Expected at least 24 inter_features for MLA strategy, but got {len(inter_features)}")

#         selected_features = [inter_features[i] for i in [5, 11, 17, 23]]
#         selected_features = [feat.permute(0, 3, 1, 2) for feat in selected_features]

#         processed_features = []
#         for i, feat in enumerate(selected_features):
#             branch = self.mla_branches[i]
#             x_feat = branch(feat)
#             x_feat = F.interpolate(x_feat, scale_factor=4, mode='bilinear', align_corners=self.align_corners)
#             processed_features.append(x_feat)

#         img_feat = self.mla_image_branch(image_embedding)
#         img_feat = F.interpolate(img_feat, scale_factor=4, mode='bilinear', align_corners=self.align_corners)
#         processed_features.append(img_feat)

#         aggregated = torch.cat(processed_features, dim=1)
#         x = self.mla_classifier_branch(aggregated)
#         x = F.interpolate(x, size=(1024, 1024), mode='bilinear', align_corners=self.align_corners)

#         return x 

# # 定义辅助分类器
# class AuxiliaryClassifier(nn.Module):
#     def __init__(self, in_channels, num_classes=1):
#         super(AuxiliaryClassifier, self).__init__()
#         self.aux_conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1)
#         self.aux_bn1 = nn.BatchNorm2d(256)
#         self.aux_relu1 = nn.ReLU(inplace=True)
#         self.aux_conv2 = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)

#     def forward(self, x):
#         x = x.permute(0, 3, 1, 2)  # 调整形状为 [B, C, H, W]
#         x = self.aux_conv1(x)
#         x = self.aux_bn1(x)
#         x = self.aux_relu1(x)
#         x = self.aux_conv2(x)
#         x = F.interpolate(x, size=(1024, 1024), mode='bilinear', align_corners=False)
#         return x

# # 读取文件列表
# def read_split_files(file_path):
#     with open(file_path, 'r') as f:
#         file_names = f.read().strip().split('\n')
#     return file_names

# # 数据集加载
# class SegmentationDataset(Dataset):
#     def __init__(self, image_dir, mask_dir, sam_model, file_list, mask_size=(1024, 1024), device='cpu'):
#         self.image_dir = image_dir
#         self.mask_dir = mask_dir
#         self.sam_model = sam_model
#         self.mask_size = mask_size
#         self.device = device
#         self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png') and f.replace('.png', '') in file_list]

#     def __len__(self):
#         return len(self.image_files)

#     def __getitem__(self, idx):
#         image_file = self.image_files[idx]
#         image_path = os.path.join(self.image_dir, image_file)
#         image = cv2.imread(image_path)
#         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_NEAREST)

#         mask_file = image_file
#         mask_path = os.path.join(self.mask_dir, mask_file)
#         mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
#         mask = cv2.resize(mask, self.mask_size, interpolation=cv2.INTER_NEAREST)

#         input_image_torch = torch.as_tensor(image, dtype=torch.float32).to(self.device)
#         input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()  # [C, H, W]

#         input_image = self.sam_model.preprocess(input_image_torch.to(self.device))

#         mask = torch.as_tensor(mask, dtype=torch.float32).to(self.device)  # 单通道浮点数

#         return input_image, mask

# # 处理分类器输出的掩码
# def process_class_logits(class_logits):
#     probs = torch.sigmoid(class_logits)
#     binary_masks = (probs > 0.5).cpu().numpy().astype(np.uint8)  # [B,1,1024,1024]
#     batch_results = []
#     for batch_idx in range(binary_masks.shape[0]):
#         current_mask = binary_masks[batch_idx, 0, :, :]
#         num_labels, labels = cv2.connectedComponents(current_mask)
#         sample_results = []
#         for label in range(1, num_labels):
#             current_component = (labels == label).astype(np.uint8)
#             y_coords, x_coords = np.nonzero(current_component)

#             min_x, max_x = np.min(x_coords), np.max(x_coords)
#             min_y, max_y = np.min(y_coords), np.max(y_coords)

#             mask = np.zeros_like(current_component, dtype=np.uint8)
#             mask[min_y:max_y+1, min_x:max_x+1] = current_component[min_y:max_y+1, min_x:max_x+1]

#             mask_resized = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)

#             sample_results.append({
#                 'bbox': [min_x, min_y, max_x, max_y],
#                 'mask': mask_resized
#             })
#         batch_results.append(sample_results)
#     return batch_results

# # 预测批量掩码
# def predict_masks_batch(sam_model, image_embeddings, batch_results, device='cuda'):
#     sam_model.eval()
#     final_predictions = []
#     for idx, sample_results in enumerate(batch_results):
#         if not sample_results:
#             final_predictions.append(torch.zeros((1, 1, 1024, 1024), device=device))
#             continue

#         current_image_embedding = image_embeddings[idx:idx+1]

#         sparse_embeddings_list = []
#         dense_embeddings_list = []

#         for mask_info in sample_results:
#             box = torch.tensor(mask_info['bbox'], dtype=torch.float, device=device).unsqueeze(0)
#             mask = torch.from_numpy(mask_info['mask']).float().to(device).unsqueeze(0).unsqueeze(0)

#             sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
#                 points=None,
#                 boxes=box,
#                 masks=mask
#             )

#             sparse_embeddings_list.append(sparse_embeddings)
#             dense_embeddings_list.append(dense_embeddings)

#         if len(sparse_embeddings_list) == 0:
#             final_predictions.append(torch.zeros((1, 1, 1024, 1024), device=device))
#             continue

#         sparse_embeddings_all = torch.cat(sparse_embeddings_list, dim=0)
#         dense_embeddings_all = torch.cat(dense_embeddings_list, dim=0)

#         low_res_masks, _ = sam_model.mask_decoder(
#             image_embeddings=current_image_embedding,
#             image_pe=sam_model.prompt_encoder.get_dense_pe(),
#             sparse_prompt_embeddings=sparse_embeddings_all,
#             dense_prompt_embeddings=dense_embeddings_all,
#             multimask_output=False,
#         )

#         resized_masks = F.interpolate(
#             low_res_masks,
#             size=(1024, 1024),
#             mode='bilinear',
#             align_corners=False
#         )

#         merged_mask = torch.max(resized_masks, dim=0)[0]  # [1,1024,1024]

#         final_predictions.append(merged_mask.unsqueeze(0))  # [1,1,1024,1024]

#     return torch.cat(final_predictions, dim=0)  # [B,1,1024,1024]

# # Dice 损失函数
# def dice_loss(preds, targets, smooth=1e-6):
#     preds = torch.sigmoid(preds)
#     preds = preds.view(preds.size(0), -1)
#     targets = targets.view(targets.size(0), -1)

#     intersection = (preds * targets).sum(dim=1)
#     union = preds.sum(dim=1) + targets.sum(dim=1)

#     dice = (2. * intersection + smooth) / (union + smooth)
#     loss = 1 - dice
#     return loss.mean()

# # 计算综合损失
# def compute_loss(
#     seg_head_logits,
#     final_predictions,
#     masks,
#     loss_fn,
#     aux_classifiers=None,
#     inter_features=None,
#     selected_aux_layers=None,
#     aux_weight=0.4,
#     seg_weight=1.0,
#     sam_weight=1.0,
#     loss_weights=[1, 1]
# ):
#     seg_bce = loss_fn(seg_head_logits, masks)
#     seg_dice = dice_loss(seg_head_logits, masks)
#     seg_main_loss = loss_weights[0] * seg_bce + loss_weights[1] * seg_dice

#     sam_bce = loss_fn(final_predictions, masks)
#     sam_dice = dice_loss(final_predictions, masks)
#     sam_main_loss = loss_weights[0] * sam_bce + loss_weights[1] * sam_dice

#     total_aux_loss = torch.tensor(0.0, device=seg_head_logits.device)
#     if aux_classifiers is not None and inter_features is not None and selected_aux_layers is not None:
#         aux_losses = []
#         for idx, aux_cls in zip(selected_aux_layers, aux_classifiers):
#             feature_idx = idx - 1
#             if feature_idx < len(inter_features):
#                 aux_feat = inter_features[feature_idx]
#                 aux_logits = aux_cls(aux_feat)
#                 loss_aux = loss_fn(aux_logits, masks)
#                 aux_losses.append(loss_aux)
#             else:
#                 logging.warning(f"inter_features does not have index {feature_idx}")

#         if aux_losses:
#             aux_loss_mean = torch.mean(torch.stack(aux_losses))
#             total_aux_loss = aux_loss_mean * aux_weight

#     total_loss = seg_weight * seg_main_loss + sam_weight * sam_main_loss + total_aux_loss

#     return (
#         total_loss,
#         seg_main_loss.item(),
#         sam_main_loss.item(),
#         total_aux_loss.item()
#     )

# # 初始化评估指标
# def initialize_metrics():
#     return {
#         'tp': 0,
#         'fp': 0,
#         'fn': 0,
#         'intersection': 0,
#         'union': 0
#     }

# # 累积评估指标
# def accumulate_metrics(preds, targets, global_metrics, threshold=0.5):
#     preds_binary = (preds > threshold).astype(np.uint8)
#     targets_binary = (targets > threshold).astype(np.uint8)

#     tp = np.logical_and(preds_binary == 1, targets_binary == 1).sum()
#     fp = np.logical_and(preds_binary == 1, targets_binary == 0).sum()
#     fn = np.logical_and(preds_binary == 0, targets_binary == 1).sum()

#     intersection = tp
#     union = np.logical_or(preds_binary, targets_binary).sum()

#     global_metrics['tp'] += tp
#     global_metrics['fp'] += fp
#     global_metrics['fn'] += fn
#     global_metrics['intersection'] += intersection
#     global_metrics['union'] += union

# # 初始化分割头模型
# segmentation_head = SegmentationHead(
#     in_channels=256,
#     intermediate_channels=256,
#     out_channels=CONFIG['num_classes'],
#     align_corners=False
# )
# segmentation_head.to(CONFIG['device'])

# # 初始化辅助分类器
# selected_aux_layers = [5, 11, 17, 23]
# auxiliary_classifiers = nn.ModuleList([
#     AuxiliaryClassifier(in_channels=1024, num_classes=CONFIG['num_classes']).to(CONFIG['device'])
#     for _ in selected_aux_layers
# ])

# # 定义 forward_inter 方法并绑定到 SAM 的 image_encoder
# sam_model = sam_model_registry[CONFIG['model_type']](checkpoint=CONFIG['checkpoint'])
# sam_model.image_encoder.forward_inter = MethodType(forward_inter, sam_model.image_encoder)
# sam_model.to(CONFIG['device'])

# # 手动初始化 adapters
# # 获取输入维度
# dummy_input = torch.zeros(1, 3, 1024, 1024).to(CONFIG['device'])
# with torch.no_grad():
#     x = sam_model.image_encoder.patch_embed(dummy_input)
#     input_dim = x.shape[-1]

# # 创建适配器列表
# sam_model.image_encoder.adapters = nn.ModuleList([
#       Adapter(
#         input_dim=input_dim,
#         reduction_factor=256
#     ) 
#     for _ in sam_model.image_encoder.blocks
# ])
# sam_model.image_encoder.adapters.to(CONFIG['device'])

# # 冻结 SAM 模型所有参数（适配器除外）
# for param in sam_model.parameters():
#     param.requires_grad = False

# # 确保适配器参数可训练
# for adapter in sam_model.image_encoder.adapters:
#     for param in adapter.parameters():
#         param.requires_grad = True
# # 收集所有可训练参数：适配器参数、分割头参数和辅助分类器参数
# trainable_params = list(segmentation_head.parameters()) + list(auxiliary_classifiers.parameters()) + list(sam_model.image_encoder.adapters.parameters())

# # 初始化优化器
# optimizer = torch.optim.AdamW(
#     trainable_params,
#     lr=CONFIG['learning_rate'],
#     betas=CONFIG['betas'],
#     weight_decay=CONFIG['weight_decay']
# )

# # 损失函数
# loss_fn = nn.BCEWithLogitsLoss()

# # 读取训练集和验证集列表
# train_files = read_split_files(train_txt)
# val_files = read_split_files(val_txt)

# # 创建数据集和数据加载器
# train_dataset = SegmentationDataset(
#     image_dir=image_dir,
#     mask_dir=mask_dir,
#     sam_model=sam_model,
#     file_list=train_files,
#     device=CONFIG['device']
# )

# val_dataset = SegmentationDataset(
#     image_dir=image_dir,
#     mask_dir=mask_dir,
#     sam_model=sam_model,
#     file_list=val_files,
#     device=CONFIG['device']
# )

# train_loader = DataLoader(
#     train_dataset,
#     batch_size=CONFIG['batch_size'],
#     shuffle=True
# )

# val_loader = DataLoader(
#     val_dataset,
#     batch_size=CONFIG['batch_size'],
#     shuffle=False
# )

# # 配置日志和保存目录
# log_dir = os.path.join(CONFIG['log_dir_base'], CONFIG['dataset_name'])
# save_dir = os.path.join(CONFIG['save_dir_base'], CONFIG['dataset_name'])
# os.makedirs(log_dir, exist_ok=True)
# os.makedirs(save_dir, exist_ok=True)

# logging.basicConfig(
#     filename=os.path.join(log_dir, CONFIG['log_file']),
#     level=logging.INFO,
#     format='%(asctime)s - %(levelname)s - %(message)s'
# )

# # 学习率调度器
# num_epochs = CONFIG['num_epochs']
# warmup_epochs = 3
# min_lr_factor = 0.01

# def lr_lambda(epoch):
#     if epoch < warmup_epochs:
#         return float((epoch + 1) / warmup_epochs)
#     else:
#         cosine_decay = 0.5 * (1 + math.cos((epoch - warmup_epochs) * math.pi / (num_epochs - warmup_epochs)))
#         return float(min_lr_factor + (1 - min_lr_factor) * cosine_decay)

# scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# # 初始化最佳评分
# best_composite_score = float('-inf')
# best_epoch = 0

# weights = CONFIG['metric_weights']
# AUX_WEIGHT = CONFIG['aux_weight']

# # 训练循环
# for epoch in range(num_epochs):
#     segmentation_head.train()
#     auxiliary_classifiers.train()
#     sam_model.train()

#     total_loss = 0
#     num_batches = 0

#     for images, masks in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Train]"):
#         images = images.to(CONFIG['device'])
#         masks = masks.to(CONFIG['device']).unsqueeze(1)

#         if images.dim() != 4 or masks.dim() != 4:
#             logging.error(f"Invalid input dimensions: images {images.shape}, masks {masks.shape}")
#             continue

#         with torch.no_grad():
#             image_embedding, inter_features = sam_model.image_encoder.forward_inter(images)
#         seg_head_logits = segmentation_head(image_embedding, inter_features)

#         prompts = process_class_logits(seg_head_logits)
#         final_predictions = predict_masks_batch(sam_model, image_embedding, prompts, CONFIG['device'])

#         loss, seg_loss_val, sam_loss_val, aux_loss_val = compute_loss(
#             seg_head_logits=seg_head_logits,
#             final_predictions=final_predictions,
#             masks=masks,
#             loss_fn=loss_fn,
#             aux_classifiers=auxiliary_classifiers,
#             inter_features=inter_features,
#             selected_aux_layers=selected_aux_layers,
#             aux_weight=AUX_WEIGHT,
#             seg_weight=1.0,
#             sam_weight=1.0,
#             loss_weights=[1,1]
#         )

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()
#         num_batches += 1

#     avg_train_loss = total_loss / num_batches if num_batches > 0 else 0

#     # 验证阶段
#     segmentation_head.eval()
#     auxiliary_classifiers.eval()
#     sam_model.eval()

#     val_loss = 0
#     num_val_batches = 0
#     global_metrics_val = initialize_metrics()       # 最终预测指标
#     global_metrics_val_seg = initialize_metrics()   # 分割头输出指标

#     with torch.no_grad():
#         for images, masks in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Validation]"):
#             images = images.to(CONFIG['device'])
#             masks = masks.to(CONFIG['device']).unsqueeze(1)

#             if images.dim() != 4 or masks.dim() != 4:
#                 logging.error(f"Invalid input dimensions: images {images.shape}, masks {masks.shape}")
#                 continue

#             image_embedding, inter_features = sam_model.image_encoder.forward_inter(images)
#             seg_head_logits = segmentation_head(image_embedding, inter_features)

#             prompts = process_class_logits(seg_head_logits)
#             final_predictions = predict_masks_batch(sam_model, image_embedding, prompts, CONFIG['device'])

#             loss, seg_loss_val, sam_loss_val, _ = compute_loss(
#                 seg_head_logits=seg_head_logits,
#                 final_predictions=final_predictions,
#                 masks=masks,
#                 loss_fn=loss_fn,
#                 aux_classifiers=None,
#                 inter_features=None,
#                 selected_aux_layers=None,
#                 aux_weight=0.0,
#                 seg_weight=1.0,
#                 sam_weight=1.0,
#                 loss_weights=[1,1]
#             )
#             val_loss += loss.item()
#             num_val_batches += 1

#             # 计算分割头输出指标
#             preds_seg = torch.sigmoid(seg_head_logits).cpu().numpy()  # 分割头输出的预测
#             preds_final = torch.sigmoid(final_predictions).cpu().numpy()  # 最终预测结果
#             masks_np = masks.cpu().numpy()

#             for p_seg, p_final, m_gt in zip(preds_seg, preds_final, masks_np):
#                 # 分割头输出指标
#                 accumulate_metrics(p_seg[0], m_gt, global_metrics_val_seg)
#                 # 最终预测指标
#                 accumulate_metrics(p_final[0], m_gt, global_metrics_val)

#     # 计算评价指标（分割头输出）
#     tp_seg = global_metrics_val_seg['tp']
#     fp_seg = global_metrics_val_seg['fp']
#     fn_seg = global_metrics_val_seg['fn']
#     intersection_seg = global_metrics_val_seg['intersection']
#     union_seg = global_metrics_val_seg['union']

#     iou_seg = intersection_seg / (union_seg + 1e-6)
#     precision_seg = tp_seg / (tp_seg + fp_seg + 1e-6)
#     recall_seg = tp_seg / (tp_seg + fn_seg + 1e-6)
#     f1_seg = (2 * precision_seg * recall_seg) / (precision_seg + recall_seg + 1e-6)

#     # 计算评价指标（最终预测）
#     tp = global_metrics_val['tp']
#     fp = global_metrics_val['fp']
#     fn = global_metrics_val['fn']
#     intersection = global_metrics_val['intersection']
#     union = global_metrics_val['union']

#     iou = intersection / (union + 1e-6)
#     precision = tp / (tp + fp + 1e-6)
#     recall = tp / (tp + fn + 1e-6)
#     f1 = (2 * precision * recall) / (precision + recall + 1e-6)

#     avg_iou_seg = iou_seg
#     avg_precision_seg = precision_seg
#     avg_recall_seg = recall_seg
#     avg_f1_seg = f1_seg

#     avg_iou = iou
#     avg_precision = precision
#     avg_recall = recall
#     avg_f1 = f1

#     composite_score = (
#         avg_iou * weights['iou'] +
#         avg_f1 * weights['f1'] +
#         avg_precision * weights['precision'] +
#         avg_recall * weights['recall']
#     )

#     avg_val_loss = val_loss / num_val_batches if num_val_batches > 0 else 0

#     log_message = (
#         f"Epoch [{epoch + 1}/{num_epochs}], "
#         f"Train Loss: {avg_train_loss:.4f}, "
#         f"Val Loss: {avg_val_loss:.4f}, "
#         f"(SegHead) IoU: {avg_iou_seg:.4f}, F1: {avg_f1_seg:.4f}, Precision: {avg_precision_seg:.4f}, Recall: {avg_recall_seg:.4f}, "
#         f"(Final) IoU: {avg_iou:.4f}, F1: {avg_f1:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, "
#         f"Composite Score: {composite_score:.4f}, "
#         f"LR: {optimizer.param_groups[0]['lr']:.6f}"
#     )
#     logging.info(log_message)
#     print(log_message)

#     # 保存最佳模型
#     if composite_score > best_composite_score:
#         best_composite_score = composite_score
#         best_epoch = epoch + 1

#         strategy = 'MLA'
#         checkpoint_path = os.path.join(save_dir, f"{CONFIG['save_prefix']}_{strategy}.pth")

#         # 仅保存分割头模型和适配器参数
#         save_dict = {
#             'segmentation_head': segmentation_head.state_dict(),
#             'auxiliary_classifiers': auxiliary_classifiers.state_dict(),
#             'adapters': sam_model.image_encoder.adapters.state_dict()
#         }
#         torch.save(save_dict, checkpoint_path)

#         save_message = (
#             f"Best model saved at epoch {best_epoch} with Composite Score {best_composite_score:.4f} using {strategy} strategy"
#         )
#         logging.info(save_message)
#         print(save_message)

#     # 更新学习率
#     scheduler.step()

# logging.info("训练完成")
# print("训练完成")


In [None]:
# import os
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
# import cv2
# import logging
# from tqdm import tqdm
# from types import MethodType
# from torch.utils.data import Dataset, DataLoader
# from segment_anything import sam_model_registry
# import json
# from datetime import datetime

# # Configuration
# CONFIG = {
#     'dataset_name': 'CTS-Pore',
#     'data_base_dir': 'datasets',
#     'model_type': 'vit_l',
#     'checkpoint': 'weights/sam_vit_l_0b3195.pth',
#     'trained_model_path': 'logs/CTS-Pore/best_model_MLA.pth',
#     'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
#     'batch_size': 2,
#     'save_predictions': True,  # Whether to save prediction masks
#     'predictions_dir': 'predictions'  # Directory to save predictions
# }

# # Setup paths
# image_dir = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'images')
# mask_dir = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'masks')
# test_txt = os.path.join(CONFIG['data_base_dir'], CONFIG['dataset_name'], 'test.txt')

# # Create predictions directory if needed
# if CONFIG['save_predictions']:
#     predictions_dir = os.path.join(CONFIG['predictions_dir'], CONFIG['dataset_name'])
#     os.makedirs(predictions_dir, exist_ok=True)

# # Setup logging
# log_dir = os.path.join('evaluation_logs', CONFIG['dataset_name'])
# os.makedirs(log_dir, exist_ok=True)
# log_file = f'evaluation_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
# logging.basicConfig(
#     filename=os.path.join(log_dir, log_file),
#     level=logging.INFO,
#     format='%(asctime)s - %(levelname)s - %(message)s'
# )

# # Model components (same as training)
# class Adapter(nn.Module):
#     def __init__(self, input_dim, reduction_factor=256):
#         super(Adapter, self).__init__()
#         self.down_project = nn.Linear(input_dim, input_dim // reduction_factor)
#         self.activation = nn.GELU()
#         self.up_project = nn.Linear(input_dim // reduction_factor, input_dim)
#         self.layer_norm = nn.LayerNorm(input_dim)
        
#     def forward(self, x):
#         residual = x
#         x = self.layer_norm(x)
#         x = self.down_project(x)
#         x = self.activation(x)
#         x = self.up_project(x)
#         return x + residual

# class SegmentationHead(nn.Module):
#     def __init__(self, in_channels, intermediate_channels, out_channels=1, align_corners=False):
#         super(SegmentationHead, self).__init__()
#         self.align_corners = align_corners

#         self.mla_branches = nn.ModuleList([
#             nn.Sequential(
#                 nn.Conv2d(1024, 512, kernel_size=3, padding=1, stride=1),
#                 nn.BatchNorm2d(512),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(512, 256, kernel_size=3, padding=1, stride=1),
#                 nn.BatchNorm2d(256),
#                 nn.ReLU(inplace=True)
#             ) for _ in range(4)
#         ])

#         self.mla_image_branch = nn.Sequential(
#             nn.Conv2d(in_channels, 256, kernel_size=1, stride=1, bias=False),
#             nn.BatchNorm2d(256),
#             nn.ReLU(inplace=True)
#         )
        
#         self.mla_classifier_branch = nn.Sequential(
#             nn.Conv2d(256 * 5, intermediate_channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(intermediate_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(intermediate_channels, out_channels, kernel_size=1, stride=1)
#         )

#     def forward(self, image_embedding, inter_features):
#         selected_features = [inter_features[i] for i in [5, 11, 17, 23]]
#         selected_features = [feat.permute(0, 3, 1, 2) for feat in selected_features]

#         processed_features = []
#         for i, feat in enumerate(selected_features):
#             branch = self.mla_branches[i]
#             x_feat = branch(feat)
#             x_feat = F.interpolate(x_feat, scale_factor=4, mode='bilinear', align_corners=self.align_corners)
#             processed_features.append(x_feat)

#         img_feat = self.mla_image_branch(image_embedding)
#         img_feat = F.interpolate(img_feat, scale_factor=4, mode='bilinear', align_corners=self.align_corners)
#         processed_features.append(img_feat)

#         aggregated = torch.cat(processed_features, dim=1)
#         x = self.mla_classifier_branch(aggregated)
#         x = F.interpolate(x, size=(1024, 1024), mode='bilinear', align_corners=self.align_corners)

#         return x

# # Modified forward_inter for SAM
# def forward_inter(self, x: torch.Tensor) -> torch.Tensor:
#     x = self.patch_embed(x)
#     if self.pos_embed is not None:
#         x = x + self.pos_embed

#     inter_features = []
#     for i, blk in enumerate(self.blocks):
#         x = blk(x)
#         x = self.adapters[i](x)
#         inter_features.append(x)

#     x = self.neck(x.permute(0, 3, 1, 2))
#     return x, inter_features

# # Data loading helpers
# def read_split_files(file_path):
#     with open(file_path, 'r') as f:
#         return f.read().strip().split('\n')

# class SegmentationDataset(Dataset):
#     def __init__(self, image_dir, mask_dir, sam_model, file_list, mask_size=(1024, 1024), device='cpu'):
#         self.image_dir = image_dir
#         self.mask_dir = mask_dir
#         self.sam_model = sam_model
#         self.mask_size = mask_size
#         self.device = device
#         self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png') and f.replace('.png', '') in file_list]

#     def __len__(self):
#         return len(self.image_files)

#     def __getitem__(self, idx):
#         image_file = self.image_files[idx]
#         image_path = os.path.join(self.image_dir, image_file)
#         image = cv2.imread(image_path)
#         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_NEAREST)

#         mask_path = os.path.join(self.mask_dir, image_file)
#         mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
#         mask = cv2.resize(mask, self.mask_size, interpolation=cv2.INTER_NEAREST)

#         input_image_torch = torch.as_tensor(image, dtype=torch.float32).to(self.device)
#         input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()

#         input_image = self.sam_model.preprocess(input_image_torch.to(self.device))
#         mask = torch.as_tensor(mask, dtype=torch.float32).to(self.device)

#         return input_image, mask, image_file

# # Helper functions for processing predictions
# def process_class_logits(class_logits):
#     probs = torch.sigmoid(class_logits)
#     binary_masks = (probs > 0.5).cpu().numpy().astype(np.uint8)
#     batch_results = []
    
#     for batch_idx in range(binary_masks.shape[0]):
#         current_mask = binary_masks[batch_idx, 0, :, :]
#         num_labels, labels = cv2.connectedComponents(current_mask)
#         sample_results = []
        
#         for label in range(1, num_labels):
#             current_component = (labels == label).astype(np.uint8)
#             y_coords, x_coords = np.nonzero(current_component)
            
#             if len(y_coords) == 0:
#                 continue
                
#             min_x, max_x = np.min(x_coords), np.max(x_coords)
#             min_y, max_y = np.min(y_coords), np.max(y_coords)
            
#             mask = np.zeros_like(current_component, dtype=np.uint8)
#             mask[min_y:max_y+1, min_x:max_x+1] = current_component[min_y:max_y+1, min_x:max_x+1]
            
#             mask_resized = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)
            
#             sample_results.append({
#                 'bbox': [min_x, min_y, max_x, max_y],
#                 'mask': mask_resized
#             })
            
#         batch_results.append(sample_results)
#     return batch_results

# def predict_masks_batch(sam_model, image_embeddings, batch_results, device='cuda'):
#     sam_model.eval()
#     final_predictions = []
    
#     for idx, sample_results in enumerate(batch_results):
#         if not sample_results:
#             final_predictions.append(torch.zeros((1, 1, 1024, 1024), device=device))
#             continue

#         current_image_embedding = image_embeddings[idx:idx+1]
#         sparse_embeddings_list = []
#         dense_embeddings_list = []
        
#         for mask_info in sample_results:
#             box = torch.tensor(mask_info['bbox'], dtype=torch.float, device=device).unsqueeze(0)
#             mask = torch.from_numpy(mask_info['mask']).float().to(device).unsqueeze(0).unsqueeze(0)
            
#             sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
#                 points=None,
#                 boxes=box,
#                 masks=mask
#             )
            
#             sparse_embeddings_list.append(sparse_embeddings)
#             dense_embeddings_list.append(dense_embeddings)
        
#         if not sparse_embeddings_list:
#             final_predictions.append(torch.zeros((1, 1, 1024, 1024), device=device))
#             continue
            
#         sparse_embeddings_all = torch.cat(sparse_embeddings_list, dim=0)
#         dense_embeddings_all = torch.cat(dense_embeddings_list, dim=0)
        
#         low_res_masks, _ = sam_model.mask_decoder(
#             image_embeddings=current_image_embedding,
#             image_pe=sam_model.prompt_encoder.get_dense_pe(),
#             sparse_prompt_embeddings=sparse_embeddings_all,
#             dense_prompt_embeddings=dense_embeddings_all,
#             multimask_output=False,
#         )
        
#         resized_masks = F.interpolate(
#             low_res_masks,
#             size=(1024, 1024),
#             mode='bilinear',
#             align_corners=False
#         )
        
#         merged_mask = torch.max(resized_masks, dim=0)[0]
#         final_predictions.append(merged_mask.unsqueeze(0))
        
#     return torch.cat(final_predictions, dim=0)

# def calculate_metrics(preds, targets, threshold=0.5):
#     preds_binary = (preds > threshold).astype(np.uint8)
#     targets_binary = (targets > threshold).astype(np.uint8)
    
#     tp = np.logical_and(preds_binary == 1, targets_binary == 1).sum()
#     fp = np.logical_and(preds_binary == 1, targets_binary == 0).sum()
#     fn = np.logical_and(preds_binary == 0, targets_binary == 1).sum()
    
#     intersection = tp
#     union = np.logical_or(preds_binary == 1, targets_binary == 1).sum()
    
#     iou = intersection / (union + 1e-6)
#     precision = tp / (tp + fp + 1e-6)
#     recall = tp / (tp + fn + 1e-6)
#     f1 = 2 * precision * recall / (precision + recall + 1e-6)
    
#     return {
#         'iou': iou,
#         'precision': precision,
#         'recall': recall,
#         'f1': f1,
#         'tp': tp,
#         'fp': fp,
#         'fn': fn
#     }

# def main():
#     # Initialize models
#     sam_model = sam_model_registry[CONFIG['model_type']](checkpoint=CONFIG['checkpoint'])
#     sam_model.to(CONFIG['device'])
    
#     # Initialize segmentation head
#     segmentation_head = SegmentationHead(
#         in_channels=256,
#         intermediate_channels=256,
#         out_channels=1,
#         align_corners=False
#     ).to(CONFIG['device'])
    
#     # Set up forward_inter method
#     sam_model.image_encoder.forward_inter = MethodType(forward_inter, sam_model.image_encoder)
    
#     # Load trained weights
#     checkpoint = torch.load(CONFIG['trained_model_path'], map_location=CONFIG['device'])
#     segmentation_head.load_state_dict(checkpoint['segmentation_head'])
    
#     # Initialize and load adapters
#     dummy_input = torch.zeros(1, 3, 1024, 1024).to(CONFIG['device'])
#     with torch.no_grad():
#         x = sam_model.image_encoder.patch_embed(dummy_input)
#         input_dim = x.shape[-1]

#     # Initialize adapters
#     sam_model.image_encoder.adapters = nn.ModuleList([
#         Adapter(input_dim=input_dim, reduction_factor=256)
#         for _ in sam_model.image_encoder.blocks
#     ]).to(CONFIG['device'])

#     # Load adapter weights
#     sam_model.image_encoder.adapters.load_state_dict(checkpoint['adapters'])

#     # Load test data
#     test_files = read_split_files(test_txt)
#     test_dataset = SegmentationDataset(
#         image_dir=image_dir,
#         mask_dir=mask_dir,
#         sam_model=sam_model,
#         file_list=test_files,
#         device=CONFIG['device']
#     )

#     test_loader = DataLoader(
#         test_dataset,
#         batch_size=CONFIG['batch_size'],
#         shuffle=False
#     )

#     # Set models to evaluation mode
#     segmentation_head.eval()
#     sam_model.eval()

#     # Initialize metric accumulators
#     metrics_seghead = {
#         'iou': [], 'precision': [], 'recall': [], 'f1': [],
#         'tp': 0, 'fp': 0, 'fn': 0
#     }
#     metrics_final = {
#         'iou': [], 'precision': [], 'recall': [], 'f1': [],
#         'tp': 0, 'fp': 0, 'fn': 0
#     }

#     # Evaluation loop
#     with torch.no_grad():
#         for images, masks, filenames in tqdm(test_loader, desc="Evaluating"):
#             masks = masks.unsqueeze(1)  # Add channel dimension

#             # Forward pass through SAM encoder
#             image_embedding, inter_features = sam_model.image_encoder.forward_inter(images)
            
#             # Get segmentation head predictions
#             seg_head_logits = segmentation_head(image_embedding, inter_features)
            
#             # Process segmentation head outputs for SAM
#             prompts = process_class_logits(seg_head_logits)
            
#             # Get final predictions from SAM
#             final_predictions = predict_masks_batch(sam_model, image_embedding, prompts, CONFIG['device'])

#             # Convert predictions to numpy and apply sigmoid
#             seg_head_preds = torch.sigmoid(seg_head_logits).cpu().numpy()
#             final_preds = torch.sigmoid(final_predictions).cpu().numpy()
#             masks_np = masks.cpu().numpy()

#             # Calculate metrics for each image in batch
#             for idx in range(len(filenames)):
#                 # Get metrics for segmentation head output
#                 metrics_seg = calculate_metrics(seg_head_preds[idx, 0], masks_np[idx, 0])
#                 metrics_seghead['iou'].append(metrics_seg['iou'])
#                 metrics_seghead['precision'].append(metrics_seg['precision'])
#                 metrics_seghead['recall'].append(metrics_seg['recall'])
#                 metrics_seghead['f1'].append(metrics_seg['f1'])
#                 metrics_seghead['tp'] += metrics_seg['tp']
#                 metrics_seghead['fp'] += metrics_seg['fp']
#                 metrics_seghead['fn'] += metrics_seg['fn']

#                 # Get metrics for final predictions
#                 metrics_fin = calculate_metrics(final_preds[idx, 0], masks_np[idx, 0])
#                 metrics_final['iou'].append(metrics_fin['iou'])
#                 metrics_final['precision'].append(metrics_fin['precision'])
#                 metrics_final['recall'].append(metrics_fin['recall'])
#                 metrics_final['f1'].append(metrics_fin['f1'])
#                 metrics_final['tp'] += metrics_fin['tp']
#                 metrics_final['fp'] += metrics_fin['fp']
#                 metrics_final['fn'] += metrics_fin['fn']

#                 # Save predictions if enabled
#                 if CONFIG['save_predictions']:
#                     # Save segmentation head predictions
#                     seg_head_pred = (seg_head_preds[idx, 0] > 0.5).astype(np.uint8) * 255
#                     cv2.imwrite(
#                         os.path.join(predictions_dir, f"{filenames[idx]}_seghead.png"),
#                         seg_head_pred
#                     )

#                     # Save final predictions
#                     final_pred = (final_preds[idx, 0] > 0.5).astype(np.uint8) * 255
#                     cv2.imwrite(
#                         os.path.join(predictions_dir, f"{filenames[idx]}_final.png"),
#                         final_pred
#                     )

#     # Calculate final metrics
#     results = {
#         'segmentation_head': {
#             'mean_iou': np.mean(metrics_seghead['iou']),
#             'mean_precision': np.mean(metrics_seghead['precision']),
#             'mean_recall': np.mean(metrics_seghead['recall']),
#             'mean_f1': np.mean(metrics_seghead['f1']),
#             'global_precision': metrics_seghead['tp'] / (metrics_seghead['tp'] + metrics_seghead['fp'] + 1e-6),
#             'global_recall': metrics_seghead['tp'] / (metrics_seghead['tp'] + metrics_seghead['fn'] + 1e-6),
#             'global_f1': 2 * metrics_seghead['tp'] / (2 * metrics_seghead['tp'] + metrics_seghead['fp'] + metrics_seghead['fn'] + 1e-6)
#         },
#         'final_prediction': {
#             'mean_iou': np.mean(metrics_final['iou']),
#             'mean_precision': np.mean(metrics_final['precision']),
#             'mean_recall': np.mean(metrics_final['recall']),
#             'mean_f1': np.mean(metrics_final['f1']),
#             'global_precision': metrics_final['tp'] / (metrics_final['tp'] + metrics_final['fp'] + 1e-6),
#             'global_recall': metrics_final['tp'] / (metrics_final['tp'] + metrics_final['fn'] + 1e-6),
#             'global_f1': 2 * metrics_final['tp'] / (2 * metrics_final['tp'] + metrics_final['fp'] + metrics_final['fn'] + 1e-6)
#         }
#     }

#     # Log results
#     log_message = "\nEvaluation Results:\n" + "-" * 50 + "\n"
    
#     log_message += "\nSegmentation Head Results:\n"
#     for metric, value in results['segmentation_head'].items():
#         log_message += f"{metric}: {value:.4f}\n"
    
#     log_message += "\nFinal Prediction Results:\n"
#     for metric, value in results['final_prediction'].items():
#         log_message += f"{metric}: {value:.4f}\n"

#     logging.info(log_message)
#     print(log_message)

#     # Save results to JSON
#     results_file = os.path.join(log_dir, 'evaluation_results.json')
#     with open(results_file, 'w') as f:
#         json.dump(results, f, indent=4)
    
#     print(f"\nResults saved to {results_file}")
#     if CONFIG['save_predictions']:
#         print(f"Predictions saved to {predictions_dir}")

# if __name__ == "__main__":
#     main()