## 4.2最新版训练Unet全流程
    包含 random随机种子存储，大图预处理(包含旋转，缩放，色彩对比度范围随机变化，图像填充裁切)，patch分割，训练，存储model，测试model。

### 加载基本库

In [1]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import math
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm
import gc

# 封装的函数模块，model部分全部引用
from functions.data import prepare_dataset,get_optimized_loaders
from functions.model import *
from functions.data import SegmentationDataset
from functions.imagePreprocessing import ImagePreprocessor


### 数据集准备

In [2]:
# train_data, test_data = prepare_dataset("Kasthuri++")
train_data, test_data = prepare_dataset("Lucchi++")

# 第一次运行时：预处理并保存
train_dataset = SegmentationDataset(
    data_list=train_data,
    patch_size=256,
    stride=128,
    preProcess=True,  # 启用预处理

)

test_dataset = SegmentationDataset(
    data_list=test_data,
    patch_size=256,
    stride=128,
    preProcess=True,  # 启用预处理

)

train_loader, test_loader = get_optimized_loaders(
    train_dataset, 
    test_dataset, 
    batch_size=32,  # 增加批大小
    num_workers=4
)



Train Data (Lucchi++): [{'image': 'dataset\\Lucchi++\\Train_In\\mask0000.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\0.png', 'index': 0}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0001.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\1.png', 'index': 1}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0002.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\2.png', 'index': 2}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0003.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\3.png', 'index': 3}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0004.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\4.png', 'index': 4}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0005.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\5.png', 'index': 5}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0006.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\6.png', 'index': 6}, {'image': 'dataset\\Lucchi++\\Train_In\\mask0007.png', 'annotation': 'dataset\\Lucchi++\\Train_Out\\7.png', 'index': 7}, {'image'

处理图像: 100%|██████████| 165/165 [00:20<00:00,  8.12it/s]


数据集预处理完成，共生成 7865 个patch
开始数据集预处理...


处理图像: 100%|██████████| 165/165 [00:23<00:00,  6.91it/s]

数据集预处理完成，共生成 7707 个patch





### 训练Unet模型阶段

In [3]:
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = UNet(num_classes=1).to(device)

cuda


In [6]:
class Att_YNet(nn.Module):
    class AttentionBlock(nn.Module):
        def __init__(self, F_g, F_l, F_int, batch_norm=False):
            super(Att_YNet.AttentionBlock, self).__init__()
            self.W_g = nn.Sequential(
                nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
                nn.BatchNorm2d(F_int) if batch_norm else nn.Identity()
            )

            self.W_x = nn.Sequential(
                nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
                nn.BatchNorm2d(F_int) if batch_norm else nn.Identity()
            )

            self.psi = nn.Sequential(
                nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
                nn.BatchNorm2d(1) if batch_norm else nn.Identity(),
                nn.Sigmoid()
            )

            self.relu = nn.ReLU(inplace=True)

        def forward(self, g, x):
            # Ensure g and x have the same spatial dimensions
            if g.size(2) != x.size(2) or g.size(3) != x.size(3):
                g = F.interpolate(g, size=(x.size(2), x.size(3)), 
                                  mode='bilinear', align_corners=True)
                
            g1 = self.W_g(g)
            x1 = self.W_x(x)
            psi = self.relu(g1 + x1)
            psi = self.psi(psi)
            return x * psi

    def __init__(self, num_classes):
        super(Att_YNet, self).__init__()
        
        # Encoder - Same as UNet
        self.encoder = nn.ModuleList([
            self.conv_block(3, 32, stride=2),
            self.conv_block(32, 64, stride=2),
            self.conv_block(64, 128, stride=2),
            self.conv_block(128, 256, stride=2)
        ])
        
        # Decoder - Same as UNet
        self.decoder = nn.ModuleList([
            self.upconv_block(256, 128),
            self.upconv_block(128, 64),
            self.upconv_block(64, 32),
            self.upconv_block(32, 32)
        ])
        
        # Final classification layer
        self.final = nn.Conv2d(32, num_classes, kernel_size=3, padding=1)
        
        # Set this to True if you want both outputs
        self.return_reconstructed = False
        
    def conv_block(self, in_channels, out_channels, stride=1):
        # Exactly as in UNet
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def upconv_block(self, in_channels, out_channels):
        # Exactly as in UNet
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def set_return_reconstructed(self, value):
        """Set whether to return the reconstructed image along with the mask"""
        self.return_reconstructed = value
    
    def forward(self, x):
        # Encoder
        features = []
        for i, encoder_layer in enumerate(self.encoder):
            x = encoder_layer(x)
            features.append(x)
        
        # Store the bottleneck for the second decoder path
        bottleneck = x
        
        # Decoder (just like UNet)
        for i, decoder_layer in enumerate(self.decoder):
            x = decoder_layer(x)
            if i < len(self.decoder) - 1:
                x = x + features[-i-2]  # Skip connection (exactly as in UNet)
        
        # Get mask output (same as UNet's final output)
        mask = self.final(x)
        
        # For now, we're just replicating the UNet's behavior
        # Later we can add the autoencoder path
        return mask

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = Att_YNet(num_classes=1).to(device)

cuda


#### 开始训练

In [None]:
optimized_train_model(model, train_loader, test_loader, num_epochs=50, device=device)


Epoch 1/50:
Training - Loss: 0.2722, IoU: 0.4515, Dice: 0.5881
Validation - Loss: 0.1169, IoU: 0.6526, Dice: 0.7841
------------------------------------------------------------
Epoch 2/50:
Training - Loss: 0.0904, IoU: 0.7153, Dice: 0.8329
Validation - Loss: 0.3013, IoU: 0.2829, Dice: 0.4229
------------------------------------------------------------
Epoch 3/50:
Training - Loss: 0.0559, IoU: 0.7905, Dice: 0.8824
Validation - Loss: 0.0556, IoU: 0.7298, Dice: 0.8403
------------------------------------------------------------
Epoch 4/50:
Training - Loss: 0.0426, IoU: 0.8276, Dice: 0.9053
Validation - Loss: 0.0547, IoU: 0.7327, Dice: 0.8422
------------------------------------------------------------
Epoch 5/50:
Training - Loss: 0.0345, IoU: 0.8535, Dice: 0.9205
Validation - Loss: 0.6809, IoU: 0.3060, Dice: 0.4508
------------------------------------------------------------
Epoch 6/50:
Training - Loss: 0.0280, IoU: 0.8788, Dice: 0.9353
Validation - Loss: 0.0449, IoU: 0.7641, Dice: 0.8634

#### 可选保存模型

In [None]:
# Save model
torch.save(model.state_dict(), "models/UnetTrain/overlaping_Ynet_segmentation_try4.27V0A1K.pth")

### 测试model性能

In [None]:
## 4.2最新的predict全流程设置
# 设置参数
model_path = "models/UnetTrain/overlaping_unet_segmentation_try4.13k_random2_iou.pth"  # 模型路径

image_path = "dataset/Lucchi++/Test_In/mask0022.png"  # 测试图片路径
mask_path = "dataset/Lucchi++/Test_Out/22.png"  # 真实掩码路径（如果有）

# image_path = "dataset/Kasthuri++/Test_In/mask1049.png"  
# mask_path = "dataset/Kasthuri++/Test_Out/mask1049.png"  


# image_path = "dataset/VNC/Test_In/16.tif"  
# mask_path = "dataset/VNC/Test_Out/16.png"  

save_dir = "test/predictData"  # 结果保存目录
patch_size = 256  # patch大小
stride = 128  # 步长

# 执行分割流程
pred_mask, metrics = segmentation_pipeline(
    model_path=model_path,
    image_path=image_path,
    mask_path=mask_path,
    save_dir=save_dir,
    patch_size=patch_size,
    stride=stride,
    value=-30,
    alpha=1.0
)

if metrics:
    print(f"最终评估指标: IoU={metrics['IoU']:.4f}, Dice={metrics['Dice']:.4f}")
    
plt.imshow(pred_mask)
plt.show()

In [None]:
# 处理特定图像列表
image_list = ["img1.png", "img2.png", "img3.png"]
mask_list = ["mask1.png", "mask2.png", "mask3.png"]

results = batch_segmentation_pipeline(
    model_path="models/UnetTrain/overlaping_unet_segmentation_try4.13k_random2_loss.pth",
    image_paths=image_list,
    mask_paths=mask_list,
    save_dir="test/batch_predictData",
    patch_size=256,
    stride=128,
    value=-30,
    alpha=1.0
)

Error: 无有效图像路径


### 批处理测试model效果，并保存结果同时生成对应的文件(可用，速度一般,1张图大概6s左右)

In [10]:
# 处理整个目录
results = batch_segmentation_pipeline(
    model_path="models/UnetTrain/overlaping_unet_segmentation_try4.13k_random2_loss.pth",
    image_paths="dataset/VNC/Test_In",  # 图像目录
    mask_paths="dataset/VNC/Test_Out",  # 掩码目录
    save_dir="test/batch_predictData",
    patch_size=256,
    stride=128,
    value=-30,
    alpha=1.0
)


开始处理 20 个样本...


 处理进度 1/20: 00

=== 处理图像: 00 ===
设备: cuda, 模型: overlaping_unet_segmentation_try4.13k_random2_loss.pth
Patch尺寸: 256, 步长: 128
评估指标: IoU=0.6089, Dice=0.7569
二值化前的图像范围: 0-255
二值化阈值: 127
二值化后白色像素数: 48985
denoise评估结果: IoU=0.5904, Dice=0.7425


 处理进度 2/20: 01

=== 处理图像: 01 ===
设备: cuda, 模型: overlaping_unet_segmentation_try4.13k_random2_loss.pth
Patch尺寸: 256, 步长: 128
评估指标: IoU=0.5807, Dice=0.7347
二值化前的图像范围: 0-255
二值化阈值: 127
二值化后白色像素数: 44278
denoise评估结果: IoU=0.5719, Dice=0.7277


 处理进度 3/20: 02

=== 处理图像: 02 ===
设备: cuda, 模型: overlaping_unet_segmentation_try4.13k_random2_loss.pth
Patch尺寸: 256, 步长: 128
评估指标: IoU=0.5788, Dice=0.7332
二值化前的图像范围: 0-255
二值化阈值: 127
二值化后白色像素数: 37270
denoise评估结果: IoU=0.5677, Dice=0.7242


 处理进度 4/20: 03

=== 处理图像: 03 ===
设备: cuda, 模型: overlaping_unet_segmentation_try4.13k_random2_loss.pth
Patch尺寸: 256, 步长: 128
评估指标: IoU=0.5248, Dice=0.6883
二值化前的图像范围: 0-255
二值化阈值: 127
二值化后白色像素数: 33233
denoise评估结果: IoU=0.5160, Dice=0.6808


 处理进度 5/20: 04

=== 处理图像: 04 =

4.13遇到的问题：
 首先是K测L的泛化很低，因为之前只是针对单张图片导致我错判以为IoU很高，实际上只有0.08
 太低了，然后我尝试在分析情况发现一个问题，就是equalization带来的影响很大，甚至可以说直方图的均值后的分布直接决定了Iou，也就是说我equalization大概率写的不好，导致它均值的时候对不同图片效果不同，简单来说那几张很特殊IOU很高的图，它们的均值图都是很平均的一条线，而其它的图均值后基本跟原本的直方图分布差不多。。
 解决： 实际上是因为文件mask的名称和image没有对上导致的意外，这个是很致命的，因为我pipeline的batch没做好，后面尝试修改就完成了

批量重命名文件

In [None]:
import os
import re

def rename_files_with_padded_numbers(folder_path, prefix="mask"):
    """
    将文件夹内的数字文件名重命名为带前缀和补零格式
    
    参数:
        folder_path (str): 目标文件夹路径
        prefix (str): 要添加的前缀（默认为"mask"）
    """
    for filename in os.listdir(folder_path):
        old_path = os.path.join(folder_path, filename)
        
        if os.path.isfile(old_path):
            # 使用正则表达式提取数字部分
            match = re.match(r'^(\d+)\.(.*?)$', filename)
            if match:
                number = match.group(1)
                ext = match.group(2)
                
                # 将数字转为4位数（前面补零）
                padded_number = f"{int(number):04d}"
                
                # 新文件名
                new_filename = f"{prefix}{padded_number}.{ext}"
                new_path = os.path.join(folder_path, new_filename)
                
                # 重命名文件
                os.rename(old_path, new_path)
                print(f"已重命名: {filename} -> {new_filename}")

if __name__ == "__main__":
    folder = "dataset/Lucchi++/Test_Out"
    prefix = input("请输入要添加的前缀(默认为'mask'，直接回车使用默认值): ") or "mask"
    
    rename_files_with_padded_numbers(folder, prefix)
    print("所有文件重命名完成！")

# 尝试设定完善模型和训练过程

## 定义新的版本Unet框架

In [2]:
import time


class ImprovedUNet(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.2):
        super(ImprovedUNet, self).__init__()
        
        # Encoder (downsampling)
        self.enc1 = self.conv_block(3, 32, dropout_rate)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc2 = self.conv_block(32, 64, dropout_rate)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc3 = self.conv_block(64, 128, dropout_rate)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc4 = self.conv_block(128, 256, dropout_rate)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Bottleneck
        self.bottleneck = self.conv_block(256, 512, dropout_rate)
        
        # Decoder (upsampling)
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(512, 256, dropout_rate)  # 512 = 256 + 256 (skip connection)
        
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(256, 128, dropout_rate)  # 256 = 128 + 128 (skip connection)
        
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(128, 64, dropout_rate)   # 128 = 64 + 64 (skip connection)
        
        self.upconv1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(64, 32, dropout_rate)    # 64 = 32 + 32 (skip connection)
        
        # Final classification layer
        self.final = nn.Conv2d(32, num_classes, kernel_size=1)
        
    def conv_block(self, in_channels, out_channels, dropout_rate):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
        enc4 = self.enc4(self.pool3(enc3))
        
        # Bottleneck
        bottleneck = self.bottleneck(self.pool4(enc4))
        
        # Decoder with skip connections
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.dec4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.dec3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.dec2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)
        
        # Final classification
        out = self.final(dec1)
        
        return out

