In [1]:
import os
import cv2
import numpy as np
import xml.etree.ElementTree as ET
from sklearn.model_selection import train_test_split
import torch
import pandas as pd 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import albumentations as A

Unet训练

In [4]:

# 1. 自定义数据集类
from torch import GradScaler
from tqdm import tqdm


class FoveaDataset(Dataset):
    def __init__(self, image_dir, xml_dir, fovea_csv=None, transform=None, is_test=False):
        self.image_dir = image_dir
        self.xml_dir = xml_dir
        self.transform = A.Compose([  # 使用albumentations
            A.CLAHE(clip_limit=3.0, tile_grid_size=(8, 8)),
            A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0)),
        ])
        # self.transform = transform
        self.is_test = is_test
        self.images = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        if fovea_csv is not None:
            self.fovea_coords = pd.read_csv(fovea_csv)
        else:
            self.fovea_coords = None
         
        # 预处理所有图像
        self.processed_images = {}
        print("预处理图像...")
        for img_name in tqdm(self.images):
            img_path = os.path.join(image_dir, img_name)
            image = cv2.imread(img_path)
            enhanced_image = self.enhance_image(image)
            self.processed_images[img_name] = enhanced_image
            
    
    def enhance_image(self, image):
        """简化的图像增强处理"""
        lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
        l = clahe.apply(l)
        enhanced_lab = cv2.merge((l,a,b))
        enhanced_image = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2BGR)
        
        return enhanced_image
    
    def __len__(self):  # 添加这个方法
        return len(self.images)
            
    def get_fovea_coords(self, img_name):
        if self.fovea_coords is not None:
            img_id = int(img_name.split('.')[0])
            coords = self.fovea_coords[self.fovea_coords['data'] == img_id]
            if not coords.empty:
                return (coords['Fovea_X'].values[0], coords['Fovea_Y'].values[0])
        return None

    def create_fovea_heatmap(self, size, coords, original_size):
        """创建中心凹位置的高斯热图"""
        x, y = coords
        # 调整坐标到调整大小后的图像尺寸
        x = int(x * size[0] / original_size[1])
        y = int(y * size[1] / original_size[0])
        
        heatmap = np.zeros(size)
        y = min(max(y, 0), size[1]-1)
        x = min(max(x, 0), size[0]-1)
        
        # 创建高斯核
        sigma = 5
        kernel_size = 6 * sigma + 1
        x_grid, y_grid = np.meshgrid(np.arange(size[0]), np.arange(size[1]))
        heatmap = np.exp(-((x_grid - x)**2 + (y_grid - y)**2) / (2 * sigma**2))
        heatmap = heatmap / heatmap.max()  # 归一化
        
        return heatmap

    def parse_xml(self, xml_path):
        tree = ET.parse(xml_path)
        root = tree.getroot()
        
        size = root.find('size')
        width = int(size.find('width').text)
        height = int(size.find('height').text)
        
        obj = root.find('object')
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)
        
        return (width, height), (xmin, ymin, xmax, ymax)
    
    def create_mask(self, img_shape, bbox):
        mask = np.zeros(img_shape[:2], dtype=np.float32)
        xmin, ymin, xmax, ymax = bbox
        mask[ymin:ymax, xmin:xmax] = 1
        return mask
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        # 直接使用预处理好的图像
        image = self.processed_images[img_name]
        
        img_path = os.path.join(self.image_dir, img_name)
        xml_path = os.path.join(self.xml_dir, img_name.replace('.jpg', '.xml'))
        
        # 读取图像
        image = cv2.imread(img_path)
        
        image = self.enhance_image(image)  # 添加图像增强
        
        orig_size = image.shape[:2]
        
        # 解析XML
        _, bbox = self.parse_xml(xml_path)
        mask = self.create_mask(orig_size, bbox)
        
        # 获取中心凹坐标并创建热图
        coords = self.get_fovea_coords(img_name)
        
        # 调整大小
        image = cv2.resize(image, (256, 256))
        mask = cv2.resize(mask, (256, 256))
        
        # 转换为张量
        image = image.transpose(2, 0, 1) / 255.0
        image = torch.FloatTensor(image)
        mask = torch.FloatTensor(mask).unsqueeze(0)
        
        if coords is not None:
            heatmap = self.create_fovea_heatmap((256, 256), coords, orig_size)
            heatmap = torch.FloatTensor(heatmap).unsqueeze(0)
            return image, mask, heatmap, coords
        
        return image, mask, None, None




