In [None]:
import os

train_path_1x = "archive/data_relabeled_balanced_1x/train"
test_path_1x = "archive/data_relabeled_balanced_1x/test"
train_path_2x = "archive/data_relabeled_balanced_2x/train"
test_path_2x = "archive/data_relabeled_balanced_2x/test"
train_path_3x = "archive/data_relabeled_balanced_3x/train"
test_path_3x = "archive/data_relabeled_balanced_3x/test"

train_class_1x = next(os.walk(train_path_1x))[1]
test_class_1x = next(os.walk(test_path_1x))[1]
train_class_2x = next(os.walk(train_path_2x))[1]
test_class_2x = next(os.walk(test_path_2x))[1]
train_class_3x = next(os.walk(train_path_3x))[1]
test_class_3x = next(os.walk(test_path_3x))[1]

print("train_path_1x：",train_class_1x)
print("test_path_1x",test_class_1x)
print("train_path_2x：",train_class_2x)
print("test_path_2x",test_class_2x)
print("train_path_3x：",train_class_3x)
print("test_path_3x",test_class_3x)

['anger', 'contempt', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']


In [3]:
import torch

print("PyTorch 版本:", torch.__version__)
print("CUDA 可用:", torch.cuda.is_available())
print("CUDA 设备数量:", torch.cuda.device_count())
print("当前设备:", torch.cuda.current_device())
print("设备名称:", torch.cuda.get_device_name(0))



PyTorch 版本: 2.6.0+cu124
CUDA 可用: True
CUDA 设备数量: 1
当前设备: 0
设备名称: NVIDIA GeForce RTX 3060 Laptop GPU


In [3]:
os.environ["HTTP_PROXY"] = "http://127.0.0.1:10809"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:10809"
os.environ["TORCH_HOME"] = "."

# 禁用torch ssl验证
# os.environ["CURL_CA_BUNDLE"] = ""

In [5]:
# train_color.py

# ----------------------
# 1. 基础导入和环境配置
# ----------------------
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.amp import GradScaler,autocast
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
# 设置代理和下载目录
os.environ["HTTP_PROXY"] = "http://127.0.0.1:10809"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:10809"
os.environ["TORCH_HOME"] = "."
from tqdm.auto import tqdm
import psutil
import gc
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
import cv2
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from datetime import datetime

# ----------------------
# 2. 基础配置
# ----------------------
# 2.1 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True  # 加速卷积运算
if device.type == 'cuda':
    print("Using GPU for training.")
else:
    print("Warning: GPU not available. Training on CPU.")

# 关键参数配置 (根据6GB显存优化)
# 2.2 训练参数配置
config = {
    "data_dir": "archive/data_relabeled_balanced_1x/train",
    "batch_size": 32,
    "grad_accum_steps": 4,
    "num_epochs": 100,
    "learning_rate": 1e-3,
    "num_classes": 8,
    "input_size": 224,
    "valid_ratio": 0.15,
    "seed": 42,
    "num_workers": 0,
    "max_grad_norm": 1.0,
    "mixup_alpha": 0.2,
    "label_smoothing": 0.1,
    "model_dir": "visionModel",
    "freeze_blocks": 2,
    "patience": 15,        # 增加耐心值
    "min_lr": 1e-6,
    "weight_decay": 5e-4,
    "T_0": 10,            # 余弦退火周期
    "T_mult": 2,          # 周期倍增因子
}

# 2.3 固定随机种子
def seed_everything(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
seed_everything(config['seed'])

# ----------------------
# 内存优化数据集
# 3. 数据集和数据加载
# ----------------------
# 3.1 数据集类定义
class OptimizedDataset(Dataset):
    """优化的数据集类，支持内存缓存"""
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.classes = sorted([d for d in os.listdir(root_dir) 
                             if os.path.isdir(os.path.join(root_dir, d)) and '_gray' not in d]) # 过滤掉灰度数据集
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.samples = []
        self.transform = transform
        
        # 预加载图像路径
        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            self.samples.extend([
                (os.path.join(cls_dir, fname), self.class_to_idx[cls]) 
                for fname in os.listdir(cls_dir)
                if fname.lower().endswith(('png', 'jpg', 'jpeg'))
            ])
        
        # 内存优化：预加载小尺寸图像
        self.cache = {}
        if psutil.virtual_memory().available > 8*1024**3:  # 仅当内存>8GB时启用
            for idx in tqdm(range(len(self.samples)), desc="预加载图像"):
                img_path, label = self.samples[idx]
                self.cache[idx] = (Image.open(img_path).convert('RGB'), label)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if idx in self.cache:
            img, label = self.cache[idx]
        else:
            img_path, label = self.samples[idx]
            img = Image.open(img_path).convert('RGB')
        
        if self.transform:
            # img = self.transform(img) # 转换为tensor
            img = np.array(img)  # 转换为numpy array
            augmented = self.transform(image=img)
            img = augmented['image']
            
        return img, label

# 3.2 数据增强定义
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.Resize(config['input_size'], config['input_size']),
    A.HorizontalFlip(p=0.5),
    # A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
    # 将ShiftScaleRotate替换为 Affine
    A.Affine(
        scale=(0.9, 1.1), # 允许图像随机缩放，范围为 0.9 到 1.1
        translate_percent=(-0.1, 0.1),  # 允许图像随机平移，范围为 -10% 到 10%
        rotate=(-15, 15), # 允许图像随机旋转，范围为 -15 到 15 度 
        shear=(-10, 10),              # 加入轻微shear变换，范围设为 -10 到 10 度
        interpolation=cv2.INTER_LINEAR, # 使用cv2.INTER_LINEAR进行线性插值
        border_mode=cv2.BORDER_REFLECT_101, # 使用cv2.BORDER_REFLECT_101方式填充边缘
        p=0.5 # 应用概率为 0.5
    ),
    A.RandomBrightnessContrast(p=0.3),
    # A.CoarseDropout(max_holes=3, max_height=20, max_width=20, p=0.3),
    # 修改 CoarseDropout 参数
    A.CoarseDropout(
        num_holes_range=(3, 6),           # 每次随机遮挡 3-6 个区域
        hole_height_range=(0.1, 0.2),     # 高度为图像高度的 10%-20%
        hole_width_range=(0.1, 0.2),      # 宽度为图像宽度的 10%-20%
        fill=0,                     # 使用黑色填充
        p=0.3                             # 应用概率为 0.3
    ),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

valid_transform = A.Compose([
    A.Resize(config['input_size'], config['input_size']),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# 3.3 创建数据加载器
def create_data_loaders():
    """创建训练和验证数据加载器"""
    # 创建数据集
    full_dataset = OptimizedDataset(config['data_dir'], transform=train_transform)
    train_size = int((1 - config['valid_ratio']) * len(full_dataset))
    valid_size = len(full_dataset) - train_size
    train_dataset, valid_dataset = torch.utils.data.random_split(full_dataset, [train_size, valid_size])
    valid_dataset.dataset.transform = valid_transform
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        # persistent_workers=True
        persistent_workers=False # 关闭持续workers（多进程）
    )
    
    # 验证数据加载器
    valid_loader = DataLoader(
        valid_dataset, 
        batch_size=config['batch_size']*2,
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True
    )

    # 返回训练和验证数据加载器
    return train_loader, valid_loader

# ----------------------
# 4. 模型相关
# ----------------------
# 4.1 模型创建函数
def create_model():
    """创建并初始化模型（带重试机制）"""
    max_retries = 3
    for attempt in range(max_retries):
        try:
            # 使用EfficientNet B0并加载ImageNet预训练权重
            model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
            in_features = model.classifier[1].in_features
            # 使用新的 classifier，加入 dropout 层以降低过拟合
            model.classifier = nn.Sequential(
                nn.Dropout(p=0.5),
                nn.Linear(in_features, config['num_classes'])
            )
            
            # 冻结前 freeze_blocks 个阶段的参数以减少模型复杂度
            for param in model.features[:config['freeze_blocks']].parameters():
                param.requires_grad = False
            
            # 显存优化配置
            model = model.to(device, memory_format=torch.channels_last)
            print("Model created successfully.")
            return model
        except Exception as e:
            print(f"Attempt {attempt + 1}/{max_retries} failed: {str(e)}")
            torch.cuda.empty_cache()
            gc.collect()
            if attempt == max_retries - 1:
                raise e

# ----------------------
# 5. 训练工具函数
# ----------------------
def print_memory_usage():
    """打印显存使用情况"""
    if device.type == 'cuda':
        mem = torch.cuda.memory_reserved(device) / 1e9
        print(f"当前显存占用: {mem:.2f}GB")
    else:
        print("Training on CPU, no GPU memory usage available.")

def mixup_data(x, y, alpha=1.0):
    """Mixup数据增强"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)
    
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# ----------------------
# 6. 训练和验证函数
# ----------------------
def train_one_epoch(model, train_loader, optimizer, criterion, scaler):
    """训练一个epoch"""
    model.train()
    correct = 0
    total = 0
    
    progress_bar = tqdm(train_loader, 
                       desc='Training',
                       bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}')
    
    optimizer.zero_grad()
    
    for step, (inputs, labels) in enumerate(progress_bar):
        try:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            # Mixup数据增强
            inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, config['mixup_alpha'])
            
            with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
                outputs = model(inputs)
                loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
                loss = loss / config['grad_accum_steps']  # 梯度累积
            
            scaler.scale(loss).backward()
            
            if (step + 1) % config['grad_accum_steps'] == 0 or (step + 1) == len(train_loader):
                # 梯度裁剪
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
                
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            # 计算准确率
            _, predicted = outputs.max(1)
            correct += (lam * predicted.eq(targets_a).sum().item() + 
                       (1 - lam) * predicted.eq(targets_b).sum().item())
            total += labels.size(0)
            train_acc = 100. * correct / total
            
            # 计算显存使用情况
            mem_str = f"{torch.cuda.memory_reserved(device)/1e9:.1f}G" if device.type == 'cuda' else "N/A"
            
            # 更新进度条
            progress_bar.set_postfix({
                'loss': f"{loss.item()*config['grad_accum_steps']:.3f}",
                'acc': f"{train_acc:.1f}%",
                'lr': f"{optimizer.param_groups[0]['lr']:.1e}",
                'mem': mem_str
            })
        except Exception as e:
            print(f"Error in training step: {str(e)}")
            continue
    
    return loss.item(), train_acc

def validate(model, valid_loader, criterion):
    """验证模型并收集详细指标"""
    model.eval()
    valid_loss = 0.0
    correct = 0
    total = 0
    
    # 收集预测结果和真实标签
    all_preds = []
    all_labels = []
    class_correct = [0] * config['num_classes']
    class_total = [0] * config['num_classes']
    
    try:
        with torch.no_grad(), autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
            for inputs, labels in tqdm(valid_loader, desc='Validating', leave=False):
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                valid_loss += loss.item()
                _, predicted = outputs.max(1)
                
                # 收集预测结果
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                
                # 计算每个类别的准确率
                for label, pred in zip(labels, predicted):
                    if label == pred:
                        class_correct[label] += 1
                    class_total[label] += 1
                
                correct += predicted.eq(labels).sum().item()
                total += labels.size(0)
        
        # 计算每个类别的准确率
        class_accuracies = [100 * correct / total if total > 0 else 0 
                           for correct, total in zip(class_correct, class_total)]
        
        return (valid_loss / len(valid_loader), 100. * correct / total, 
                all_preds, all_labels, class_accuracies)
    except Exception as e:
        print(f"Error in validation: {str(e)}")
        return float('inf'), 0.0, [], [], [0] * config['num_classes']

# ----------------------
# 7. 主训练循环
# ----------------------
def main():
    """主训练函数"""
    os.makedirs(config['model_dir'], exist_ok=True)
    
    # 创建训练结果保存目录
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = os.path.join(config['model_dir'], f'training_results_{timestamp}')
    os.makedirs(save_dir, exist_ok=True)
    
    train_loader, valid_loader = create_data_loaders()
    model = create_model()
    scaler = GradScaler()
    
    # 使用标签平滑的交叉熵损失
    criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])
    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )
    
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=config['T_0'],
        T_mult=config['T_mult'],
        eta_min=config['min_lr']
    )
    
    early_stopping = EarlyStopping(patience=config['patience'])
    best_acc = 0.0
    history = {
        'train_loss': [], 'train_acc': [], 
        'valid_loss': [], 'valid_acc': [],
        'learning_rates': [], 'class_accuracies': [],
        'final_confusion_matrix': None
    }
    
    for epoch in range(config['num_epochs']):
        print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
        
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, scaler)
        valid_loss, valid_acc, all_preds, all_labels, class_accuracies = validate(
            model, valid_loader, criterion)
        
        # 记录历史
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['valid_loss'].append(valid_loss)
        history['valid_acc'].append(valid_acc)
        history['learning_rates'].append(optimizer.param_groups[0]['lr'])
        history['class_accuracies'].append(class_accuracies)
        
        # 在最后一个epoch或提前停止时保存混淆矩阵
        if epoch == config['num_epochs']-1 or early_stopping.counter >= config['patience']:
            conf_matrix = confusion_matrix(all_labels, all_preds)
            history['final_confusion_matrix'] = conf_matrix
            
            # 打印分类报告
            print("\nClassification Report:")
            print(classification_report(all_labels, all_preds))
        
        if valid_acc > best_acc:
            best_acc = valid_acc
            save_path = os.path.join(
                save_dir, 
                f'best_model_epoch{epoch+1}_acc{valid_acc:.1f}.pth'
            )
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
                'history': history,
            }, save_path)
            print(f"Saved best model to: {save_path}")
        
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Learning rate: {current_lr:.1e}")
        
        # 每5个epoch绘制一次训练图
        if (epoch + 1) % 5 == 0:
            plot_training_history(history, save_dir)
        
        early_stopping(valid_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
        
        torch.cuda.empty_cache()
        gc.collect()
    
    # 训练结束后绘制最终图表
    plot_training_history(history, save_dir)
    
    # 保存完整训练历史
    history_path = os.path.join(save_dir, 'training_history.pth')
    torch.save(history, history_path)
    print(f"Saved training history to: {history_path}")

# ----------------------
# 8. 可视化函数
# ----------------------
def plot_training_history(history, save_dir):
    """绘制详细的训练历史"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # 1. 损失和准确率图
    plt.figure(figsize=(15, 10))
    
    plt.subplot(2, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['valid_loss'], label='Valid')
    plt.title('Training/Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(2, 2, 2)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['valid_acc'], label='Valid')
    plt.title('Training/Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    # 2. 学习率变化图
    plt.subplot(2, 2, 3)
    plt.plot(history['learning_rates'])
    plt.title('Learning Rate Schedule')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.yscale('log')
    
    # 3. 每个类别的准确率变化
    plt.subplot(2, 2, 4)
    for i, class_acc in enumerate(zip(*history['class_accuracies'])):
        plt.plot(class_acc, label=f'Class {i}')
    plt.title('Per-Class Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'training_metrics_{timestamp}.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. 混淆矩阵
    if history['final_confusion_matrix'] is not None:
        plt.figure(figsize=(10, 8))
        sns.heatmap(history['final_confusion_matrix'], 
                   annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.savefig(os.path.join(save_dir, f'confusion_matrix_{timestamp}.png'), dpi=300, bbox_inches='tight')
        plt.close()

# ----------------------
# 9. 程序入口
# ----------------------
if __name__ == "__main__":
    main()

Using GPU for training.




Model created successfully.

Epoch 1/50


Training: 100%|████████████████████| 389/389 [08:57<00:00,  1.38s/it, loss=1.883, acc=37.0%, lr=2.0e-04, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch1_acc56.4.pth
Learning rate: 2.0e-04

Epoch 2/50


Training: 100%|████████████████████| 389/389 [08:53<00:00,  1.37s/it, loss=2.172, acc=54.9%, lr=2.0e-04, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch2_acc65.0.pth
Learning rate: 2.0e-04

Epoch 3/50


Training: 100%|████████████████████| 389/389 [08:44<00:00,  1.35s/it, loss=1.814, acc=61.7%, lr=2.0e-04, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch3_acc67.7.pth
Learning rate: 2.0e-04

Epoch 4/50


Training: 100%|████████████████████| 389/389 [08:57<00:00,  1.38s/it, loss=1.366, acc=65.6%, lr=2.0e-04, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch4_acc69.6.pth
Learning rate: 2.0e-04

Epoch 5/50


Training: 100%|████████████████████| 389/389 [08:49<00:00,  1.36s/it, loss=1.351, acc=69.8%, lr=2.0e-04, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch5_acc70.6.pth
Learning rate: 2.0e-04

Epoch 6/50


Training: 100%|████████████████████| 389/389 [09:14<00:00,  1.42s/it, loss=1.332, acc=71.8%, lr=2.0e-04, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch6_acc71.3.pth
Learning rate: 2.0e-04

Epoch 7/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.482, acc=76.3%, lr=2.0e-04, mem=1.2G]
                                                                     

Learning rate: 2.0e-04

Epoch 8/50


Training: 100%|████████████████████| 389/389 [08:53<00:00,  1.37s/it, loss=1.057, acc=78.7%, lr=2.0e-04, mem=1.2G]
                                                                     

Learning rate: 2.0e-04

Epoch 9/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=2.270, acc=81.5%, lr=2.0e-04, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch9_acc72.4.pth
Learning rate: 1.0e-04

Epoch 10/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.670, acc=82.3%, lr=1.0e-04, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch10_acc72.8.pth
Learning rate: 1.0e-04

Epoch 11/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.040, acc=83.3%, lr=1.0e-04, mem=1.2G]
                                                                     

Learning rate: 1.0e-04

Epoch 12/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.572, acc=82.7%, lr=1.0e-04, mem=1.2G]
                                                                     

Learning rate: 1.0e-04

Epoch 13/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.626, acc=82.4%, lr=1.0e-04, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 14/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.594, acc=85.6%, lr=5.0e-05, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch14_acc73.0.pth
Learning rate: 5.0e-05

Epoch 15/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.944, acc=85.5%, lr=5.0e-05, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch15_acc73.0.pth
Learning rate: 5.0e-05

Epoch 16/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.034, acc=84.4%, lr=5.0e-05, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch16_acc73.4.pth
Learning rate: 5.0e-05

Epoch 17/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.322, acc=86.1%, lr=5.0e-05, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 18/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.882, acc=84.3%, lr=5.0e-05, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 19/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.230, acc=85.6%, lr=5.0e-05, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 20/50


Training: 100%|████████████████████| 389/389 [08:50<00:00,  1.36s/it, loss=1.005, acc=86.1%, lr=5.0e-05, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 21/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.126, acc=84.6%, lr=5.0e-05, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch21_acc73.6.pth
Learning rate: 5.0e-05

Epoch 22/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.719, acc=84.2%, lr=5.0e-05, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 23/50


Training: 100%|████████████████████| 389/389 [08:53<00:00,  1.37s/it, loss=2.016, acc=84.6%, lr=5.0e-05, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch23_acc73.7.pth
Learning rate: 5.0e-05

Epoch 24/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.630, acc=85.8%, lr=5.0e-05, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 25/50


Training: 100%|████████████████████| 389/389 [08:51<00:00,  1.37s/it, loss=1.162, acc=85.9%, lr=5.0e-05, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 26/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.008, acc=86.5%, lr=5.0e-05, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch26_acc73.9.pth
Learning rate: 5.0e-05

Epoch 27/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.659, acc=86.0%, lr=5.0e-05, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 28/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.572, acc=84.6%, lr=5.0e-05, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 29/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.967, acc=87.5%, lr=5.0e-05, mem=1.2G]
                                                                     

Learning rate: 5.0e-05

Epoch 30/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.946, acc=86.8%, lr=5.0e-05, mem=1.2G]
                                                                     

Saved best model to: visionModel\best_model_epoch30_acc74.1.pth
Learning rate: 2.5e-05

Epoch 31/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.542, acc=85.5%, lr=2.5e-05, mem=1.2G]
                                                                     

Learning rate: 2.5e-05

Epoch 32/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.348, acc=85.1%, lr=2.5e-05, mem=1.2G]
                                                                     

Learning rate: 2.5e-05

Epoch 33/50


Training: 100%|████████████████████| 389/389 [08:52<00:00,  1.37s/it, loss=1.560, acc=86.4%, lr=2.5e-05, mem=1.2G]
                                                                     

Learning rate: 1.3e-05

Epoch 34/50


Training: 100%|████████████████████| 389/389 [08:50<00:00,  1.36s/it, loss=1.148, acc=86.4%, lr=1.3e-05, mem=1.2G]
                                                                     

Learning rate: 1.3e-05

Epoch 35/50


Training: 100%|████████████████████| 389/389 [28:00<00:00,  4.32s/it, loss=1.034, acc=88.0%, lr=1.3e-05, mem=1.2G]
                                                                     

Learning rate: 1.3e-05

Epoch 36/50


Training:  14%|██▋                 | 53/389 [06:20<40:10,  7.17s/it, loss=0.701, acc=88.1%, lr=1.3e-05, mem=1.2G]

In [4]:
# train_gray.py

# ----------------------
# 1. 基础导入和配置
# ----------------------
import os
# 设置代理和下载目录
os.environ["HTTP_PROXY"] = "http://127.0.0.1:10809"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:10809"
os.environ["TORCH_HOME"] = "."
import re
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import tqdm
from train_color import (
    device, seed_everything, OptimizedDataset,
    GradScaler, autocast, A, ToTensorV2, tqdm, validate
)
import gc


# ----------------------
# 2. 增强配置参数
# ----------------------
class GrayConfig:
    # 自动查找最佳模型
    @staticmethod
    def find_best_model(model_dir="visionModel"):
        model_files = glob.glob(os.path.join(model_dir, "best_model_epoch*.pth"))
        best_acc = -1
        best_path = ""
        
        for file_path in model_files:
            match = re.search(r"acc([\d.]+)\.pth", file_path)
            if match:
                acc = float(match.group(1))
                if acc > best_acc:
                    best_acc = acc
                    best_path = file_path
        
        if not best_path:
            raise FileNotFoundError(f"No valid model found in {model_dir}")
        print(f"Selected best model: {best_path} (acc={best_acc:.1f}%)")
        return best_path

    # 训练参数
    config = {
        "data_dir": "archive/data_relabeled_balanced_1x/train",
        "model_dir": "visionModel_gray",
        "pretrained": find_best_model.__func__(),  # 自动选择最佳模型
        "batch_size": 128,
        "grad_accum_steps": 2,
        "num_epochs": 50,
        "learning_rate": 5e-5,
        "num_classes": 8,
        "input_size": 224,
        "valid_ratio": 0.15,
        "num_workers": 4 if os.cpu_count() > 4 else 0,
        "max_grad_norm": 0.5,
        "mixup_alpha": 0.3,
        "label_smoothing": 0.05,
        "freeze_stages": 3,
        "unfreeze_layers": ["features.5", "features.6", "classifier"],
        "lr_decay": 0.95,
        "es_patience": 5
    }

# ----------------------
# 3. 数据集调整
# ----------------------
class GrayDataset(OptimizedDataset):
    """灰度专用数据集"""
    def __init__(self, root_dir, transform=None):
        self.gray_classes = [d for d in os.listdir(root_dir) 
                            if d.endswith('_gray') and os.path.isdir(os.path.join(root_dir, d))]
        super().__init__(root_dir, transform)
        
        # 重建class_to_idx映射
        self.classes = [c.replace('_gray', '') for c in self.gray_classes]
        self.class_to_idx = {cls:i for i, cls in enumerate(self.classes)}
        
        # 重建样本列表
        self.samples = []
        for gray_cls in self.gray_classes:
            cls_dir = os.path.join(root_dir, gray_cls)
            label = self.class_to_idx[gray_cls.replace('_gray', '')]
            self.samples.extend([
                (os.path.join(cls_dir, fname), label)
                for fname in os.listdir(cls_dir)
                if fname.lower().endswith(('png', 'jpg', 'jpeg'))
            ])

    def __getitem__(self, idx):
        img, label = super().__getitem__(idx)
        return img, label

# ----------------------
# 4. 数据增强优化
# ----------------------
def get_gray_transforms():
    """灰度专用增强策略"""
    return {
        'train': A.Compose([
            A.Resize(GrayConfig.config['input_size'], GrayConfig.config['input_size']),
            A.HorizontalFlip(p=0.5),
            A.Affine(
                scale=(0.8, 1.2),
                translate_percent=(-0.2, 0.2),
                rotate=(-30, 30),
                shear=(-15, 15),
                p=0.7
            ),
            A.OneOf([
                A.GaussianBlur(blur_limit=(3, 7)),  # 模糊增强
                A.MotionBlur(blur_limit=7),        # 运动模糊
                A.GlassBlur(sigma=0.7, max_delta=2) # 玻璃模糊
            ], p=0.3),
            A.RandomGamma(gamma_limit=(80, 120), p=0.3),
            A.GaussNoise(
                std=20.0,  # 标准差范围改为固定值
                mean=0,    # 均值设为0
                per_channel=True,  # 每个通道独立添加噪声
                p=0.2
            ),
            A.CoarseDropout(
                num_holes_range=(3, 6),           # 每次随机遮挡 3-6 个区域
                hole_height_range=(0.1, 0.2),     # 高度为图像高度的 10%-20%
                hole_width_range=(0.1, 0.2),      # 宽度为图像宽度的 10%-20%
                fill=0,                     # 使用黑色填充
                p=0.3                             # 应用概率为 0.3
            ),
            A.Normalize(mean=[0.5], std=[0.5]),  # 单通道标准化
            ToTensorV2()
        ]),
        'valid': A.Compose([
            A.Resize(GrayConfig.config['input_size'], GrayConfig.config['input_size']),
            A.Normalize(mean=[0.5], std=[0.5]),
            ToTensorV2()
        ])
    }

# ----------------------
# 5. 模型适配优化
# ----------------------
class GrayModelAdapter:
    """灰度模型适配器"""
    @staticmethod
    def convert_model(model, checkpoint_path):
        # 加载预训练权重
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # 调整第一层卷积
        original_conv = model.features[0][0]
        new_conv = nn.Conv2d(
            in_channels=1,
            out_channels=original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride,
            padding=original_conv.padding,
            bias=original_conv.bias is not None
        )
        
        # 权重转换策略
        with torch.no_grad():
            if original_conv.weight.shape[1] == 3:
                # RGB转灰度：使用加权平均
                weights = original_conv.weight.data
                gray_weights = (0.2989 * weights[:,0] + 0.5870 * weights[:,1] + 0.1140 * weights[:,2]).unsqueeze(1)
                new_conv.weight.data = gray_weights
            else:
                new_conv.weight.data = original_conv.weight.data.mean(dim=1, keepdim=True)
            
            if new_conv.bias is not None:
                new_conv.bias.data = original_conv.bias.data.clone()
        
        model.features[0][0] = new_conv
        return model

    @staticmethod
    def freeze_layers(model):
        # 冻结指定层
        for name, param in model.named_parameters():
            if not any([layer in name for layer in GrayConfig.config['unfreeze_layers']]):
                param.requires_grad = False
            else:
                param.requires_grad = True
        return model

# ----------------------
# 6. 训练流程优化
# ----------------------
def train_gray():
    # 初始化配置
    seed_everything(42)
    cfg = GrayConfig.config
    os.makedirs(cfg['model_dir'], exist_ok=True)

    # 数据加载
    transforms = get_gray_transforms()
    full_dataset = GrayDataset(cfg['data_dir'], transform=transforms['train'])
    train_size = int((1 - cfg['valid_ratio']) * len(full_dataset))
    train_dataset, valid_dataset = torch.utils.data.random_split(full_dataset, [train_size, len(full_dataset)-train_size])
    valid_dataset.dataset.transform = transforms['valid']

    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg['batch_size'],
        shuffle=True,
        num_workers=cfg['num_workers'],
        pin_memory=True,
        persistent_workers=True if cfg['num_workers']>0 else False
    )

    # 模型初始化
    model = models.efficientnet_b0()
    model.classifier = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(model.classifier[1].in_features, cfg['num_classes'])
    )
    model = GrayModelAdapter.convert_model(model, cfg['pretrained'])
    model = GrayModelAdapter.freeze_layers(model)
    model = model.to(device)

    # 优化器配置
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=cfg['learning_rate'],
        weight_decay=1e-5
    )
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=len(train_loader)*2,
        T_mult=1,
        eta_min=1e-7
    )
    criterion = nn.CrossEntropyLoss(label_smoothing=cfg['label_smoothing'])
    scaler = GradScaler()

    # 早停机制
    best_acc = 0.0
    epochs_no_improve = 0
    
    # 记录训练历史
    history = {
        'train_loss': [],
        'train_acc': [],
        'valid_loss': [],
        'valid_acc': []
    }

    # 训练循环
    for epoch in range(cfg['num_epochs']):
        print(f"\nEpoch {epoch+1}/{cfg['num_epochs']}")
        
        # 训练阶段
        model.train()
        progress_bar = tqdm(train_loader, desc='Training', dynamic_ncols=True)
        optimizer.zero_grad()
        
        train_loss = 0.0
        correct = 0
        total = 0
        
        for step, (inputs, labels) in enumerate(progress_bar):
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            # 混合精度训练
            with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss = loss / cfg['grad_accum_steps']

            scaler.scale(loss).backward()

            # 累积训练统计
            train_loss += loss.item() * cfg['grad_accum_steps']
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            train_acc = 100. * correct / total

            # 梯度累积
            if (step+1) % cfg['grad_accum_steps'] == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            # 更新进度条
            progress_bar.set_postfix({
                'loss': f"{loss.item()*cfg['grad_accum_steps']:.3f}",
                'acc': f"{train_acc:.1f}%",
                'lr': f"{optimizer.param_groups[0]['lr']:.1e}",
                'mem': f"{torch.cuda.memory_reserved(device)/1e9:.1f}G" if torch.cuda.is_available() else "N/A"
            })

        # 计算平均训练损失
        train_loss = train_loss / len(train_loader)
        
        # 验证阶段
        valid_loader = DataLoader(
            valid_dataset,
            batch_size=cfg['batch_size']*2,
            shuffle=False,
            num_workers=cfg['num_workers']
        )
        valid_loss, valid_acc = validate(model, valid_loader, criterion)
        
        # 记录历史
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['valid_loss'].append(valid_loss)
        history['valid_acc'].append(valid_acc)
        
        # 打印验证结果
        print(f"Validation Loss: {valid_loss:.4f}, Validation Accuracy: {valid_acc:.1f}%")

        # 模型保存逻辑
        if valid_acc > best_acc:
            best_acc = valid_acc
            epochs_no_improve = 0
            save_path = os.path.join(cfg['model_dir'], f'gray_best_acc{best_acc:.1f}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
            }, save_path)
            print(f"New best model saved with acc {best_acc:.1f}%")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= cfg['es_patience']:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break

        # 释放内存
        torch.cuda.empty_cache()
        gc.collect()

    print(f"Training completed. Best validation accuracy: {best_acc:.1f}%")
    
    # 绘制训练曲线
    plot_training_history(history, cfg['model_dir'])

def plot_training_history(history, save_dir):
    """绘制训练历史"""
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(16, 6))
    
    # 损失曲线
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['valid_loss'], label='Valid')
    plt.title('Training/Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # 准确率曲线
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['valid_acc'], label='Valid')
    plt.title('Training/Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    # 保存图像
    save_path = os.path.join(save_dir, 'gray_training_metrics.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved training metrics plot to: {save_path}")

if __name__ == "__main__":
    train_gray()

  A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
  A.CoarseDropout(


Using GPU for training.
Selected best model: visionModel\best_model_epoch30_acc74.1.pth (acc=74.1%)


预加载图像: 100%|██████████| 29217/29217 [00:24<00:00, 1206.94it/s]



Epoch 1/50


Training:   0%|          | 0/389 [00:00<?, ?it/s]

In [None]:
# train.py
import os
import re
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from torch.amp import GradScaler, autocast
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gc
import cv2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------
# 1. 系统配置
# ----------------------
class TrainingConfig:
    # 自动模型选择
    @staticmethod
    def select_model(model_dir="visionModel"):
        models = glob.glob(os.path.join(model_dir, "*.pth"))
        if not models:
            return None
        
        print("\n发现预训练模型:")
        for i, path in enumerate(models, 1):
            acc = re.search(r"acc([\d.]+)", path)
            acc = acc.group(1) if acc else "未知"
            print(f"[{i}] {os.path.basename(path)} (准确率: {acc}%)")
        
        print("[0] 从头开始训练")
        choice = int(input("请选择模型编号: "))
        return models[choice-1] if choice > 0 else None

    # 训练参数
    config = {
        "data_dir": "archive/data_relabeled_balanced_1x/train",
        "model_dir": "visionModel_enhanced",
        "batch_size": 128,
        "grad_accum_steps": 2,
        "num_epochs": 100,
        "learning_rate": 3e-5,
        "num_classes": 8,
        "input_size": 224,
        "valid_ratio": 0.15,
        "num_workers": 4 if os.cpu_count() > 4 else 0,
        "max_grad_norm": 0.5,
        "label_smoothing": 0.1,
        "freeze_blocks": 2,
        "es_patience": 8
    }

# ----------------------
# 2. 混合数据集
# ----------------------
class HybridDataset(Dataset):
    """支持彩色和灰度混合训练的数据集"""
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.classes = self._get_classes()
        self.class_to_idx = {cls:i for i, cls in enumerate(self.classes)}
        self.samples = []
        self.transform = transform
        
        # 加载所有图像路径（包含彩色和灰度）
        for cls in self.classes:
            # 彩色目录
            color_dir = os.path.join(root_dir, cls)
            if os.path.exists(color_dir):
                self.samples += [
                    (os.path.join(color_dir, f), self.class_to_idx[cls])
                    for f in os.listdir(color_dir) 
                    if f.lower().endswith(('png','jpg','jpeg'))
                ]

            # 灰度目录
            gray_dir = os.path.join(root_dir, f"{cls}_gray")
            if os.path.exists(gray_dir):
                self.samples += [
                    (os.path.join(gray_dir, f), self.class_to_idx[cls])
                    for f in os.listdir(gray_dir)
                    if f.lower().endswith(('png','jpg','jpeg'))
                ]

    def _get_classes(self):
        """获取有效类别列表（过滤灰度目录）"""
        all_dirs = [d for d in os.listdir(self.root_dir) 
                   if os.path.isdir(os.path.join(self.root_dir, d))]
        return sorted({d.replace('_gray', '') for d in all_dirs})

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')  # 统一转为三通道
        
        if self.transform:
            img = np.array(img)
            img = self.transform(image=img)['image']
        return img, label

# ----------------------
# 3. 数据增强
# ----------------------
def get_transforms():
    return {
        'train': A.Compose([
            A.Resize(224, 224),
            A.HorizontalFlip(p=0.5),
            A.OneOf([
                A.RandomGamma(gamma_limit=(80, 120)),
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
            ], p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.1, 
                scale_limit=0.2, 
                rotate_limit=20,
                border_mode=cv2.BORDER_REFLECT_101,
                p=0.5
            ),
            A.CoarseDropout(
                max_holes=5,
                max_height=0.2, 
                max_width=0.2,
                p=0.3
            ),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2()
        ]),
        'valid': A.Compose([
            A.Resize(224, 224),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2()
        ])
    }

# ----------------------
# 4. 模型工具
# ----------------------
def create_model(num_classes, pretrained=None):
    """创建/加载模型"""
    model = models.efficientnet_b0(pretrained=False)
    
    # 修改分类层
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(in_features, num_classes)
    )
    
    # 加载预训练权重
    if pretrained:
        print(f"加载预训练模型: {pretrained}")
        checkpoint = torch.load(pretrained)
        
        # 自适应调整第一层卷积
        state_dict = checkpoint['model_state_dict']
        current_conv = model.features[0][0]
        
        # 处理输入通道不匹配的情况
        if state_dict['features.0.0.weight'].shape[1] != current_conv.in_channels:
            print("调整输入通道...")
            orig_weight = state_dict['features.0.0.weight']
            if current_conv.in_channels == 3:
                # 灰度转彩色：复制单通道权重到三通道
                new_weight = orig_weight.repeat(1,3,1,1) / 3.0
            else:
                # 彩色转灰度：取三通道均值
                new_weight = orig_weight.mean(dim=1, keepdim=True)
            
            state_dict['features.0.0.weight'] = new_weight
        
        model.load_state_dict(state_dict, strict=False)
    
    # 冻结部分层
    for param in model.features[:TrainingConfig.config['freeze_blocks']].parameters():
        param.requires_grad = False
        
    return model.to(device)

# ----------------------
# 5. 训练引擎
# ----------------------
def train():
    # 初始化配置
    config = TrainingConfig.config
    os.makedirs(config['model_dir'], exist_ok=True)
    
    # 选择设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    # 数据加载
    transforms = get_transforms()
    dataset = HybridDataset(config['data_dir'], transform=transforms['train'])
    train_size = int((1 - config['valid_ratio']) * len(dataset))
    train_set, valid_set = random_split(dataset, [train_size, len(dataset)-train_size])
    valid_set.dataset.transform = transforms['valid']
    
    train_loader = DataLoader(
        train_set,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    # 模型初始化
    pretrained = TrainingConfig.select_model()
    model = create_model(config['num_classes'], pretrained)
    
    # 优化器配置
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config['learning_rate'],
        weight_decay=1e-5
    )
    
    # 加载优化器状态（如果存在）
    if pretrained:
        checkpoint = torch.load(pretrained)
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("已加载优化器状态")
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='max', 
        factor=0.5, 
        patience=3,
        verbose=True
    )
    
    criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])
    scaler = GradScaler()
    
    # 训练循环
    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'valid_loss': [], 'valid_acc': []}
    
    for epoch in range(config['num_epochs']):
        print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
        
        # 训练阶段
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(train_loader, desc='Training')
        for step, (inputs, labels) in enumerate(progress_bar):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            with autocast(device_type=device.type):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            
            if (step+1) % config['grad_accum_steps'] == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            # 统计指标
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
            acc = 100. * correct / total
            
            progress_bar.set_postfix({
                'loss': f"{loss.item():.3f}",
                'acc': f"{acc:.1f}%",
                'lr': f"{optimizer.param_groups[0]['lr']:.1e}"
            })
        
        # 验证阶段
        valid_loss, valid_acc = validate(model, DataLoader(
            valid_set, 
            batch_size=config['batch_size']*2,
            shuffle=False
        ), criterion)
        
        # 记录历史
        history['train_loss'].append(total_loss/len(train_loader))
        history['train_acc'].append(acc)
        history['valid_loss'].append(valid_loss)
        history['valid_acc'].append(valid_acc)
        
        # 学习率调整
        scheduler.step(valid_acc)
        
        # 保存最佳模型
        if valid_acc > best_acc:
            best_acc = valid_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
            }, os.path.join(config['model_dir'], f'enhanced_acc{best_acc:.1f}.pth'))
            print(f"保存最佳模型，准确率: {best_acc:.1f}%")
        
        # 早停检测
        if (epoch - np.argmax(history['valid_acc'])) >= config['es_patience']:
            print("早停触发")
            break
    
    # 保存训练曲线
    plt.figure(figsize=(12,5))
    plt.subplot(121)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['valid_loss'], label='Valid')
    plt.title('Loss Curve')
    plt.legend()
    
    plt.subplot(122)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['valid_acc'], label='Valid')
    plt.title('Accuracy Curve')
    plt.legend()
    
    plt.savefig(os.path.join(config['model_dir'], 'training_curves.png'))
    plt.close()

def validate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc='Validating'):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    
    return total_loss/len(loader), 100.*correct/total

if __name__ == "__main__":
    train()