def train_and_validate_model(model, train_loader, val_loader=None, num_epochs=50, 
                           learning_rate=1e-4, device='cuda', save_path='best_model.pth'):
    """
    训练和验证模型 - 添加数据加载调试信息，优化内存管理
    
    参数:
        model: 要训练的模型
        train_loader: 训练数据加载器
        val_loader: 验证数据加载器 (可选)
        num_epochs: 训练轮数
        learning_rate: 学习率
        device: 训练设备 ('cuda' 或 'cpu')
        save_path: 保存最佳模型的路径
    
    返回:
        model: 训练后的模型
        history: 训练历史记录
    """
    # 移动模型到指定设备
    model = model.to(device)
    
    # 设置优化器和损失函数
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    # 创建学习率调度器
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )
    
    # 记录训练历史
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_dice': [],
        'val_dice': [],
        'train_iou': [],
        'val_iou': [],
        'lr': []
    }
    
    # 用于早停和保存最佳模型
    best_val_metric = float('inf')
    patience = 10
    counter = 0
    
    print(f"开始训练，共 {num_epochs} 个轮次")
    
    # 训练循环
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        print(f"\n开始第 {epoch+1}/{num_epochs} 轮训练")
        
        # 训练阶段
        model.train()
        train_loss = 0.0
        train_dice_scores = []
        train_iou_scores = []
        
        # 清理GPU内存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # 获取数据加载器的迭代器
        train_iter = iter(train_loader)
        
        # 获取批次总数
        total_batches = len(train_loader)
        print(f"此轮次共有 {total_batches} 个批次")
        
        # 使用预先获取的迭代器，逐个批次处理
        for batch_idx in range(total_batches):
            batch_start_time = time.time()
            
            # 尝试加载下一个批次
            print(f"正在加载批次 {batch_idx+1}/{total_batches}...")
            try:
                batch = next(train_iter)
                print(f"批次 {batch_idx+1} 加载完成，耗时 {time.time() - batch_start_time:.2f}秒")
            except Exception as e:
                print(f"加载批次 {batch_idx+1} 时出错: {str(e)}")
                continue
                
            # 将数据转移到设备
            data_to_device_time = time.time()
            try:
                # 逐个张量转移以避免整批内存问题
                inputs = batch['patches'].to(device)
                print(f"输入数据转移到 {device} 完成，形状: {inputs.shape}")
                
                targets = batch['mask_patches'].to(device)
                print(f"目标数据转移到 {device} 完成，形状: {targets.shape}")
                
                print(f"数据转移到 {device} 完成，耗时 {time.time() - data_to_device_time:.2f}秒")
            except Exception as e:
                print(f"将数据转移到 {device} 时出错: {str(e)}")
                continue
                
            # 梯度清零和前向传播
            forward_time = time.time()
            try:
                optimizer.zero_grad()
                outputs = model(inputs)
                print(f"前向传播完成，耗时 {time.time() - forward_time:.2f}秒")
            except Exception as e:
                print(f"前向传播时出错: {str(e)}")
                continue
                
            # 计算损失
            loss_time = time.time()
            try:
                loss = criterion(outputs, targets)
                print(f"损失计算完成，耗时 {time.time() - loss_time:.2f}秒")
            except Exception as e:
                print(f"计算损失时出错: {str(e)}")
                continue
                
            # 反向传播和优化
            backward_time = time.time()
            try:
                loss.backward()
                optimizer.step()
                print(f"反向传播完成，耗时 {time.time() - backward_time:.2f}秒")
            except Exception as e:
                print(f"反向传播时出错: {str(e)}")
                continue
                
            # 累加损失
            train_loss += loss.item()
            
            # 计算训练指标
            metrics_time = time.time()
            with torch.no_grad():
                preds = (torch.sigmoid(outputs) > 0.5).float()
                iou, dice = calculate_metrics(preds, targets)
                train_dice_scores.append(dice)
                train_iou_scores.append(iou)
                print(f"指标计算完成，耗时 {time.time() - metrics_time:.2f}秒")
            
            # 手动清理临时张量以释放内存
            del inputs, targets, outputs, preds, loss
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            batch_total_time = time.time() - batch_start_time
            print(f"批次 {batch_idx+1} 处理完成，总耗时 {batch_total_time:.2f}秒")
            
            # 只处理第一个批次以进行诊断后退出
            if batch_idx == 0 and epoch == 0:
                print("\n成功处理第一个批次！继续训练...\n")
        
        # 计算平均训练指标
        train_loss /= total_batches
        train_dice = sum(train_dice_scores) / len(train_dice_scores) if train_dice_scores else 0
        train_iou = sum(train_iou_scores) / len(train_iou_scores) if train_iou_scores else 0
        
        # 记录当前学习率
        current_lr = optimizer.param_groups[0]['lr']
        
        # 保存训练指标
        history['train_loss'].append(train_loss)
        history['train_dice'].append(train_dice)
        history['train_iou'].append(train_iou)
        history['lr'].append(current_lr)
        
        # 验证阶段
        if val_loader is not None:
            val_loss, val_dice, val_iou = validate_model(model, val_loader, criterion, device)
            history['val_loss'].append(val_loss)
            history['val_dice'].append(val_dice)
            history['val_iou'].append(val_iou)
            
            # 调整学习率
            old_lr = optimizer.param_groups[0]['lr']
            scheduler.step(val_loss)
            new_lr = optimizer.param_groups[0]['lr']
            
            # 手动检测学习率变化并打印通知
            if new_lr < old_lr:
                print(f"学习率从 {old_lr:.6f} 减小到 {new_lr:.6f}")
            
            # 早停和保存最佳模型
            if val_loss < best_val_metric:
                best_val_metric = val_loss
                counter = 0
                # 保存最佳模型
                torch.save(model.state_dict(), save_path)
                print(f"轮次 {epoch+1}/{num_epochs} - "
                     f"训练损失: {train_loss:.4f}, 训练Dice: {train_dice:.4f}, "
                     f"验证损失: {val_loss:.4f}, 验证Dice: {val_dice:.4f} - "
                     f"耗时: {time.time() - epoch_start_time:.1f}秒 - 保存最佳模型")
            else:
                counter += 1
                print(f"轮次 {epoch+1}/{num_epochs} - "
                     f"训练损失: {train_loss:.4f}, 训练Dice: {train_dice:.4f}, "
                     f"验证损失: {val_loss:.4f}, 验证Dice: {val_dice:.4f} - "
                     f"耗时: {time.time() - epoch_start_time:.1f}秒")
                
                if counter >= patience:
                    print(f"早停: {epoch+1} 轮后未见改善")
                    break
        else:
            # 如果没有验证集，每个epoch都保存模型
            torch.save(model.state_dict(), save_path)
            print(f"轮次 {epoch+1}/{num_epochs} - "
                 f"训练损失: {train_loss:.4f}, 训练Dice: {train_dice:.4f} - "
                 f"耗时: {time.time() - epoch_start_time:.1f}秒")
    
    # 如果有验证集，加载最佳模型
    if val_loader is not None and os.path.exists(save_path):
        model.load_state_dict(torch.load(save_path))
    
    return model, history    

