In [1]:
import os
import cv2
import numpy as np
import xml.etree.ElementTree as ET
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


Unet训练

In [None]:

# 1. 自定义数据集类
class FoveaDataset(Dataset):
    def __init__(self, image_dir, xml_dir, transform=None):
        self.image_dir = image_dir
        self.xml_dir = xml_dir
        self.transform = transform
        self.images = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        
            
    def get_fovea_coords(self, img_name):
        if self.fovea_coords is not None:
            img_id = int(img_name.split('.')[0])
            coords = self.fovea_coords[self.fovea_coords['data'] == img_id]
            if not coords.empty:
                return (coords['Fovea_X'].values[0], coords['Fovea_Y'].values[0])
        return None
        
    def __len__(self):
        return len(self.images)
    
    def parse_xml(self, xml_path):
        tree = ET.parse(xml_path)
        root = tree.getroot()
        
        size = root.find('size')
        width = int(size.find('width').text)
        height = int(size.find('height').text)
        
        obj = root.find('object')
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)
        
        return (width, height), (xmin, ymin, xmax, ymax)
    
    def create_mask(self, img_shape, bbox):
        mask = np.zeros(img_shape[:2], dtype=np.float32)
        xmin, ymin, xmax, ymax = bbox
        mask[ymin:ymax, xmin:xmax] = 1
        return mask
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        xml_path = os.path.join(self.xml_dir, img_name.replace('.jpg', '.xml'))
        
        # 读取图像
        image = cv2.imread(img_path)
        orig_size = image.shape[:2]
        
        # 解析XML
        _, bbox = self.parse_xml(xml_path)
        mask = self.create_mask(orig_size, bbox)
        
        # 调整大小
        image = cv2.resize(image, (256, 256))
        mask = cv2.resize(mask, (256, 256))
        
        # 转换为张量
        image = image.transpose(2, 0, 1) / 255.0  # 归一化并转换通道顺序
        image = torch.FloatTensor(image)
        mask = torch.FloatTensor(mask).unsqueeze(0)
        
        return image, mask


# 2. 定义U-Net模型
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
    


class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = DoubleConv(3, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.dec3 = DoubleConv(512 + 256, 256)
        self.dec2 = DoubleConv(256 + 128, 128)
        self.dec1 = DoubleConv(128 + 64, 64)
        
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Decoder
        d3 = self.dec3(torch.cat([self.upsample(e4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.upsample(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.upsample(d2), e1], dim=1))
        
        out = self.sigmoid(self.final_conv(d1))
        return out

def train_model(model, train_loader, val_loader, device, num_epochs=100):
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_loss = 0
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        train_loss /= len(train_loader)
        
        # 验证阶段
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
        val_loss /= len(val_loader)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')

def predict_test_images(model, test_dir, device):
    model.eval()
    os.makedirs('predictions', exist_ok=True)
    
    for img_name in os.listdir(test_dir):
        if img_name.endswith('.jpg'):
            # 读取和预处理图像
            img_path = os.path.join(test_dir, img_name)
            image = cv2.imread(img_path)
            original_size = image.shape[:2]
            
            # 预处理
            image_resized = cv2.resize(image, (256, 256))
            image_tensor = torch.FloatTensor(image_resized.transpose(2, 0, 1) / 255.0).unsqueeze(0)
            
            # 预测
            with torch.no_grad():
                image_tensor = image_tensor.to(device)
                prediction = model(image_tensor)
                prediction = prediction.cpu().numpy()[0, 0]
            
            # 将预测结果调整回原始大小
            prediction_resized = cv2.resize(prediction, (original_size[1], original_size[0]))
            
            # 可视化结果
            plt.figure(figsize=(12, 4))
            plt.subplot(131)
            plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            plt.title('Original Image')
            plt.subplot(132)
            plt.imshow(prediction_resized > 0.5, cmap='gray')
            plt.title('Predicted Mask')
            plt.subplot(133)
            plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            plt.imshow(prediction_resized > 0.5, alpha=0.3, cmap='jet')
            plt.title('Overlay')
            plt.savefig(f'predictions/{img_name}_prediction.png')
            plt.close()

def main():
    # 设置路径
    train_image_dir = 'C:/code/vcpython/ML_design_1/task1/detection/train'
    train_xml_dir = 'C:/code/vcpython/ML_design_1/task1/detection/train_location'
    test_image_dir = 'C:/code/vcpython/ML_design_1/task1/detection/test'
    
    # 创建数据集
    dataset = FoveaDataset(train_image_dir, train_xml_dir)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4)
    
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 创建模型
    model = UNet().to(device)
    
    # 训练模型
    train_model(model, train_loader, val_loader, device)
    
    # 加载最佳模型进行预测
    model.load_state_dict(torch.load('best_model.pth'))
    predict_test_images(model, test_image_dir, device)

if __name__ == '__main__':
    main()


Epoch [1/100], Train Loss: 0.5683, Val Loss: 0.7335
Epoch [2/100], Train Loss: 0.3949, Val Loss: 0.4950
Epoch [3/100], Train Loss: 0.3586, Val Loss: 0.3711
Epoch [4/100], Train Loss: 0.3511, Val Loss: 0.3614
Epoch [5/100], Train Loss: 0.3375, Val Loss: 0.3526
Epoch [6/100], Train Loss: 0.3200, Val Loss: 0.3303
Epoch [7/100], Train Loss: 0.3068, Val Loss: 0.3127
Epoch [8/100], Train Loss: 0.2958, Val Loss: 0.2912
Epoch [9/100], Train Loss: 0.2840, Val Loss: 0.2775
Epoch [10/100], Train Loss: 0.2761, Val Loss: 0.2753
Epoch [11/100], Train Loss: 0.2683, Val Loss: 0.2848
Epoch [12/100], Train Loss: 0.2611, Val Loss: 0.2560
Epoch [13/100], Train Loss: 0.2529, Val Loss: 0.2535
Epoch [14/100], Train Loss: 0.2447, Val Loss: 0.2446
Epoch [15/100], Train Loss: 0.2396, Val Loss: 0.2377
Epoch [16/100], Train Loss: 0.2314, Val Loss: 0.2288
Epoch [17/100], Train Loss: 0.2243, Val Loss: 0.2308
Epoch [18/100], Train Loss: 0.2173, Val Loss: 0.2221
Epoch [19/100], Train Loss: 0.2114, Val Loss: 0.2128
Ep

  model.load_state_dict(torch.load('best_model.pth'))


In [1]:
print(1)

1
