# GPT 图像分类训练与推理

本 Notebook 使用 `archive/train` 与 `archive/valid` 中的 YOLO 格式钢板缺陷图像，构建端到端的分类训练与推理流程。

主要步骤：
- 读取 YOLO 标注并整理成分类数据索引
- 构建 PyTorch Dataset 与 DataLoader
- 微调预训练 ResNet50 分类模型
- 保存最优权重并提供单张图片推理函数


## 0. 准备环境

首次运行前请安装依赖（根据实际 CUDA 版本调整 PyTorch 安装命令）。已安装的环境可直接跳过此步骤。


In [None]:
!curl -L -o ./neu-yolo.zip https://www.kaggle.com/api/v1/datasets/download/zymzym/neu-yolo
!unzip -q neu-yolo.zip
!rm neu-yolo.zip
# 如已安装依赖，可跳过此步骤
# !pip install --upgrade pip
# CPU 环境示例: !pip install torch torchvision torchaudio
# GPU 环境示例: !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install numpy pandas matplotlib seaborn scikit-learn tqdm Pillow


## 1. 导入库


In [2]:
import os
import random
from pathlib import Path
from collections import Counter, defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights
from PIL import Image

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from IPython.display import display


  from .autonotebook import tqdm as notebook_tqdm


## 2. 基础配置与随机种子


In [3]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.benchmark = True

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

BASE_DIR = Path.cwd()
if not (BASE_DIR / 'train').exists():
    candidate = BASE_DIR / 'archive'
    if candidate.exists():
        BASE_DIR = candidate

TRAIN_ROOT = BASE_DIR / 'train' / 'train'
VALID_ROOT = BASE_DIR / 'valid' / 'valid'
for name, root in {'train': TRAIN_ROOT, 'valid': VALID_ROOT}.items():
    images_dir = root / 'images'
    labels_dir = root / 'labels'
    if not images_dir.exists():
        raise FileNotFoundError(f'未找到 {images_dir}，请确认 notebook 的工作路径。')
    if not labels_dir.exists():
        raise FileNotFoundError(f'未找到 {labels_dir}，请确认 YOLO 标签是否存在。')
    num_images = len(list(images_dir.glob('*')))
    num_labels = len(list(labels_dir.glob('*.txt')))
    print(f"{name:<5} images: {num_images} | labels: {num_labels}")

CHECKPOINT_DIR = BASE_DIR / 'checkpoints'
CHECKPOINT_DIR.mkdir(exist_ok=True)
BEST_MODEL_PATH = CHECKPOINT_DIR / 'resnet50_best.pth'
print(f'Checkpoint dir: {CHECKPOINT_DIR}')


Using device: cpu
train images: 1770 | labels: 1770
valid images: 30 | labels: 30
Checkpoint dir: /Users/zhaizeyu/Documents/vscode/archive/checkpoints


## 3. 读取 YOLO 标注并构建索引

遍历图像文件与对应的 YOLO 标签，记录图片路径、类别名称以及标签中出现的类别 ID。


In [4]:
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp'}

def extract_class_name(image_path: Path) -> str:
    return image_path.stem.rsplit('_', 1)[0]