def validate_model(model, val_loader, criterion, device):
    """
    在验证集上评估模型 - 简化版，无进度显示
    
    参数:
        model: 要评估的模型
        val_loader: 验证数据加载器
        criterion: 损失函数
        device: 评估设备
    
    返回:
        val_loss: 验证损失
        val_dice: 验证Dice系数
        val_iou: 验证IoU
    """
    model.eval()
    val_loss = 0.0
    dice_scores = []
    iou_scores = []
    
    with torch.no_grad():
        for batch in val_loader:
            # 获取输入和目标
            inputs = batch['patches'].to(device)
            targets = batch['mask_patches'].to(device)
            
            # 前向传播
            outputs = model(inputs)
            
            # 计算损失
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            
            # 计算评估指标
            preds = (torch.sigmoid(outputs) > 0.5).float()
            iou, dice = calculate_metrics(preds, targets)
            dice_scores.append(dice)
            iou_scores.append(iou)
    
    # 计算平均损失和评估指标
    val_loss /= len(val_loader)
    val_dice = sum(dice_scores) / len(dice_scores)
    val_iou = sum(iou_scores) / len(iou_scores)
    
    return val_loss, val_dice, val_iou

def calculate_metrics(pred_mask, true_mask, threshold=0.5):
    """
    计算IoU和Dice系数
    
    参数:
        pred_mask: 预测掩码（已经是二值化的掩码或概率掩码）
        true_mask: 目标掩码
        threshold: 二值化阈值（如果预测掩码还不是二值的）
    
    返回:
        iou: IoU值
        dice: Dice系数
    """
    # 确保预测掩码是二值的
    if threshold is not None:
        pred_mask = (pred_mask > threshold).float()

    # 计算交集和并集
    intersection = (pred_mask * true_mask).sum()
    union = pred_mask.sum() + true_mask.sum() - intersection

    # 避免除零错误
    smooth = 1e-7
    
    # 计算IoU
    iou = (intersection + smooth) / (union + smooth)

    # 计算Dice系数
    dice = (2. * intersection + smooth) / (pred_mask.sum() + true_mask.sum() + smooth)

    return iou.item(), dice.item()