# 2. 定义U-Net模型
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
    


class UNetWithFovea(nn.Module):
    def __init__(self):
        super().__init__()
        # 原有的U-Net编码器部分
        self.enc1 = DoubleConv(3, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        # 原有的U-Net解码器部分
        self.dec3 = DoubleConv(512 + 256, 256)
        self.dec2 = DoubleConv(256 + 128, 128)
        self.dec1 = DoubleConv(128 + 64, 64)
        
        # 分割分支
        self.final_conv_seg = nn.Conv2d(64, 1, kernel_size=1)
        
        # 中心凹预测分支
        self.final_conv_fovea = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Decoder
        d3 = self.dec3(torch.cat([self.upsample(e4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.upsample(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.upsample(d2), e1], dim=1))
        
        # 两个输出分支
        mask = torch.sigmoid(self.final_conv_seg(d1))
        fovea_heatmap = self.final_conv_fovea(d1)
        
        return mask, fovea_heatmap

def train_model(model, train_loader, val_loader, device, num_epochs=100):
    criterion_mask = nn.BCELoss()
    criterion_fovea = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_loss = 0
        for images, masks, heatmaps, _ in train_loader:
            images = images.to(device)
            masks = masks.to(device)
            if heatmaps is not None:
                heatmaps = heatmaps.to(device)
            
            optimizer.zero_grad()
            mask_pred, fovea_pred = model(images)
            
            loss_mask = criterion_mask(mask_pred, masks)
            loss = loss_mask
            
            if heatmaps is not None:
                loss_fovea = criterion_fovea(fovea_pred, heatmaps)
                loss += loss_fovea
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # 验证阶段
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, masks, heatmaps, _ in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                if heatmaps is not None:
                    heatmaps = heatmaps.to(device)
                
                mask_pred, fovea_pred = model(images)
                loss = criterion_mask(mask_pred, masks)
                
                if heatmaps is not None:
                    loss += criterion_fovea(fovea_pred, heatmaps)
                
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')

def pad_image(image, target_size):
    """填充图像到目标大小"""
    h, w = image.shape[:2]
    target_h, target_w = target_size
    
    # 计算需要的填充
    pad_h = max(0, target_h - h)
    pad_w = max(0, target_w - w)
    
    # 计算上下左右填充量
    top = pad_h // 2
    bottom = pad_h - top
    left = pad_w // 2
    right = pad_w - left
    
    # 填充图像
    padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
                                     cv2.BORDER_CONSTANT, value=(0,0,0))
    return padded_image

def predict_test_images(model, test_dir, device, output_csv):
    model.eval()
    os.makedirs('predictions', exist_ok=True)
    
    results = []
    
    for img_name in sorted(os.listdir(test_dir)):
        if img_name.endswith('.jpg'):
            img_path = os.path.join(test_dir, img_name)
            image = cv2.imread(img_path)
            original_size = image.shape[:2]
            
            # 图像增强
            enhanced_image = enhance_image(image)
            
            # 多尺度预测
            scales = [0.75, 1.0, 1.25]
            fovea_heatmaps = []
            
            for scale in scales:
                # 计算缩放尺寸
                scaled_size = (int(256 * scale), int(256 * scale))
                image_scaled = cv2.resize(enhanced_image, scaled_size)
                
                # 安全地填充到256x256
                if scale != 1.0:
                    image_scaled = pad_image(image_scaled, (256, 256))
                
                # 确保图像大小正确
                image_scaled = cv2.resize(image_scaled, (256, 256))
                
                # 转换为张量
                image_tensor = torch.FloatTensor(image_scaled.transpose(2, 0, 1) / 255.0).unsqueeze(0)
                
                # 预测
                with torch.no_grad():
                    image_tensor = image_tensor.to(device)
                    _, fovea_heatmap = model(image_tensor)
                    fovea_heatmap = fovea_heatmap.cpu().numpy()[0, 0]
                    fovea_heatmaps.append(fovea_heatmap)
            
            # 融合多尺度预测结果
            fovea_heatmap = np.mean(fovea_heatmaps, axis=0)
            
            # 使用高斯滤波平滑热图
            fovea_heatmap = cv2.GaussianBlur(fovea_heatmap, (7, 7), 0)
            
            # 找到热图中的最大值位置
            y, x = np.unravel_index(np.argmax(fovea_heatmap), fovea_heatmap.shape)
            
            # 将坐标转换回原始图像大小
            original_x = int(x * original_size[1] / 256)
            original_y = int(y * original_size[0] / 256)
            
            # 验证预测结果
            if validate_prediction(original_x, original_y, original_size):
                results.extend([
                    {'ImageID': f'{img_name.split(".")[0]}_Fovea_X', 'value': original_x},
                    {'ImageID': f'{img_name.split(".")[0]}_Fovea_Y', 'value': original_y}
                ])
            else:
                backup_x, backup_y = get_backup_prediction(image)
                results.extend([
                    {'ImageID': f'{img_name.split(".")[0]}_Fovea_X', 'value': backup_x},
                    {'ImageID': f'{img_name.split(".")[0]}_Fovea_Y', 'value': backup_y}
                ])
            
            # 可视化结果
            plt.figure(figsize=(12, 4))
            plt.subplot(131)
            plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            plt.plot(original_x, original_y, 'r+', markersize=10)
            plt.title('Predicted Fovea Location')
            
            plt.subplot(132)
            plt.imshow(fovea_heatmap, cmap='jet')
            plt.title('Fovea Heatmap')
            
            plt.subplot(133)
            plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            plt.imshow(cv2.resize(fovea_heatmap, (original_size[1], original_size[0])), 
                      alpha=0.3, cmap='jet')
            plt.plot(original_x, original_y, 'r+', markersize=10)
            plt.title('Overlay')
            
            plt.savefig(f'predictions/{img_name}_prediction.png')
            plt.close()
    
    # 保存预测结果到CSV
    df = pd.DataFrame(results)
    df.to_csv(output_csv, index=False)


def enhance_image(image):
    """图像增强处理"""
    # 对比度增强
    lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
    l = clahe.apply(l)
    enhanced_lab = cv2.merge((l,a,b))
    enhanced_image = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2BGR)
    
    # 锐化
    kernel = np.array([[-1,-1,-1],
                      [-1, 9,-1],
                      [-1,-1,-1]])
    sharpened = cv2.filter2D(enhanced_image, -1, kernel)
    
    # 去噪
    denoised = cv2.fastNlMeansDenoisingColored(sharpened)
    
    return denoised

def validate_prediction(x, y, image_size):
    """验证预测的坐标是否在合理范围内"""
    # 检查坐标是否在图像范围内
    if x < 0 or x >= image_size[1] or y < 0 or y >= image_size[0]:
        return False
    
    # 检查坐标是否在图像中心区域附近
    center_x = image_size[1] // 2
    center_y = image_size[0] // 2
    max_distance = min(image_size) * 0.4  # 假设中心凹不会距离中心太远
    
    distance = np.sqrt((x - center_x)**2 + (y - center_y)**2)
    return distance <= max_distance

def get_backup_prediction(image):
    """当主要预测失败时的备选预测策略"""
    # 这里可以实现一个基于传统图像处理的备选方法
    # 例如：使用图像的中心点或基于解剖学特征的启发式方法
    
    # 简单示例：返回图像中心点
    h, w = image.shape[:2]
    return w//2, h//2

def visualize_prediction(original_image, enhanced_image, heatmap, pred_x, pred_y, img_name):
    """可视化预测结果"""
    plt.figure(figsize=(15, 5))
    
    # 原始图像
    plt.subplot(141)
    plt.imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
    plt.plot(pred_x, pred_y, 'r+', markersize=10)
    plt.title('Original Image')
    
    # 增强后的图像
    plt.subplot(142)
    plt.imshow(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
    plt.plot(pred_x, pred_y, 'r+', markersize=10)
    plt.title('Enhanced Image')
    
    # 热图
    plt.subplot(143)
    plt.imshow(heatmap, cmap='jet')
    plt.title('Fovea Heatmap')
    
    # 叠加显示
    plt.subplot(144)
    plt.imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
    plt.imshow(cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0])), 
              alpha=0.3, cmap='jet')
    plt.plot(pred_x, pred_y, 'r+', markersize=10)
    plt.title('Overlay')
    
    plt.savefig(f'predictions/{img_name}_prediction.png')
    plt.close()


