In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import json
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import os


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

In [3]:
# 自定义数据集类
class MelonDataset(Dataset):
    def __init__(self, data_path, annotations, transform=None):
        self.data_path = data_path
        self.annotations = annotations
        self.transform = transform
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        img_path = os.path.join(self.data_path, annotation['image_path'])
        
        # 读取图像
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(image)
        
        label = annotation['class_id']
        return image, label

In [4]:
# 数据预处理
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

In [5]:
# 加载标注数据并划分训练验证集
with open('data/annotations/annotations.json', 'r', encoding='utf-8') as f:
    annotations_data = json.load(f)

annotations = annotations_data['annotations']
num_classes = annotations_data['num_classes']


In [None]:
# 划分训练验证集 (80:20)
train_annotations, val_annotations = train_test_split(
    annotations, test_size=0.2, random_state=42, 
    stratify=[ann['class_id'] for ann in annotations]
)

print(f"训练集大小: {len(train_annotations)}")
print(f"验证集大小: {len(val_annotations)}")

In [7]:
#创建数据集和数据加载器
train_dataset = MelonDataset("data/processed/melon17_clean", train_annotations, train_transform)
val_dataset = MelonDataset("data/processed/melon17_clean", val_annotations, val_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
# 构建模型
def create_model(num_classes, model_name='resnet50'):
    """创建预训练模型"""
    
    if model_name == 'resnet50':
        model = models.resnet50(pretrained=True)
        # 冻结前面的层
        for param in model.parameters():
            param.requires_grad = False
        # 替换分类头
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        
    elif model_name == 'efficientnet':
        model = models.efficientnet_b0(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    
    return model

# 创建模型
model = create_model(num_classes, 'resnet50')
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

print("模型创建完成")
print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# 训练函数
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=20):
    """训练模型"""
    
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    best_val_acc = 0.0
    best_model_state = None
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 50)
        
        # 训练阶段
        model.train()
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        
        train_pbar = tqdm(train_loader, desc='Training')
        for inputs, labels in train_pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            
            current_acc = 100 * correct_predictions / total_samples
            train_pbar.set_postfix({'Loss': f'{loss.item():.4f}', 'Acc': f'{current_acc:.2f}%'})
        
        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100 * correct_predictions / total_samples
        
        # 验证阶段
        model.eval()
        val_running_loss = 0.0
        val_correct_predictions = 0
        val_total_samples = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc='Validation'):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                val_loss = criterion(outputs, labels)
                
                val_running_loss += val_loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total_samples += labels.size(0)
                val_correct_predictions += (predicted == labels).sum().item()
        
        epoch_val_loss = val_running_loss / len(val_loader)
        epoch_val_acc = 100 * val_correct_predictions / val_total_samples
        
        # 保存最佳模型
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_model_state = model.state_dict().copy()
        
        # 记录指标
        train_losses.append(epoch_train_loss)
        val_losses.append(epoch_val_loss)
        train_accuracies.append(epoch_train_acc)
        val_accuracies.append(epoch_val_acc)
        
        print(f'Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%')
        print(f'Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.2f}%')
        print()
        
        scheduler.step()
    
    # 恢复最佳模型
    model.load_state_dict(best_model_state)
    
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'val_accuracies': val_accuracies,
        'best_val_acc': best_val_acc
    }

# 开始训练
print("开始训练模型...")
training_history = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=20)

print(f"训练完成！最佳验证准确率: {training_history['best_val_acc']:.2f}%")

In [None]:
# 保存模型和训练历史
torch.save(model.state_dict(), 'models/melon_classifier.pth')

# 保存训练历史
training_history['class_to_idx'] = annotations_data['class_to_idx']
training_history['idx_to_class'] = annotations_data['idx_to_class']

with open('results/training_history.json', 'w') as f:
    json.dump(training_history, f, indent=2)

print("模型和训练历史已保存")
# 绘制训练曲线
def plot_training_history(history):
    """绘制训练历史曲线"""
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # 损失曲线
    ax1.plot(history['train_losses'], label='Training Loss', marker='o')
    ax1.plot(history['val_losses'], label='Validation Loss', marker='s')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # 准确率曲线
    ax2.plot(history['train_accuracies'], label='Training Accuracy', marker='o')
    ax2.plot(history['val_accuracies'], label='Validation Accuracy', marker='s')
    ax2.set_title('Training and Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig('results/training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()

# 绘制训练曲线
plot_training_history(training_history)