def reconstruct_from_patches(patches, positions, original_size, patch_size, stride):
    """
    从patches重建完整图像
    
    参数:
        patches: 预测的patch列表
        positions: 每个patch的左上角坐标列表，格式为[(y1, x1), (y2, x2), ...]
        original_size: 原始图像尺寸，格式为(height, width)
        patch_size: patch的大小
        stride: patch滑动的步长

    返回:
        reconstructed: 重建后的完整图像
    """
    h, w = original_size
    reconstructed = np.zeros((h, w), dtype=np.float32)
    count = np.zeros((h, w), dtype=np.float32)
    
    for patch, (y, x) in zip(patches, positions):
        patch_h = min(patch_size, h - y)
        patch_w = min(patch_size, w - x)
        reconstructed[y:y+patch_h, x:x+patch_w] += patch[:patch_h, :patch_w]
        count[y:y+patch_h, x:x+patch_w] += 1
    
    # 处理重叠区域，确保没有除零错误
    count[count == 0] = 1
    reconstructed /= count
    return reconstructed

def plot_training_history(history, save_path=None):
    """
    可视化训练历史
    
    参数:
        history: 包含训练和验证指标的字典
        save_path: 可选，保存图像的路径
    """
    epochs = range(1, len(history['train_loss']) + 1)
    
    plt.figure(figsize=(16, 12))
    
    # Loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    if 'val_loss' in history and history['val_loss']:
        plt.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # IoU
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history['train_iou'], 'b-', label='Training IoU')
    if 'val_iou' in history and history['val_iou']:
        plt.plot(epochs, history['val_iou'], 'r-', label='Validation IoU')
    plt.title('Training and Validation IoU')
    plt.xlabel('Epochs')
    plt.ylabel('IoU')
    plt.legend()
    plt.grid(True)
    
    # Dice
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history['train_dice'], 'b-', label='Training Dice')
    if 'val_dice' in history and history['val_dice']:
        plt.plot(epochs, history['val_dice'], 'r-', label='Validation Dice')
    plt.title('Training and Validation Dice')
    plt.xlabel('Epochs')
    plt.ylabel('Dice')
    plt.legend()
    plt.grid(True)
    
    # Learning Rate
    plt.subplot(2, 2, 4)
    plt.plot(epochs, history['lr'], 'g-', label='Learning Rate')
    plt.title('Learning Rate')
    plt.xlabel('Epochs')
    plt.ylabel('Learning Rate')
    plt.legend()
    plt.grid(True)
    plt.yscale('log')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Training history plot saved to {save_path}")
    
    plt.show()    
    
  
