In [4]:
import os
import cv2
import numpy as np
import xml.etree.ElementTree as ET
from sklearn.model_selection import train_test_split
import torch
import pandas as pd 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


Unet训练

In [6]:

# 1. 自定义数据集类
class FoveaDataset(Dataset):
    def __init__(self, image_dir, xml_dir, fovea_csv=None, transform=None):
        self.image_dir = image_dir
        self.xml_dir = xml_dir
        self.transform = transform
        self.images = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        
        # 加载中心凹坐标
        if fovea_csv is not None:
            self.fovea_coords = pd.read_csv(fovea_csv)
        else:
            self.fovea_coords = None
    
    def __len__(self):  # 添加这个方法
        return len(self.images)
            
    def get_fovea_coords(self, img_name):
        if self.fovea_coords is not None:
            img_id = int(img_name.split('.')[0])
            coords = self.fovea_coords[self.fovea_coords['data'] == img_id]
            if not coords.empty:
                return (coords['Fovea_X'].values[0], coords['Fovea_Y'].values[0])
        return None

    def create_fovea_heatmap(self, size, coords, original_size):
        """创建中心凹位置的高斯热图"""
        x, y = coords
        # 调整坐标到调整大小后的图像尺寸
        x = int(x * size[0] / original_size[1])
        y = int(y * size[1] / original_size[0])
        
        heatmap = np.zeros(size)
        y = min(max(y, 0), size[1]-1)
        x = min(max(x, 0), size[0]-1)
        
        # 创建高斯核
        sigma = 5
        kernel_size = 6 * sigma + 1
        x_grid, y_grid = np.meshgrid(np.arange(size[0]), np.arange(size[1]))
        heatmap = np.exp(-((x_grid - x)**2 + (y_grid - y)**2) / (2 * sigma**2))
        heatmap = heatmap / heatmap.max()  # 归一化
        
        return heatmap

    def parse_xml(self, xml_path):
        tree = ET.parse(xml_path)
        root = tree.getroot()
        
        size = root.find('size')
        width = int(size.find('width').text)
        height = int(size.find('height').text)
        
        obj = root.find('object')
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)
        
        return (width, height), (xmin, ymin, xmax, ymax)
    
    def create_mask(self, img_shape, bbox):
        mask = np.zeros(img_shape[:2], dtype=np.float32)
        xmin, ymin, xmax, ymax = bbox
        mask[ymin:ymax, xmin:xmax] = 1
        return mask
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        xml_path = os.path.join(self.xml_dir, img_name.replace('.jpg', '.xml'))
        
        # 读取图像
        image = cv2.imread(img_path)
        orig_size = image.shape[:2]
        
        # 解析XML
        _, bbox = self.parse_xml(xml_path)
        mask = self.create_mask(orig_size, bbox)
        
        # 获取中心凹坐标并创建热图
        coords = self.get_fovea_coords(img_name)
        
        # 调整大小
        image = cv2.resize(image, (256, 256))
        mask = cv2.resize(mask, (256, 256))
        
        # 转换为张量
        image = image.transpose(2, 0, 1) / 255.0
        image = torch.FloatTensor(image)
        mask = torch.FloatTensor(mask).unsqueeze(0)
        
        if coords is not None:
            heatmap = self.create_fovea_heatmap((256, 256), coords, orig_size)
            heatmap = torch.FloatTensor(heatmap).unsqueeze(0)
            return image, mask, heatmap, coords
        
        return image, mask, None, None




# 2. 定义U-Net模型
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
    