def main():
    # 设置路径
    train_image_dir = 'C:/code/vcpython/ML_design_1/task1/detection/train'
    train_xml_dir = 'C:/code/vcpython/ML_design_1/task1/detection/train_location'
    fovea_csv = 'C:/code/vcpython/ML_design_1/task1/detection/fovea_localization_train_GT.csv'  # 添加中心凹坐标文件路径
    test_image_dir = 'C:/code/vcpython/ML_design_1/task1/detection/test'
    output_csv = 'fovea_predictions.csv'
    
    torch.backends.cudnn.benchmark = True  # 加速卷积运算
    torch.backends.cuda.matmul.allow_tf32 = True  # 允许使用TF32
    torch.backends.cudnn.allow_tf32 = True  # 允许cudnn使用TF32

    # 检查CUDA是否可用
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'使用设备: {device}')
    if torch.cuda.is_available():
        print(f'GPU名称: {torch.cuda.get_device_name(0)}')
        print(f'GPU显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 2:.0f}MB')
    
    # 创建数据集
    dataset = FoveaDataset(train_image_dir, train_xml_dir, fovea_csv)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    # 创建数据加载器
    batch_size = 16  # 或者更大，具体取决于你的GPU内存
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size,
        num_workers=0,
        pin_memory=True
    )
    
    # 设置设备
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(torch.cuda.is_available())
    device = torch.device('cuda:0')
    # 创建模型
    model = UNetWithFovea().to(device)
    
    print("begin training")
    # 训练模型
    # train_model(model, train_loader, val_loader, device)
    
    # 加载最佳模型进行预测
    model.load_state_dict(torch.load('best_model.pth'))
    predict_test_images(model, test_image_dir, device, output_csv)