def improved_predict(model, image, mask=None, device="cuda", patch_size=256, stride=None, 
                    test_time_augmentation=True, tta_flips=True, tta_scales=True):
    """
    改进的预测函数：支持测试时增强
    
    参数:
        model: 训练好的模型
        image: 预处理后的图像数组（RGB格式），已确保尺寸适合patch提取
        mask: 预处理后的真实掩码数组（可选，用于计算评估指标）
        device: 使用的设备，默认为cuda
        patch_size: 处理的patch大小
        stride: patch滑动的步长，None则默认为patch_size//2
        test_time_augmentation: 是否使用测试时增强
        tta_flips: 是否使用翻转进行测试时增强
        tta_scales: 是否使用多尺度进行测试时增强
    
    返回:
        pred_mask: 预测的分割掩码
        metrics: 评估指标（若提供真实掩码）
    """
    # 设置默认stride
    if stride is None:
        stride = patch_size // 2
    
    # 获取图像尺寸
    h, w = image.shape[:2]
    
    # 执行预测
    model.eval()
    
    # 计算步数
    h_steps = max(1, (h - patch_size + stride) // stride)
    w_steps = max(1, (w - patch_size + stride) // stride)
    
    patches_list = []
    patch_positions = []
    
    # 提取重叠的patches
    for i in range(h_steps):
        for j in range(w_steps):
            # 计算patch坐标
            y_start = min(i * stride, h - patch_size)
            x_start = min(j * stride, w - patch_size)
            
            # 提取patch
            patch = image[y_start:y_start+patch_size, x_start:x_start+patch_size]
            
            # 处理可能的边界情况，确保patch尺寸正确
            if patch.shape[0] < patch_size or patch.shape[1] < patch_size:
                temp_patch = np.zeros((patch_size, patch_size, 3), dtype=patch.dtype)
                temp_patch[:patch.shape[0], :patch.shape[1]] = patch
                patch = temp_patch
            
            # 标准化和通道顺序转换
            normalized_patch = patch.astype(np.float32) / 255.0
            normalized_patch = normalized_patch.transpose(2, 0, 1)
            patches_list.append(normalized_patch)
            patch_positions.append((y_start, x_start))
    
    # 转换为tensor
    patches_array = np.stack(patches_list)
    patches_tensor = torch.from_numpy(patches_array).float().to(device)
    
    # 分批处理以避免内存不足
    batch_size = 16  # 可以根据可用GPU内存调整
    all_pred_patches = []
    
    with torch.no_grad():
        for i in range(0, len(patches_tensor), batch_size):
            batch = patches_tensor[i:i+batch_size]
            
            if test_time_augmentation:
                # 测试时增强预测
                pred_masks = predict_with_tta(
                    model, batch, tta_flips=tta_flips, tta_scales=tta_scales
                )
            else:
                # 常规预测
                outputs = model(batch)
                pred_masks = (torch.sigmoid(outputs) > 0.5).float()
            
            all_pred_patches.extend([p[0].cpu().numpy() for p in pred_masks])
    
    # 重建完整的预测掩码
    prediction = reconstruct_from_patches(
        all_pred_patches, 
        patch_positions, 
        (h, w), 
        patch_size, 
        stride
    )
    
    # 二值化
    prediction = (prediction > 0.5).astype(np.float32)
    
    # 计算评估指标（如果提供了真实掩码）
    metrics = None
    if mask is not None:
        # 确保掩码格式正确
        # 假设255是前景，0和2是背景
        mask_float = (mask == 255).astype(np.float32)
            
        # 转换为tensor计算IoU和Dice
        pred_tensor = torch.from_numpy(prediction)
        true_tensor = torch.from_numpy((mask_float > 0.5).astype(np.float32))
        
        # 确保尺寸一致
        if pred_tensor.shape != true_tensor.shape:
            print(f"警告: 预测掩码 ({pred_tensor.shape}) 和真实掩码 ({true_tensor.shape}) 尺寸不一致")
            # 使用最近邻插值调整大小
            pred_tensor = torch.nn.functional.interpolate(
                pred_tensor.unsqueeze(0).unsqueeze(0), 
                size=true_tensor.shape, 
                mode='nearest'
            ).squeeze(0).squeeze(0)
        
        # 计算评估指标
        iou, dice = calculate_metrics(pred_tensor, true_tensor)
        metrics = {"IoU": iou, "Dice": dice}
        
        print(f"评估指标: IoU={iou:.4f}, Dice={dice:.4f}")
    
    return prediction, metrics    
    
    
    
def predict_with_tta(model, x, tta_flips=True, tta_scales=True):
    """
    使用测试时增强进行预测
    
    参数:
        model: 模型
        x: 输入张量
        tta_flips: 是否使用翻转增强
        tta_scales: 是否使用多尺度增强
    
    返回:
        平均预测结果
    """
    # 原始预测
    original_pred = torch.sigmoid(model(x))
    all_preds = [original_pred]
    
    # 水平翻转
    if tta_flips:
        # 水平翻转
        flipped_h = torch.flip(x, dims=[3])  # 水平翻转
        pred_h = torch.sigmoid(model(flipped_h))
        pred_h = torch.flip(pred_h, dims=[3])  # 翻转回来
        all_preds.append(pred_h)
        
        # 垂直翻转
        flipped_v = torch.flip(x, dims=[2])  # 垂直翻转
        pred_v = torch.sigmoid(model(flipped_v))
        pred_v = torch.flip(pred_v, dims=[2])  # 翻转回来
        all_preds.append(pred_v)
        
        # 水平+垂直翻转
        flipped_hv = torch.flip(x, dims=[2, 3])  # 水平+垂直翻转
        pred_hv = torch.sigmoid(model(flipped_hv))
        pred_hv = torch.flip(pred_hv, dims=[2, 3])  # 翻转回来
        all_preds.append(pred_hv)
    
    # 多尺度
    if tta_scales:
        # 缩小
        scale_small = F.interpolate(x, scale_factor=0.75, mode='bilinear', align_corners=False)
        pred_small = torch.sigmoid(model(scale_small))
        pred_small = F.interpolate(pred_small, size=original_pred.shape[2:], mode='bilinear', align_corners=False)
        all_preds.append(pred_small)
        
        # 放大
        scale_large = F.interpolate(x, scale_factor=1.25, mode='bilinear', align_corners=False)
        pred_large = torch.sigmoid(model(scale_large))
        pred_large = F.interpolate(pred_large, size=original_pred.shape[2:], mode='bilinear', align_corners=False)
        all_preds.append(pred_large)
    
    # 平均所有预测
    final_pred = torch.stack(all_preds).mean(dim=0)
    return (final_pred > 0.5).float()


def post_process_mask(mask, min_size=100, close_kernel_size=5, open_kernel_size=3):
    """
    对预测掩码进行后处理以提高质量
    
    参数:
        mask: 预测的二值掩码
        min_size: 移除小于此大小的连通区域
        close_kernel_size: 闭运算的核大小
        open_kernel_size: 开运算的核大小
    
    返回:
        后处理后的掩码
    """
    # 确保掩码是二值的
    binary_mask = (mask > 0.5).astype(np.uint8)
    
    # 闭运算（先膨胀后腐蚀）填充小洞
    if close_kernel_size > 0:
        close_kernel = np.ones((close_kernel_size, close_kernel_size), np.uint8)
        binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, close_kernel)
    
    # 开运算（先腐蚀后膨胀）移除小噪点
    if open_kernel_size > 0:
        open_kernel = np.ones((open_kernel_size, open_kernel_size), np.uint8)
        binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, open_kernel)
    
    # 移除小连通区域
    if min_size > 0:
        # 标记连通区域
        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
        
        # 移除小于min_size的区域
        for i in range(1, num_labels):  # 从1开始，跳过背景
            if stats[i, cv2.CC_STAT_AREA] < min_size:
                binary_mask[labels == i] = 0
    
    return binary_mask




## 定义新的dataset类