class UNetWithFovea(nn.Module):
    def __init__(self):
        super().__init__()
        # 原有的U-Net编码器部分
        self.enc1 = DoubleConv(3, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        # 原有的U-Net解码器部分
        self.dec3 = DoubleConv(512 + 256, 256)
        self.dec2 = DoubleConv(256 + 128, 128)
        self.dec1 = DoubleConv(128 + 64, 64)
        
        # 分割分支
        self.final_conv_seg = nn.Conv2d(64, 1, kernel_size=1)
        
        # 中心凹预测分支
        self.final_conv_fovea = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Decoder
        d3 = self.dec3(torch.cat([self.upsample(e4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.upsample(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.upsample(d2), e1], dim=1))
        
        # 两个输出分支
        mask = torch.sigmoid(self.final_conv_seg(d1))
        fovea_heatmap = self.final_conv_fovea(d1)
        
        return mask, fovea_heatmap


def train_model(model, train_loader, val_loader, device, num_epochs=100):
    criterion_mask = nn.BCELoss()
    criterion_fovea = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_loss = 0
        for images, masks, heatmaps, _ in train_loader:
            images = images.to(device)
            masks = masks.to(device)
            if heatmaps is not None:
                heatmaps = heatmaps.to(device)
            
            optimizer.zero_grad()
            mask_pred, fovea_pred = model(images)
            
            loss_mask = criterion_mask(mask_pred, masks)
            loss = loss_mask
            
            if heatmaps is not None:
                loss_fovea = criterion_fovea(fovea_pred, heatmaps)
                loss += loss_fovea
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # 验证阶段
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, masks, heatmaps, _ in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                if heatmaps is not None:
                    heatmaps = heatmaps.to(device)
                
                mask_pred, fovea_pred = model(images)
                loss = criterion_mask(mask_pred, masks)
                
                if heatmaps is not None:
                    loss += criterion_fovea(fovea_pred, heatmaps)
                
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')

def predict_test_images(model, test_dir, device, output_csv):
    model.eval()
    os.makedirs('predictions', exist_ok=True)
    
    results = []
    for img_name in sorted(os.listdir(test_dir)):
        if img_name.endswith('.jpg'):
            img_path = os.path.join(test_dir, img_name)
            image = cv2.imread(img_path)
            original_size = image.shape[:2]
            
            # 预处理
            image_resized = cv2.resize(image, (256, 256))
            image_tensor = torch.FloatTensor(image_resized.transpose(2, 0, 1) / 255.0).unsqueeze(0)
            
            # 预测
            with torch.no_grad():
                image_tensor = image_tensor.to(device)
                _, fovea_heatmap = model(image_tensor)
                fovea_heatmap = fovea_heatmap.cpu().numpy()[0, 0]
            
            # 找到热图中的最大值位置
            y, x = np.unravel_index(np.argmax(fovea_heatmap), fovea_heatmap.shape)
            
            # 将坐标转换回原始图像大小
            original_x = int(x * original_size[1] / 256)
            original_y = int(y * original_size[0] / 256)
            
            # 保存结果
            img_id = img_name.split('.')[0]
            results.extend([
                {'ImageID': f'{img_id}_Fovea_X', 'value': original_x},
                {'ImageID': f'{img_id}_Fovea_Y', 'value': original_y}
            ])
            
            # 可视化结果
            plt.figure(figsize=(12, 4))
            plt.subplot(131)
            plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            plt.plot(original_x, original_y, 'r+', markersize=10)
            plt.title('Predicted Fovea Location')
            plt.subplot(132)
            plt.imshow(fovea_heatmap, cmap='jet')
            plt.title('Fovea Heatmap')
            plt.subplot(133)
            plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            plt.imshow(cv2.resize(fovea_heatmap, (original_size[1], original_size[0])), 
                      alpha=0.3, cmap='jet')
            plt.plot(original_x, original_y, 'r+', markersize=10)
            plt.title('Overlay')
            plt.savefig(f'predictions/{img_name}_prediction.png')
            plt.close()
    
    # 保存预测结果到CSV
    df = pd.DataFrame(results)
    df.to_csv(output_csv, index=False)

def main():
    # 设置路径
    train_image_dir = 'C:/code/vcpython/ML_design_1/task1/detection/train'
    train_xml_dir = 'C:/code/vcpython/ML_design_1/task1/detection/train_location'
    fovea_csv = 'C:/code/vcpython/ML_design_1/task1/detection/fovea_localization_train_GT.csv'  # 添加中心凹坐标文件路径
    test_image_dir = 'C:/code/vcpython/ML_design_1/task1/detection/test'
    output_csv = 'fovea_predictions.csv'
    
    # 创建数据集
    dataset = FoveaDataset(train_image_dir, train_xml_dir, fovea_csv)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4)
    
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 创建模型
    model = UNetWithFovea().to(device)
    
    # 训练模型
    train_model(model, train_loader, val_loader, device)
    
    # 加载最佳模型进行预测
    model.load_state_dict(torch.load('best_model.pth'))
    predict_test_images(model, test_image_dir, device, output_csv)

if __name__ == '__main__':
    main()


Epoch [1/100], Train Loss: 0.8387, Val Loss: 0.9508
Epoch [2/100], Train Loss: 0.6082, Val Loss: 0.6489
Epoch [3/100], Train Loss: 0.5293, Val Loss: 0.4973
Epoch [4/100], Train Loss: 0.4850, Val Loss: 0.4938
Epoch [5/100], Train Loss: 0.4405, Val Loss: 0.4402
Epoch [6/100], Train Loss: 0.4196, Val Loss: 0.4623
Epoch [7/100], Train Loss: 0.3962, Val Loss: 0.3860
Epoch [8/100], Train Loss: 0.3858, Val Loss: 0.4452
Epoch [9/100], Train Loss: 0.3641, Val Loss: 0.3682
Epoch [10/100], Train Loss: 0.3560, Val Loss: 0.6349
Epoch [11/100], Train Loss: 0.3459, Val Loss: 0.3396
Epoch [12/100], Train Loss: 0.3323, Val Loss: 0.3315
Epoch [13/100], Train Loss: 0.3271, Val Loss: 0.3168
Epoch [14/100], Train Loss: 0.3156, Val Loss: 0.3208
Epoch [15/100], Train Loss: 0.3102, Val Loss: 0.3244
Epoch [16/100], Train Loss: 0.2949, Val Loss: 0.3129
Epoch [17/100], Train Loss: 0.2909, Val Loss: 0.3020
Epoch [18/100], Train Loss: 0.2813, Val Loss: 0.2845
Epoch [19/100], Train Loss: 0.2741, Val Loss: 0.2804
Ep

  model.load_state_dict(torch.load('best_model.pth'))


查看模型参数

In [4]:
import torch

checkpoint = torch.load('best_model.pth')
for name, param in checkpoint.items():
    print(name, param.shape)


enc1.double_conv.0.weight torch.Size([64, 3, 3, 3])
enc1.double_conv.0.bias torch.Size([64])
enc1.double_conv.1.weight torch.Size([64])
enc1.double_conv.1.bias torch.Size([64])
enc1.double_conv.1.running_mean torch.Size([64])
enc1.double_conv.1.running_var torch.Size([64])
enc1.double_conv.1.num_batches_tracked torch.Size([])
enc1.double_conv.3.weight torch.Size([64, 64, 3, 3])
enc1.double_conv.3.bias torch.Size([64])
enc1.double_conv.4.weight torch.Size([64])
enc1.double_conv.4.bias torch.Size([64])
enc1.double_conv.4.running_mean torch.Size([64])
enc1.double_conv.4.running_var torch.Size([64])
enc1.double_conv.4.num_batches_tracked torch.Size([])
enc2.double_conv.0.weight torch.Size([128, 64, 3, 3])
enc2.double_conv.0.bias torch.Size([128])
enc2.double_conv.1.weight torch.Size([128])
enc2.double_conv.1.bias torch.Size([128])
enc2.double_conv.1.running_mean torch.Size([128])
enc2.double_conv.1.running_var torch.Size([128])
enc2.double_conv.1.num_batches_tracked torch.Size([])
enc2.dou

  checkpoint = torch.load('best_model.pth')