def parse_label_file(label_path: Path) -> list:
    if not label_path.exists():
        return []
    ids = []
    with open(label_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            try:
                ids.append(int(parts[0]))
            except (ValueError, IndexError):
                continue
    return ids

def collect_records(split_name: str, split_root: Path) -> list:
    images_dir = split_root / 'images'
    labels_dir = split_root / 'labels'
    records = []
    for img_path in sorted(images_dir.glob('*')):
        if img_path.suffix.lower() not in IMAGE_EXTENSIONS:
            continue
        label_path = labels_dir / f'{img_path.stem}.txt'
        yolo_ids = parse_label_file(label_path)
        records.append({
            'split': split_name,
            'image_path': img_path,
            'label_name': extract_class_name(img_path),
            'yolo_ids': yolo_ids,
            'label_path': label_path
        })
    return records

train_records_raw = collect_records('train', TRAIN_ROOT)
valid_records_raw = collect_records('valid', VALID_ROOT)
all_records = train_records_raw + valid_records_raw
print(f'Train images: {len(train_records_raw)} | Valid images: {len(valid_records_raw)} | Total: {len(all_records)}')

multi_label_records = [rec for rec in all_records if len(set(rec['yolo_ids'])) > 1]
print(f'Multi-class YOLO annotations: {len(multi_label_records)} (文件名前缀将作为分类标签)')
if multi_label_records:
    sample = multi_label_records[0]
    print(f"示例: {sample['image_path'].name} -> {sample['yolo_ids']}")


Train images: 1770 | Valid images: 30 | Total: 1800
Multi-class YOLO annotations: 123 (文件名前缀将作为分类标签)
示例: crazing_104.jpg -> [0, 0, 0, 2]


## 4. 类别统计

依据文件名前缀统计类别分布，并查看 YOLO 标签中出现的主次类别，用于确认映射关系。


In [None]:
name_to_id_counts = defaultdict(Counter)
for rec in all_records:
    for cls_id in rec['yolo_ids']:
        name_to_id_counts[rec['label_name']][cls_id] += 1

summary_rows = []
for name, counter in name_to_id_counts.items():
    total = sum(counter.values())
    if total:
        top_id, top_count = counter.most_common(1)[0]
        top_ratio = top_count / total
    else:
        top_id, top_ratio = None, 0.0
    sample_count = sum(1 for rec in all_records if rec['label_name'] == name)
    summary_rows.append({
        'class_name': name,
        'preferred_yolo_id': top_id,
        'top_id_ratio': round(top_ratio, 4),
        'samples': sample_count
    })
summary_df = pd.DataFrame(summary_rows).sort_values('preferred_yolo_id')
display(summary_df)

counts_df = pd.DataFrame({
    'split': [rec['split'] for rec in all_records],
    'class_name': [rec['label_name'] for rec in all_records]
})
distribution = counts_df.groupby(['split', 'class_name']).size().unstack(fill_value=0)
display(distribution)


## 5. 构建训练 / 验证 / 测试划分

将所有样本按 7:2:1 比例做分层划分，确保各类别在不同子集中的占比接近。


In [None]:
id_name_pairs = []
for _, row in summary_df.iterrows():
    if row['preferred_yolo_id'] is not None:
        id_name_pairs.append((int(row['preferred_yolo_id']), row['class_name']))
id_name_pairs = sorted(id_name_pairs, key=lambda x: x[0])
class_names = [name for _, name in id_name_pairs]
class_to_idx = {name: idx for idx, name in enumerate(class_names)}
idx_to_class = {idx: name for name, idx in class_to_idx.items()}
print(f'Class order: {class_names}')

ratios = (0.7, 0.2, 0.1)
train_ratio, val_ratio, test_ratio = ratios
assert abs(sum(ratios) - 1.0) < 1e-6

indices = np.arange(len(all_records))
labels = np.array([class_to_idx[rec['label_name']] for rec in all_records])
train_idx, temp_idx = train_test_split(
    indices,
    test_size=val_ratio + test_ratio,
    stratify=labels,
    random_state=SEED,
)
temp_labels = labels[temp_idx]
val_fraction = val_ratio / (val_ratio + test_ratio)
val_idx, test_idx = train_test_split(
    temp_idx,
    test_size=1 - val_fraction,
    stratify=temp_labels,
    random_state=SEED,
)

def select_records(base, idxs):
    return [base[i] for i in idxs]

train_records = select_records(all_records, train_idx)
val_records = select_records(all_records, val_idx)
test_records = select_records(all_records, test_idx)

def attach_label_idx(records):
    for rec in records:
        rec['label_idx'] = class_to_idx[rec['label_name']]

attach_label_idx(train_records)
attach_label_idx(val_records)
attach_label_idx(test_records)

print(f"Train: {len(train_records)} | Val: {len(val_records)} | Test: {len(test_records)}")
for split_name, records in [('Train', train_records), ('Val', val_records), ('Test', test_records)]:
    counter = Counter([rec['label_name'] for rec in records])
    print(f"{split_name} per class: {dict(counter)}")


## 6. Dataset 与 DataLoader


In [7]:
class YoloClassificationDataset(Dataset):
    """基于 YOLO 标注的分类数据集。"""
    def __init__(self, records, transform=None):
        self.records = list(records)
        self.transform = transform

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        record = self.records[idx]
        image = Image.open(record['image_path']).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        label = record['label_idx']
        return image, label

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 0  # 多进程在 Notebook 中易触发 pickling 报错
PIN_MEMORY = DEVICE.type == 'cuda' and NUM_WORKERS > 0

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = YoloClassificationDataset(train_records, transform=train_transform)
val_dataset = YoloClassificationDataset(val_records, transform=eval_transform)
test_dataset = YoloClassificationDataset(test_records, transform=eval_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

print(f'Dataloaders ready -> batches: train={len(train_loader)}, val={len(val_loader)}, test={len(test_loader)}')


Dataloaders ready -> batches: train=40, val=12, test=6


## 7. 可视化样本


In [None]:
def show_samples(dataset, num_images=9):
    cols = int(num_images ** 0.5)
    rows = (num_images + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    axes = axes.flatten()
    for ax in axes:
        ax.axis('off')
    chosen_indices = np.random.choice(len(dataset), size=min(num_images, len(dataset)), replace=False)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    for ax, idx in zip(axes, chosen_indices):
        image, label = dataset[idx]
        image_np = image.numpy().transpose(1, 2, 0)
        image_np = image_np * std + mean
        image_np = np.clip(image_np, 0, 1)
        ax.imshow(image_np)
        ax.set_title(idx_to_class[label])
    plt.tight_layout()

show_samples(train_dataset, num_images=9)


## 8. 创建模型与优化器


In [None]:
def create_model(num_classes: int, dropout: float = 0.3) -> nn.Module:
    model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(dropout),
        nn.Linear(in_features, num_classes)
    )
    return model

NUM_CLASSES = len(class_names)
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
NUM_EPOCHS = 2

model = create_model(NUM_CLASSES).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
USE_AMP = DEVICE.type == 'cuda'
print(model.fc)


## 9. 训练与验证函数


In [10]:
def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0
    progress = tqdm(loader, desc='Train', leave=False)
    for images, labels in progress:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            outputs = model(images)
            loss = criterion(outputs, labels)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        preds = outputs.argmax(dim=1)
        running_loss += loss.item() * images.size(0)
        running_corrects += (preds == labels).sum().item()
        total += labels.size(0)
        progress.set_postfix(loss=running_loss / max(total, 1), acc=running_corrects / max(total, 1))
    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device, return_details=False):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Eval', leave=False):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, labels)
            preds = outputs.argmax(dim=1)
            running_loss += loss.item() * images.size(0)
            running_corrects += (preds == labels).sum().item()
            total += labels.size(0)
            if return_details:
                all_preds.extend(preds.cpu().tolist())
                all_labels.extend(labels.cpu().tolist())
    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total
    if return_details:
        return epoch_loss, epoch_acc, all_labels, all_preds
    return epoch_loss, epoch_acc

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs, save_path, use_amp=True):
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp) if use_amp else None
    best_acc = 0.0
    for epoch in range(1, num_epochs + 1):
        print(f'Epoch {epoch}/{num_epochs}')
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler=scaler)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        if scheduler is not None:
            scheduler.step()
        print(f"Train loss: {train_loss:.4f} | acc: {train_acc:.4f}")
        print(f"Val   loss: {val_loss:.4f} | acc: {val_acc:.4f}")
        if val_acc > best_acc:
            best_acc = val_acc
            checkpoint = {
                'epoch': epoch,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
                'class_names': class_names,
                'config': {
                    'img_size': IMG_SIZE,
                    'mean': [0.485, 0.456, 0.406],
                    'std': [0.229, 0.224, 0.225]
                }
            }
            torch.save(checkpoint, save_path)
            print(f"✓ Saved new best model to {save_path} (val_acc={val_acc:.4f})")
    print(f"Best validation accuracy: {best_acc:.4f}")
    return history