In [3]:
class OptimizedSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, patch_size=128, stride=64, is_training=True, preload=True, transform=None):
        """
        初始化数据集
        Args:
            data_list: 包含图像和标注路径的列表
            patch_size: 切片大小
            stride: 滑动窗口步长
            is_training: 是否为训练模式，决定是否应用随机预处理
            preload: 是否预加载所有patch到内存
            transform: 数据增强转换
        """
        self.data_list = data_list
        self.patch_size = patch_size
        self.stride = stride
        self.is_training = is_training
        self.preload = preload
        self.transform = transform
        self.preprocessor = ImagePreprocessor()
        
        # 存储图像索引和可能的patch位置
        self.patches_info = []
        
        # 预加载数据的存储
        if preload:
            self.all_patches = []
            self.all_masks = []
            self.all_positions = []
            self.all_original_sizes = []
            self.all_image_paths = []
            self.all_mask_paths = []
            self.image_indices = []
        
        print(f"{'训练' if is_training else '验证/测试'}数据集初始化...")
        
        # 处理每张图片
        for idx, item in enumerate(tqdm(data_list, desc="处理图像")):
            # 读取图像以获取尺寸
            image = cv2.imread(item["image"])
            if image is None:
                print(f"警告: 无法读取图像 {item['image']}")
                continue
                
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            mask = cv2.imread(item["annotation"], cv2.IMREAD_GRAYSCALE)
            if mask is None:
                print(f"警告: 无法读取掩码 {item['annotation']}")
                continue
            
            # 如果预加载，立即预处理并保存所有patch
            if preload:
                # 预处理图像
                if is_training:
                    processed_image, processed_mask, params = self.preprocessor.random_preprocess(
                        image, mask, patch_size=self.patch_size
                    )
                else:
                    processed_image, processed_mask = self.preprocessor.preprocess(
                        image, mask, patch_size=self.patch_size
                    )
                
                # 计算可能的patch位置
                h, w = processed_image.shape[:2]
                positions = self._calculate_positions(h, w)
                
                # 提取和保存所有patch
                for y, x in positions:
                    # 确保坐标不超出预处理后图像的边界
                    y_valid = min(y, h - self.patch_size)
                    x_valid = min(x, w - self.patch_size)
                    
                    # 提取patch
                    patch = processed_image[y_valid:y_valid+self.patch_size, x_valid:x_valid+self.patch_size].copy()
                    mask_patch = processed_mask[y_valid:y_valid+self.patch_size, x_valid:x_valid+self.patch_size].copy()
                    
                    # 应用变换（如果有）
                    if self.transform:
                        # 假设transform是一个可调用对象，接受图像和掩码
                        augmented = self.transform(image=patch, mask=mask_patch)
                        patch = augmented['image']
                        mask_patch = augmented['mask']
                    
                    # 转换为tensor
                    patch_tensor = torch.FloatTensor(patch.transpose(2, 0, 1)) / 255.0
                    mask_tensor = torch.FloatTensor(mask_patch).unsqueeze(0) / 255.0
                    
                    # 存储结果
                    self.all_patches.append(patch_tensor)
                    self.all_masks.append(mask_tensor)
                    self.all_positions.append((y_valid, x_valid))
                    self.all_original_sizes.append((h, w))
                    self.all_image_paths.append(item["image"])
                    self.all_mask_paths.append(item["annotation"])
                    self.image_indices.append(idx)
            else:
                # 如果不预加载，只保存图像索引和位置信息
                h, w = image.shape[:2]
                positions = self._calculate_positions(h, w)
                
                # 存储图像索引和patch位置
                for pos in positions:
                    self.patches_info.append({
                        'image_idx': idx,
                        'position': pos
                    })
        
        if preload:
            print(f"数据集预处理完成，共生成 {len(self.all_patches)} 个patch，来自 {len(data_list)} 张图像")
        else:
            print(f"数据集索引完成，共索引 {len(self.patches_info)} 个可能的patch，来自 {len(data_list)} 张图像")
        
        # 图像缓存，用于非预加载模式
        self._cached_image_idx = None
        self._cached_image = None
        self._cached_mask = None

    def _calculate_positions(self, h, w):
        """计算图像中所有有效的patch位置"""
        positions = []
        
        # 横向和纵向的滑动位置
        h_idx = np.arange(0, h-self.patch_size+1, self.stride)
        w_idx = np.arange(0, w-self.patch_size+1, self.stride)
        
        # 确保处理到边缘
        if h % self.stride != 0 and h > self.patch_size:
            h_idx = np.append(h_idx, h-self.patch_size)
        if w % self.stride != 0 and w > self.patch_size:
            w_idx = np.append(w_idx, w-self.patch_size)
        
        # 生成所有位置组合
        for y in h_idx:
            for x in w_idx:
                positions.append((int(y), int(x)))
                
        return positions
    
    def __len__(self):
        if self.preload:
            return len(self.all_patches)
        else:
            return len(self.patches_info)

    def __getitem__(self, idx):
        if self.preload:
            # 预加载模式直接返回保存的数据
            return {
                'patches': self.all_patches[idx],
                'mask_patches': self.all_masks[idx],
                'positions': self.all_positions[idx],
                'original_size': self.all_original_sizes[idx],
                'image_path': self.all_image_paths[idx],
                'mask_path': self.all_mask_paths[idx],
                'image_idx': self.image_indices[idx]
            }
        else:
            # 动态加载模式
            # 获取patch信息
            patch_info = self.patches_info[idx]
            image_idx = patch_info['image_idx']
            position = patch_info['position']
            
            # 使用图像缓存减少I/O
            if self._cached_image_idx != image_idx:
                # 获取原图和掩码路径
                image_path = self.data_list[image_idx]["image"]
                mask_path = self.data_list[image_idx]["annotation"]
                
                # 读取图像和掩码
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                
                # 使用预处理方法
                if self.is_training:
                    processed_image, processed_mask, params = self.preprocessor.random_preprocess(
                        image, mask, patch_size=self.patch_size
                    )
                else:
                    processed_image, processed_mask = self.preprocessor.preprocess(
                        image, mask, patch_size=self.patch_size
                    )
                
                # 更新缓存
                self._cached_image_idx = image_idx
                self._cached_image = processed_image
                self._cached_mask = processed_mask
                self._cached_image_path = image_path
                self._cached_mask_path = mask_path
            else:
                # 使用缓存
                processed_image = self._cached_image
                processed_mask = self._cached_mask
                image_path = self._cached_image_path
                mask_path = self._cached_mask_path
            
            # 提取patch
            y, x = position
            
            # 确保坐标不超出预处理后图像的边界
            h, w = processed_image.shape[:2]
            y = min(y, h - self.patch_size)
            x = min(x, w - self.patch_size)
            
            # 复制避免视图问题
            patch = processed_image[y:y+self.patch_size, x:x+self.patch_size].copy()
            mask_patch = processed_mask[y:y+self.patch_size, x:x+self.patch_size].copy()
            
            # 应用变换（如果有）
            if self.transform:
                augmented = self.transform(image=patch, mask=mask_patch)
                patch = augmented['image']
                mask_patch = augmented['mask']
            
            # 转换为tensor
            patch = torch.FloatTensor(patch.transpose(2, 0, 1)) / 255.0
            mask_patch = torch.FloatTensor(mask_patch).unsqueeze(0) / 255.0
            
            return {
                'patches': patch,
                'mask_patches': mask_patch,
                'positions': (y, x),
                'original_size': (h, w),
                'image_path': image_path,
                'mask_path': mask_path,
                'image_idx': image_idx
            }
    
    def get_stats(self):
        """返回数据集的统计信息"""
        if self.preload:
            unique_images = len(set(self.image_indices))
            patches_per_image = len(self.all_patches) / unique_images if unique_images > 0 else 0
            
            return {
                'total_patches': len(self.all_patches),
                'unique_images': unique_images,
                'patches_per_image': patches_per_image,
                'patch_size': self.patch_size,
                'stride': self.stride,
                'is_training': self.is_training,
                'preload': self.preload
            }
        else:
            unique_images = len(set(info['image_idx'] for info in self.patches_info))
            patches_per_image = len(self.patches_info) / unique_images if unique_images > 0 else 0
            
            return {
                'total_patches': len(self.patches_info),
                'unique_images': unique_images,
                'patches_per_image': patches_per_image,
                'patch_size': self.patch_size,
                'stride': self.stride,
                'is_training': self.is_training,
                'preload': self.preload
            }
        
    def visualize_batch(self, indices=None, num_samples=5):
        """
        可视化一批数据
        
        Args:
            indices: 要可视化的样本索引列表，如果为None则随机选择
            num_samples: 如果indices为None，要随机选择的样本数量
        """
        if indices is None:
            indices = np.random.choice(len(self), size=min(num_samples, len(self)), replace=False)
        
        for idx in indices:
            sample = self[idx]
            
            # 转换数据格式用于显示
            patch = sample['patches'].numpy()
            mask = sample['mask_patches'].numpy()
            
            patch = np.transpose(patch, (1, 2, 0))  # CHW -> HWC
            mask = np.squeeze(mask)
            
            # 还原归一化
            patch = (patch * 255).astype(np.uint8)
            mask = (mask * 255).astype(np.uint8)
            
            # 创建子图
            fig, axes = plt.subplots(1, 3, figsize=(12, 4))
            image_name = os.path.basename(sample["image_path"])
            fig.suptitle(f'Sample #{idx} from {image_name}')
            
            # 显示原始patch
            axes[0].imshow(patch)
            axes[0].set_title('Image Patch')
            axes[0].axis('off')
            
            # 显示二值mask
            axes[1].imshow(mask, cmap='gray')
            axes[1].set_title('Mask')
            axes[1].axis('off')
            
            # 显示叠加效果
            overlay = patch.copy()
            overlay[mask > 127] = [255, 0, 0]  # 用红色标注细胞区域
            axes[2].imshow(overlay)
            axes[2].set_title('Overlay')
            axes[2].axis('off')
            
            plt.tight_layout()
            plt.show()
    
    def visualize_image_patches(self, image_path, num_patches=5):
        """可视化指定图片的patches分布"""
        try:
            # 找到指定图片在数据集中的索引位置
            image_indices = [i for i, item in enumerate(self.data_list) if item["image"] == image_path]
            
            if not image_indices:
                print(f"未找到图片: {image_path}")
                return
                
            image_idx = image_indices[0]
            
            # 获取对应的mask路径
            mask_path = self.data_list[image_idx]["annotation"]
            
            # 读取原始图片和对应的mask
            original_image = cv2.imread(image_path)
            original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
            
            original_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            
            # 进行预处理（与__getitem__中相同）
            if self.is_training:
                processed_image, processed_mask, _ = self.preprocessor.random_preprocess(
                    original_image, original_mask, patch_size=self.patch_size
                )
            else:
                processed_image, processed_mask = self.preprocessor.preprocess(
                    original_image, original_mask, patch_size=self.patch_size
                )
            
            # 找到这个图像的所有patch索引
            patch_indices = [i for i, info in enumerate(self.patches_info) 
                            if info['image_idx'] == image_idx]
            
            if not patch_indices:
                print(f"未找到图片的patches: {image_path}")
                return
            
            # 随机选择要显示的patches
            if len(patch_indices) <= num_patches:
                selected_indices = patch_indices
            else:
                selected_indices = random.sample(patch_indices, num_patches)
                
            num_selected = len(selected_indices)
            
            # 创建图像网格
            plt.figure(figsize=(20, 8))
            
            # 显示处理后的图像和patches的位置
            plt.subplot(1, 2, 1)
            plt.imshow(processed_image)
            plt.title("Processed Image with Patch Locations")
            
            # 显示处理后的mask和patches的位置
            plt.subplot(1, 2, 2)
            plt.imshow(processed_mask, cmap='gray')
            plt.title("Processed Mask with Patch Locations")
            
            # 在图像和mask上标注patch位置
            colors = plt.cm.rainbow(np.linspace(0, 1, num_selected))
            
            for subplot_idx in [1, 2]:
                plt.subplot(1, 2, subplot_idx)
                
                for i, idx in enumerate(selected_indices):
                    patch_info = self.patches_info[idx]
                    y, x = patch_info['position']
                    
                    # 确保坐标不超出边界
                    h, w = processed_image.shape[:2]
                    y = min(y, h - self.patch_size)
                    x = min(x, w - self.patch_size)
                    
                    rect = plt.Rectangle(
                        xy=(x, y),
                        width=self.patch_size,
                        height=self.patch_size,
                        fill=False,
                        color=colors[i],
                        linewidth=2
                    )
                    plt.gca().add_patch(rect)
                    
                    plt.text(
                        x, y,
                        str(i+1), 
                        color=colors[i], 
                        fontsize=12, 
                        bbox=dict(facecolor='white', alpha=0.7)
                    )
                plt.axis('off')
            
            plt.tight_layout()
            plt.show()
            
            # 显示每个patch的详细信息
            for i, idx in enumerate(selected_indices):
                sample = self[idx]
                
                # 转换数据格式
                patch = sample['patches'].numpy()
                mask = sample['mask_patches'].numpy()
                
                patch = np.transpose(patch, (1, 2, 0))
                mask = np.squeeze(mask)
                
                # 还原归一化
                patch = (patch * 255).astype(np.uint8)
                mask = (mask * 255).astype(np.uint8)
                
                # 创建子图
                fig, axes = plt.subplots(1, 3, figsize=(12, 4))
                fig.suptitle(f'Patch {i+1}')
                
                # 显示原始patch
                axes[0].imshow(patch)
                axes[0].set_title('Image Patch')
                axes[0].axis('off')
                
                # 显示二值mask
                axes[1].imshow(mask, cmap='gray')
                axes[1].set_title('Mask')
                axes[1].axis('off')
                
                # 显示叠加效果
                overlay = patch.copy()
                overlay[mask > 127] = [255, 0, 0]
                axes[2].imshow(overlay)
                axes[2].set_title('Overlay')
                axes[2].axis('off')
                
                plt.tight_layout()
                plt.show()
                
        except Exception as e:
            print(f"可视化过程出错: {str(e)}")
            import traceback
            traceback.print_exc()
            