if __name__ == '__main__':
    main()


使用设备: cuda
GPU名称: NVIDIA GeForce RTX 3050 Laptop GPU
GPU显存总量: 4096MB
预处理图像...


100%|██████████| 80/80 [00:05<00:00, 15.36it/s]
  model.load_state_dict(torch.load('best_model.pth'))


True
begin training


查看GPU

In [18]:
import torch

print("CUDA是否可用:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("当前CUDA设备号:", torch.cuda.current_device())
    print("CUDA设备名称:", torch.cuda.get_device_name(0))


CUDA是否可用: True
当前CUDA设备号: 0
CUDA设备名称: NVIDIA GeForce RTX 3050 Laptop GPU


In [17]:
import torch
import os
print(os.environ.get('CUDA_VISIBLE_DEVICES'))



None


查看模型参数

In [4]:
import torch

checkpoint = torch.load('best_model.pth')
for name, param in checkpoint.items():
    print(name, param.shape)


enc1.double_conv.0.weight torch.Size([64, 3, 3, 3])
enc1.double_conv.0.bias torch.Size([64])
enc1.double_conv.1.weight torch.Size([64])
enc1.double_conv.1.bias torch.Size([64])
enc1.double_conv.1.running_mean torch.Size([64])
enc1.double_conv.1.running_var torch.Size([64])
enc1.double_conv.1.num_batches_tracked torch.Size([])
enc1.double_conv.3.weight torch.Size([64, 64, 3, 3])
enc1.double_conv.3.bias torch.Size([64])
enc1.double_conv.4.weight torch.Size([64])
enc1.double_conv.4.bias torch.Size([64])
enc1.double_conv.4.running_mean torch.Size([64])
enc1.double_conv.4.running_var torch.Size([64])
enc1.double_conv.4.num_batches_tracked torch.Size([])
enc2.double_conv.0.weight torch.Size([128, 64, 3, 3])
enc2.double_conv.0.bias torch.Size([128])
enc2.double_conv.1.weight torch.Size([128])
enc2.double_conv.1.bias torch.Size([128])
enc2.double_conv.1.running_mean torch.Size([128])
enc2.double_conv.1.running_var torch.Size([128])
enc2.double_conv.1.num_batches_tracked torch.Size([])
enc2.dou

  checkpoint = torch.load('best_model.pth')