## 10. 开始训练


In [None]:
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=DEVICE,
    num_epochs=NUM_EPOCHS,
    save_path=BEST_MODEL_PATH,
    use_amp=USE_AMP
)


## 11. 可视化训练曲线


In [None]:
def plot_history(history_dict):
    epochs = range(1, len(history_dict['train_loss']) + 1)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].plot(epochs, history_dict['train_loss'], label='Train')
    axes[0].plot(epochs, history_dict['val_loss'], label='Val')
    axes[0].set_title('Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()

    axes[1].plot(epochs, history_dict['train_acc'], label='Train')
    axes[1].plot(epochs, history_dict['val_acc'], label='Val')
    axes[1].set_title('Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].legend()

    plt.tight_layout()

plot_history(history)


## 12. 加载最优模型并在测试集评估


In [None]:
best_checkpoint = torch.load(BEST_MODEL_PATH, map_location=DEVICE)
print(f"Best epoch: {best_checkpoint['epoch']} | Val acc: {best_checkpoint['val_acc']:.4f}")
model.load_state_dict(best_checkpoint['model_state'])
test_loss, test_acc, y_true, y_pred = evaluate(model, test_loader, criterion, DEVICE, return_details=True)
print(f"Test loss: {test_loss:.4f} | acc: {test_acc:.4f}")

report = classification_report(y_true, y_pred, target_names=class_names)
print(report)

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.show()


## 13. 推理函数


In [14]:
def predict_image(image_path: Path, model: nn.Module, transform, device):
    model.eval()
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(input_tensor)
        probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()
    pred_idx = int(np.argmax(probs))
    pred_class = idx_to_class[pred_idx]
    confidence = float(probs[pred_idx])
    return {
        'image': image,
        'probs': probs,
        'pred_class': pred_class,
        'confidence': confidence
    }

def show_prediction(result, title: str = None):
    plt.figure(figsize=(4, 4))
    plt.imshow(result['image'])
    plt.axis('off')
    caption = title or ''
    caption += f"Pred: {result['pred_class']} ({result['confidence']:.2%})"
    plt.title(caption)
    plt.tight_layout()
    plt.show()


## 14. 推理示例


In [None]:
sample_path = Path(test_records[0]['image_path'])
print(f'Sample image: {sample_path}')
prediction = predict_image(sample_path, model, eval_transform, DEVICE)
show_prediction(prediction, title='Test Sample')
print('Probability vector:', prediction['probs'])