def get_optimized_loaders(train_dataset, val_dataset, batch_size=32, num_workers=4, prefetch_factor=2):
    """
    创建优化的数据加载器
    
    参数:
        train_dataset: 训练数据集对象
        val_dataset: 验证数据集对象
        batch_size: 批量大小
        num_workers: 数据加载工作线程数
        prefetch_factor: 预取因子
        
    返回:
        train_loader: 训练数据加载器
        val_loader: 验证数据加载器
    """
    # 自动检测CPU核心数并设置工作线程
    if num_workers <= 0:
        import os
        num_workers = min(os.cpu_count(), 8)
        print(f"自动设置工作线程数: {num_workers}")
    
    # 设置共享内存管理策略，避免共享内存问题
    import torch
    torch.multiprocessing.set_sharing_strategy('file_system')
    
    # 创建训练数据加载器
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,  # 训练时随机打乱数据
        num_workers=num_workers,  # 并行加载数据
        pin_memory=True,  # 使用固定内存加速GPU传输
        prefetch_factor=prefetch_factor,  # 预取因子
        persistent_workers=True if num_workers > 0 else False,  # 保持工作线程存活
        drop_last=True  # 丢弃不完整的最后一个batch，避免批归一化问题
    )
    
    # 创建验证数据加载器
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False,  # 验证时不打乱顺序
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=prefetch_factor,
        persistent_workers=True if num_workers > 0 else False,
        drop_last=False  # 验证时保留所有样本
    )
    
    # 打印加载器信息
    print(f"数据加载器创建完成:")
    print(f"- 训练集: {len(train_dataset)} 个样本, {len(train_loader)} 个批次")
    print(f"- 验证集: {len(val_dataset)} 个样本, {len(val_loader)} 个批次")
    
    return train_loader, val_loader






## 开始部分

In [None]:
train_data, val_data = prepare_dataset("Kasthuri++")
# train_data, val_data = prepare_dataset("Lucchi++")

# 第一次运行时：预处理并保存
train_dataset = ImprovedSegmentationDataset(
    data_list=train_data,
    patch_size=256,
    stride=128,
    preProcess=True,  # 启用预处理

)

val_dataset = ImprovedSegmentationDataset(
    data_list=val_data,
    patch_size=256,
    stride=128,
    preProcess=True,  # 启用预处理

)

train_loader, val_loader = get_optimized_loaders(
    train_dataset, 
    val_dataset, 
    batch_size=32,  # 增加批大小
    num_workers=4
)



In [None]:
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = ImprovedUNet(num_classes=1).to(device)

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import time
import cv2

# 设置随机种子以确保结果可复现
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    print(f"随机种子已设置为 {seed}")

# 设置随机种子
set_seed(42)

# 检查并创建保存模型和结果的目录
results_dir = "results"
models_dir = os.path.join(results_dir, "models")
plots_dir = os.path.join(results_dir, "plots")

os.makedirs(results_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)

# 选择数据集
dataset_name = "Kasthuri++"  # 或 "Lucchi++"
print(f"使用数据集: {dataset_name}")

# 准备数据
train_data, val_data = prepare_dataset(dataset_name)
print(f"训练数据: {len(train_data)} 张图像")
print(f"验证数据: {len(val_data)} 张图像")

# 创建优化的数据集，使用预加载模式
train_dataset = OptimizedSegmentationDataset(
    data_list=train_data,
    patch_size=256,
    stride=128,
    is_training=True,
    preload=True  # 预加载所有数据到内存以提高速度
)

val_dataset = OptimizedSegmentationDataset(
    data_list=val_data,
    patch_size=256,
    stride=128,
    is_training=True,
    preload=True  # 预加载所有数据到内存以提高速度
)

# 打印数据集统计信息
train_stats = train_dataset.get_stats()
val_stats = val_dataset.get_stats()
print(f"训练数据集统计: {train_stats}")
print(f"验证数据集统计: {val_stats}")

# 创建数据加载器
train_loader, val_loader = get_optimized_loaders(
    train_dataset, 
    val_dataset, 
    batch_size=32,
    num_workers=4,
    prefetch_factor=2
)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 初始化模型
model = ImprovedUNet(num_classes=1, dropout_rate=0.2)
model = model.to(device)
print(f"模型已创建并移至 {device}")

# 打印模型总结
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"模型参数: 总计 {total_params/1e6:.2f}M, 可训练 {trainable_params/1e6:.2f}M")

# 设置训练参数
num_epochs = 50
learning_rate = 1e-4
model_save_path = os.path.join(models_dir, f"{dataset_name}_unet_model.pth")

# 修改后的训练函数使用非tqdm进度条
def train_and_validate_model(model, train_loader, val_loader=None, num_epochs=50, 
                           learning_rate=1e-4, device='cuda', save_path='best_model.pth'):
    """
    训练和验证模型 - 简化无进度条版本
    """
    # 移动模型到指定设备
    model = model.to(device)
    
    # 设置优化器和损失函数
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    # 创建学习率调度器
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )
    
    # 记录训练历史
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_dice': [],
        'val_dice': [],
        'train_iou': [],
        'val_iou': [],
        'lr': []
    }
    
    # 用于早停和保存最佳模型
    best_val_metric = float('inf')
    patience = 10
    counter = 0
    
    print(f"开始训练，共 {num_epochs} 个轮次")
    
    # 训练循环
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        # 训练阶段
        model.train()
        train_loss = 0.0
        train_dice_scores = []
        train_iou_scores = []
        
        total_batches = len(train_loader)
        batch_count = 0
        
        # 手动处理批次
        for batch in train_loader:
            # 获取输入和目标
            inputs = batch['patches'].to(device)
            targets = batch['mask_patches'].to(device)
            
            # 梯度清零
            optimizer.zero_grad()
            
            # 前向传播
            outputs = model(inputs)
            
            # 计算损失
            loss = criterion(outputs, targets)
            
            # 反向传播和优化
            loss.backward()
            optimizer.step()
            
            # 累加损失
            train_loss += loss.item()
            
            # 计算训练指标
            with torch.no_grad():
                preds = (torch.sigmoid(outputs) > 0.5).float()
                iou, dice = calculate_metrics(preds, targets)
                train_dice_scores.append(dice)
                train_iou_scores.append(iou)
            
            batch_count += 1
            
            # 显示部分进度
            if batch_count % 10 == 0:
                print(f"Epoch {epoch+1}/{num_epochs} - 批次 {batch_count}/{total_batches}")
        
        # 计算平均训练指标
        train_loss /= total_batches
        train_dice = sum(train_dice_scores) / len(train_dice_scores)
        train_iou = sum(train_iou_scores) / len(train_iou_scores)
        
        # 记录当前学习率
        current_lr = optimizer.param_groups[0]['lr']
        
        # 保存训练指标
        history['train_loss'].append(train_loss)
        history['train_dice'].append(train_dice)
        history['train_iou'].append(train_iou)
        history['lr'].append(current_lr)
        
        # 验证阶段
        if val_loader is not None:
            val_loss, val_dice, val_iou = validate_model(model, val_loader, criterion, device)
            history['val_loss'].append(val_loss)
            history['val_dice'].append(val_dice)
            history['val_iou'].append(val_iou)
            
            # 调整学习率
            old_lr = optimizer.param_groups[0]['lr']
            scheduler.step(val_loss)
            new_lr = optimizer.param_groups[0]['lr']
            
            # 手动检测学习率变化并打印通知
            if new_lr < old_lr:
                print(f"学习率从 {old_lr:.6f} 减小到 {new_lr:.6f}")
            
            # 早停和保存最佳模型
            if val_loss < best_val_metric:
                best_val_metric = val_loss
                counter = 0
                # 保存最佳模型
                torch.save(model.state_dict(), save_path)
                print(f"轮次 {epoch+1}/{num_epochs} - "
                     f"训练损失: {train_loss:.4f}, 训练Dice: {train_dice:.4f}, "
                     f"验证损失: {val_loss:.4f}, 验证Dice: {val_dice:.4f} - "
                     f"耗时: {time.time() - epoch_start_time:.1f}秒 - 保存最佳模型")
            else:
                counter += 1
                print(f"轮次 {epoch+1}/{num_epochs} - "
                     f"训练损失: {train_loss:.4f}, 训练Dice: {train_dice:.4f}, "
                     f"验证损失: {val_loss:.4f}, 验证Dice: {val_dice:.4f} - "
                     f"耗时: {time.time() - epoch_start_time:.1f}秒")
                
                if counter >= patience:
                    print(f"早停: {epoch+1} 轮后未见改善")
                    break
        else:
            # 如果没有验证集，每个epoch都保存模型
            torch.save(model.state_dict(), save_path)
            print(f"轮次 {epoch+1}/{num_epochs} - "
                 f"训练损失: {train_loss:.4f}, 训练Dice: {train_dice:.4f} - "
                 f"耗时: {time.time() - epoch_start_time:.1f}秒")
    
    # 如果有验证集，加载最佳模型
    if val_loader is not None and os.path.exists(save_path):
        model.load_state_dict(torch.load(save_path))
    
    return model, history

# 简化的验证函数
def validate_model(model, val_loader, criterion, device):
    """在验证集上评估模型 - 简化无进度条版本"""
    model.eval()
    val_loss = 0.0
    dice_scores = []
    iou_scores = []
    
    with torch.no_grad():
        for batch in val_loader:
            # 获取输入和目标
            inputs = batch['patches'].to(device)
            targets = batch['mask_patches'].to(device)
            
            # 前向传播
            outputs = model(inputs)
            
            # 计算损失
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            
            # 计算评估指标
            preds = (torch.sigmoid(outputs) > 0.5).float()
            iou, dice = calculate_metrics(preds, targets)
            dice_scores.append(dice)
            iou_scores.append(iou)
    
    # 计算平均损失和评估指标
    val_loss /= len(val_loader)
    val_dice = sum(dice_scores) / len(dice_scores)
    val_iou = sum(iou_scores) / len(iou_scores)
    
    return val_loss, val_dice, val_iou

# 训练模型
print(f"开始训练模型，总计 {num_epochs} 轮...")
start_time = time.time()

model, history = train_and_validate_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    device=device,
    save_path=model_save_path
)

# 计算总训练时间
total_time = time.time() - start_time
hours = int(total_time // 3600)
minutes = int((total_time % 3600) // 60)
seconds = int(total_time % 60)
print(f"训练完成，总耗时: {hours}小时 {minutes}分钟 {seconds}秒")

# 绘制并保存训练历史
history_plot_path = os.path.join(plots_dir, f"{dataset_name}_training_history.png")
plot_training_history(history, save_path=history_plot_path)