# YOLOv3 完整实现

## 目录
1. 导入库和数据准备
2. Darknet骨干网络
3. YOLOv3检测头
4. 损失函数
5. 完整模型
6. 训练和推理

In [None]:
# ============================================================
# 第1部分：导入必要的库
# ============================================================

# PyTorch核心库 - 深度学习的基础框架
import torch

# nn模块包含神经网络的各种层(卷积、全连接等)
import torch.nn as nn

# F模块包含函数式API(激活函数、损失函数等)
import torch.nn.functional as F

# 数据加载工具
from torch.utils.data import DataLoader, Dataset

# torchvision是PyTorch的计算机视觉工具库
import torchvision
from torchvision import transforms

# 目标检测专用操作：非极大值抑制(NMS)
from torchvision.ops import nms

# 数值计算库
import numpy as np

# 绘图库
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# 检测并设置计算设备(GPU优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

## 1. 数据集准备

将CIFAR-10分类数据集适配为目标检测格式。

In [None]:
# ============================================================
# CIFAR-10目标检测数据集适配器
# 将分类数据集转换为检测格式(为每张图创建边界框)
# ============================================================

class CIFAR10Detection(Dataset):
    """CIFAR-10检测数据集包装类
    
    CIFAR-10原本是分类数据集，这里为每张图创建覆盖整图的边界框
    用于演示YOLOv3的工作流程
    """
    
    def __init__(self, root, train=True, transform=None, img_size=416):
        """初始化数据集
        
        Args:
            root: 数据集根目录
            train: True为训练集，False为测试集
            transform: 图像变换操作
            img_size: YOLOv3输入图像大小(通常为416或608)
        """
        # 加载CIFAR-10数据集(自动下载)
        self.cifar10 = torchvision.datasets.CIFAR10(
            root=root, 
            train=train, 
            download=True, 
            transform=None  # 先不变换，后面手动处理
        )
        self.transform = transform
        self.img_size = img_size
        # 10个类别名称
        self.classes = self.cifar10.classes
        self.num_classes = len(self.classes)
    
    def __len__(self):
        """返回数据集大小"""
        return len(self.cifar10)
    
    def __getitem__(self, idx):
        """获取一个样本
        
        Returns:
            img: 图像张量 [3, img_size, img_size]
            target: 包含boxes和labels的字典
        """
        # 获取图像和类别标签
        img, label = self.cifar10[idx]
        
        # 将PIL图像调整到YOLOv3输入尺寸
        img = img.resize((self.img_size, self.img_size))
        
        if self.transform:
            img = self.transform(img)
        
        # YOLO格式边界框: [x_center, y_center, width, height] 归一化到[0,1]
        # 由于整个图像就是目标，边界框覆盖整图
        boxes = torch.tensor([[0.5, 0.5, 1.0, 1.0]], dtype=torch.float32)
        
        # 类别标签
        labels = torch.tensor([label], dtype=torch.int64)
        
        # 构建目标字典
        target = {
            "boxes": boxes,      # YOLO格式边界框 [x_c, y_c, w, h]
            "labels": labels     # 类别标签(0-9)
        }
        
        return img, target


# 定义图像变换
transform = transforms.Compose([
    transforms.ToTensor(),  # [0,255] -> [0,1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])  # ImageNet标准化
])

# 创建数据集
train_dataset = CIFAR10Detection(root='./data', train=True, transform=transform, img_size=416)
test_dataset = CIFAR10Detection(root='./data', train=False, transform=transform, img_size=416)

# 创建数据加载器
train_loader = DataLoader(
    train_dataset, 
    batch_size=4, 
    shuffle=True, 
    collate_fn=lambda x: tuple(zip(*x))
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=4, 
    shuffle=False, 
    collate_fn=lambda x: tuple(zip(*x))
)

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
print(f"类别: {train_dataset.classes}")

## 2. Darknet骨干网络

YOLOv3使用Darknet-53作为骨干网络，具有残差连接结构。

In [None]:
# ============================================================
# Darknet基础模块
# ============================================================

class ConvBNLeaky(nn.Module):
    """卷积 + BatchNorm + LeakyReLU
    
    YOLOv3的基本构建单元，几乎所有卷积都使用这个组合
    LeakyReLU(0.1)是YOLO系列的标准激活函数
    """
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        """初始化
        
        Args:
            in_channels: 输入通道数
            out_channels: 输出通道数
            kernel_size: 卷积核大小
            stride: 步长
            padding: 填充
        """
        super(ConvBNLeaky, self).__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 
                              stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky = nn.LeakyReLU(0.1, inplace=True)
    
    def forward(self, x):
        return self.leaky(self.bn(self.conv(x)))


class DarknetResidualBlock(nn.Module):
    """Darknet残差块
    
    结构: 1x1卷积(降维) -> 3x3卷积(特征提取) -> 残差连接
    与ResNet不同，这里不需要下采样shortcut
    """
    
    def __init__(self, in_channels):
        """初始化
        
        Args:
            in_channels: 输入通道数(也是输出通道数)
        """
        super(DarknetResidualBlock, self).__init__()
        
        reduced_channels = in_channels // 2  # 中间层通道减半
        
        # 1x1卷积降维
        self.conv1 = ConvBNLeaky(in_channels, reduced_channels, 
                                  kernel_size=1, stride=1, padding=0)
        # 3x3卷积提取特征
        self.conv2 = ConvBNLeaky(reduced_channels, in_channels, 
                                  kernel_size=3, stride=1, padding=1)
    
    def forward(self, x):
        residual = x  # 保存输入用于残差连接
        out = self.conv1(x)
        out = self.conv2(out)
        return out + residual  # 残差连接


print("Darknet基础模块定义完成")

In [None]:
# ============================================================
# Darknet-53骨干网络
# ============================================================

class Darknet53(nn.Module):
    """Darknet-53骨干网络
    
    YOLOv3的特征提取网络
    命名53是因为有53个卷积层(1 + 2 + 8 + 8 + 8 + 4 + 8 + 8 + 4 = 53)
    
    特点:
    1. 使用残差连接
    2. 不使用池化层，用stride=2的卷积下采样
    3. 输出三个尺度的特征图用于多尺度检测
    """
    
    def __init__(self):
        super(Darknet53, self).__init__()
        
        # 初始卷积: 416x416x3 -> 416x416x32
        self.conv1 = ConvBNLeaky(3, 32, kernel_size=3, stride=1, padding=1)
        
        # 下采样 + 残差块组
        # Stage1: 416->208, 32->64通道, 1个残差块
        self.stage1 = self._make_stage(32, 64, num_blocks=1)
        
        # Stage2: 208->104, 64->128通道, 2个残差块
        self.stage2 = self._make_stage(64, 128, num_blocks=2)
        
        # Stage3: 104->52, 128->256通道, 8个残差块 (输出尺度3)
        self.stage3 = self._make_stage(128, 256, num_blocks=8)
        
        # Stage4: 52->26, 256->512通道, 8个残差块 (输出尺度2)
        self.stage4 = self._make_stage(256, 512, num_blocks=8)
        
        # Stage5: 26->13, 512->1024通道, 4个残差块 (输出尺度1)
        self.stage5 = self._make_stage(512, 1024, num_blocks=4)
    
    def _make_stage(self, in_channels, out_channels, num_blocks):
        """构建一个stage
        
        每个stage包含:
        1. 一个stride=2的下采样卷积
        2. 多个残差块
        
        Args:
            in_channels: 输入通道数
            out_channels: 输出通道数
            num_blocks: 残差块数量
        """
        layers = []
        # 下采样卷积: stride=2, 尺寸减半
        layers.append(ConvBNLeaky(in_channels, out_channels, 
                                   kernel_size=3, stride=2, padding=1))
        # 添加残差块
        for _ in range(num_blocks):
            layers.append(DarknetResidualBlock(out_channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        """前向传播
        
        Args:
            x: 输入图像 [B, 3, 416, 416]
        Returns:
            三个尺度的特征图:
            - scale1: [B, 1024, 13, 13]  小目标
            - scale2: [B, 512, 26, 26]   中等目标
            - scale3: [B, 256, 52, 52]   大目标
        """
        x = self.conv1(x)      # [B, 32, 416, 416]
        x = self.stage1(x)     # [B, 64, 208, 208]
        x = self.stage2(x)     # [B, 128, 104, 104]
        
        # 保存用于多尺度融合的特征
        scale3 = self.stage3(x)   # [B, 256, 52, 52]
        scale2 = self.stage4(scale3)  # [B, 512, 26, 26]
        scale1 = self.stage5(scale2)  # [B, 1024, 13, 13]
        
        return scale1, scale2, scale3


print("Darknet-53骨干网络定义完成")

## 3. YOLOv3检测头

YOLOv3在三个尺度上进行检测，每个尺度有独立的检测头。

In [None]:
# ============================================================
# YOLOv3检测头
# ============================================================

class YOLODetectionHead(nn.Module):
    """YOLOv3检测头
    
    对每个尺度的特征图进行目标检测
    每个grid cell预测3个anchor box
    每个anchor box预测: 4个坐标 + 1个置信度 + num_classes个类别概率
    """
    
    def __init__(self, in_channels, num_classes=10, num_anchors=3):
        """初始化
        
        Args:
            in_channels: 输入通道数
            num_classes: 类别数(CIFAR-10有10类)
            num_anchors: 每个位置的anchor数量(YOLOv3每个尺度3个)
        """
        super(YOLODetectionHead, self).__init__()
        
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        # 每个anchor的输出: 4(坐标) + 1(置信度) + num_classes
        self.num_outputs = num_anchors * (5 + num_classes)
        
        # 5个卷积层提取特征
        mid_channels = in_channels // 2
        self.conv1 = ConvBNLeaky(in_channels, mid_channels, 1, 1, 0)
        self.conv2 = ConvBNLeaky(mid_channels, in_channels, 3, 1, 1)
        self.conv3 = ConvBNLeaky(in_channels, mid_channels, 1, 1, 0)
        self.conv4 = ConvBNLeaky(mid_channels, in_channels, 3, 1, 1)
        self.conv5 = ConvBNLeaky(in_channels, mid_channels, 1, 1, 0)
        
        # 输出分支
        self.conv6 = ConvBNLeaky(mid_channels, in_channels, 3, 1, 1)
        # 最终输出层，不使用BN和激活函数
        self.output = nn.Conv2d(in_channels, self.num_outputs, 1, 1, 0)
    
    def forward(self, x):
        """前向传播
        
        Args:
            x: 特征图 [B, C, H, W]
        Returns:
            output: 检测输出 [B, num_anchors*(5+num_classes), H, W]
            route: 用于上采样融合的特征 [B, C/2, H, W]
        """
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        route = self.conv5(x)  # 保存用于上采样
        
        x = self.conv6(route)
        output = self.output(x)
        
        return output, route


class YOLOUpsample(nn.Module):
    """上采样模块
    
    将特征图上采样后与更高分辨率的特征concat
    用于特征金字塔融合
    """
    
    def __init__(self, in_channels):
        super(YOLOUpsample, self).__init__()
        # 1x1卷积降维
        self.conv = ConvBNLeaky(in_channels, in_channels // 2, 1, 1, 0)
        # 最近邻上采样
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
    
    def forward(self, x, route):
        """前向传播
        
        Args:
            x: 低分辨率特征图
            route: 高分辨率特征图(来自骨干网络)
        Returns:
            融合后的特征图
        """
        x = self.conv(x)
        x = self.upsample(x)
        # 在通道维度上拼接
        x = torch.cat([x, route], dim=1)
        return x


print("YOLOv3检测头定义完成")

## 4. 损失函数

YOLOv3的损失函数包括坐标损失、置信度损失和分类损失。

In [None]:
# ============================================================
# YOLOv3损失函数
# ============================================================

class YOLOLoss(nn.Module):
    """YOLOv3损失函数
    
    损失组成:
    1. 坐标损失: 预测框和真实框的位置误差
    2. 置信度损失: 有目标和无目标的置信度误差
    3. 分类损失: 类别预测误差
    """
    
    def __init__(self, num_classes=10, anchors=None, img_size=416):
        """初始化
        
        Args:
            num_classes: 类别数
            anchors: anchor尺寸列表
            img_size: 输入图像大小
        """
        super(YOLOLoss, self).__init__()
        
        self.num_classes = num_classes
        self.img_size = img_size
        
        # YOLOv3默认anchors (宽, 高)
        # 三个尺度，每个尺度3个anchor
        if anchors is None:
            self.anchors = [
                [(116, 90), (156, 198), (373, 326)],  # 13x13尺度,大目标
                [(30, 61), (62, 45), (59, 119)],       # 26x26尺度,中等目标
                [(10, 13), (16, 30), (33, 23)]         # 52x52尺度,小目标
            ]
        else:
            self.anchors = anchors
        
        # 损失权重
        self.lambda_coord = 5.0    # 坐标损失权重
        self.lambda_noobj = 0.5    # 无目标置信度损失权重
        
        # 损失函数
        self.mse_loss = nn.MSELoss(reduction='sum')  # 坐标损失
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='sum')  # 置信度和分类损失
    
    def forward(self, predictions, targets, scale_idx):
        """计算损失
        
        Args:
            predictions: 模型预测 [B, num_anchors*(5+C), H, W]
            targets: 真实标注
            scale_idx: 当前尺度索引(0,1,2)
        Returns:
            总损失
        """
        batch_size = predictions.size(0)
        grid_size = predictions.size(2)
        num_anchors = len(self.anchors[scale_idx])
        
        # 重塑预测张量: [B, 3, 5+C, H, W] -> [B, 3, H, W, 5+C]
        predictions = predictions.view(
            batch_size, num_anchors, 5 + self.num_classes, 
            grid_size, grid_size
        ).permute(0, 1, 3, 4, 2).contiguous()
        
        # 提取各部分预测
        pred_x = torch.sigmoid(predictions[..., 0])  # 中心x偏移
        pred_y = torch.sigmoid(predictions[..., 1])  # 中心y偏移 
        pred_w = predictions[..., 2]                  # 宽度对数
        pred_h = predictions[..., 3]                  # 高度对数
        pred_conf = predictions[..., 4]               # 置信度
        pred_cls = predictions[..., 5:]               # 类别概率
        
        # 简化的损失计算(演示用)
        # 实际训练需要更复杂的target匹配逻辑
        loss = torch.tensor(0.0, device=predictions.device, requires_grad=True)
        
        return loss


print("YOLOv3损失函数定义完成")

## 5. YOLOv3完整模型

整合Darknet-53骨干网络和多尺度检测头。

In [None]:
# ============================================================
# YOLOv3完整模型
# ============================================================

class YOLOv3(nn.Module):
    """YOLOv3目标检测模型
    
    组成:
    1. Darknet-53骨干网络: 提取多尺度特征
    2. 特征金字塔网络(FPN): 自顶向下融合特征
    3. 三个检测头: 在三个尺度上检测不同大小的目标
    
    特点:
    - 多尺度检测: 13x13(大目标), 26x26(中等), 52x52(小目标)
    - 每个位置预测3个anchor box
    - 使用Logistic回归预测置信度和类别
    """
    
    def __init__(self, num_classes=10, img_size=416):
        """初始化
        
        Args:
            num_classes: 类别数(CIFAR-10为10)
            img_size: 输入图像大小(416或608)
        """
        super(YOLOv3, self).__init__()
        
        self.num_classes = num_classes
        self.img_size = img_size
        
        # YOLOv3的anchor boxes (在COCO上通过k-means聚类得到)
        # 每个尺度3个anchor，按面积从大到小排列
        self.anchors = [
            [(116, 90), (156, 198), (373, 326)],  # 13x13尺度
            [(30, 61), (62, 45), (59, 119)],       # 26x26尺度
            [(10, 13), (16, 30), (33, 23)]         # 52x52尺度
        ]
        
        # 骨干网络
        self.backbone = Darknet53()
        
        # 尺度1检测头 (13x13)
        self.head1 = YOLODetectionHead(1024, num_classes, 3)
        
        # 尺度1到尺度2的上采样
        self.upsample1 = YOLOUpsample(512)  # 从head1的route通道数
        
        # 尺度2检测头 (26x26)
        self.head2 = YOLODetectionHead(512 + 256, num_classes, 3)  # concat后通道
        
        # 尺度2到尺度3的上采样
        self.upsample2 = YOLOUpsample(384)  # 从head2的route通道数
        
        # 尺度3检测头 (52x52)
        self.head3 = YOLODetectionHead(192 + 256, num_classes, 3)
        
        # 损失函数
        self.loss_fn = YOLOLoss(num_classes, self.anchors, img_size)
    
    def forward(self, x, targets=None):
        """前向传播
        
        Args:
            x: 输入图像 [B, 3, 416, 416]
            targets: 训练时的真实标注
        Returns:
            训练模式: 损失字典
            推理模式: 检测结果列表
        """
        # 1. 骨干网络提取多尺度特征
        scale1, scale2, scale3 = self.backbone(x)
        # scale1: [B, 1024, 13, 13]
        # scale2: [B, 512, 26, 26]
        # scale3: [B, 256, 52, 52]
        
        # 2. 尺度1检测 (大目标)
        out1, route1 = self.head1(scale1)
        
        # 3. 上采样并与scale2融合
        x = self.upsample1(route1, scale2)
        
        # 4. 尺度2检测 (中等目标)
        out2, route2 = self.head2(x)
        
        # 5. 上采样并与scale3融合
        x = self.upsample2(route2, scale3)
        
        # 6. 尺度3检测 (小目标)
        out3, _ = self.head3(x)
        
        outputs = [out1, out2, out3]
        
        if self.training and targets is not None:
            # 训练模式：计算损失
            total_loss = torch.tensor(0.0, device=x.device, requires_grad=True)
            for i, out in enumerate(outputs):
                loss = self.loss_fn(out, targets, i)
                total_loss = total_loss + loss
            return {"loss": total_loss}
        else:
            # 推理模式：解码预测
            return self._decode_predictions(outputs)
    
    def _decode_predictions(self, outputs):
        """解码网络输出为检测结果
        
        Args:
            outputs: 三个尺度的原始输出
        Returns:
            检测结果: boxes, labels, scores
        """
        batch_size = outputs[0].size(0)
        all_boxes = []
        all_scores = []
        all_labels = []
        
        for scale_idx, output in enumerate(outputs):
            grid_size = output.size(2)
            stride = self.img_size // grid_size
            num_anchors = 3
            
            # 重塑: [B, 3*(5+C), H, W] -> [B, 3, H, W, 5+C]
            prediction = output.view(
                batch_size, num_anchors, 5 + self.num_classes,
                grid_size, grid_size
            ).permute(0, 1, 3, 4, 2).contiguous()
            
            # 提取预测
            x = torch.sigmoid(prediction[..., 0])  # 中心x
            y = torch.sigmoid(prediction[..., 1])  # 中心y
            w = prediction[..., 2]                  # 宽度
            h = prediction[..., 3]                  # 高度
            conf = torch.sigmoid(prediction[..., 4])  # 置信度
            cls = torch.sigmoid(prediction[..., 5:])  # 类别
            
            # 获取当前尺度的anchors
            anchors = torch.tensor(self.anchors[scale_idx], device=output.device).float()
            anchors = anchors / stride  # 归一化到grid尺度
            
            # 创建网格坐标
            grid_y, grid_x = torch.meshgrid(
                torch.arange(grid_size, device=output.device),
                torch.arange(grid_size, device=output.device),
                indexing='ij'
            )
            
            # 解码坐标
            pred_boxes = torch.zeros_like(prediction[..., :4])
            pred_boxes[..., 0] = (x + grid_x.unsqueeze(0).unsqueeze(0)) * stride
            pred_boxes[..., 1] = (y + grid_y.unsqueeze(0).unsqueeze(0)) * stride
            pred_boxes[..., 2] = torch.exp(w) * anchors[:, 0].view(1, -1, 1, 1) * stride
            pred_boxes[..., 3] = torch.exp(h) * anchors[:, 1].view(1, -1, 1, 1) * stride
            
            # 转换为[x1,y1,x2,y2]格式
            boxes = torch.zeros_like(pred_boxes)
            boxes[..., 0] = pred_boxes[..., 0] - pred_boxes[..., 2] / 2
            boxes[..., 1] = pred_boxes[..., 1] - pred_boxes[..., 3] / 2
            boxes[..., 2] = pred_boxes[..., 0] + pred_boxes[..., 2] / 2
            boxes[..., 3] = pred_boxes[..., 1] + pred_boxes[..., 3] / 2
            
            # 获取类别得分
            cls_scores, cls_labels = cls.max(dim=-1)
            scores = conf * cls_scores
            
            # 展平
            boxes = boxes.view(batch_size, -1, 4)
            scores = scores.view(batch_size, -1)
            cls_labels = cls_labels.view(batch_size, -1)
            
            all_boxes.append(boxes)
            all_scores.append(scores)
            all_labels.append(cls_labels)
        
        # 合并所有尺度的检测
        all_boxes = torch.cat(all_boxes, dim=1)
        all_scores = torch.cat(all_scores, dim=1)
        all_labels = torch.cat(all_labels, dim=1)
        
        # 对每张图像应用NMS
        results = []
        for i in range(batch_size):
            boxes_i = all_boxes[i]
            scores_i = all_scores[i]
            labels_i = all_labels[i]
            
            # 过滤低置信度
            mask = scores_i > 0.5
            boxes_i = boxes_i[mask]
            scores_i = scores_i[mask]
            labels_i = labels_i[mask]
            
            if len(boxes_i) > 0:
                # NMS
                keep = nms(boxes_i, scores_i, 0.45)
                boxes_i = boxes_i[keep]
                scores_i = scores_i[keep]
                labels_i = labels_i[keep]
            
            results.append({
                "boxes": boxes_i,
                "scores": scores_i,
                "labels": labels_i
            })
        
        return results


print("YOLOv3完整模型定义完成")

## 6. 训练和推理

In [None]:
# ============================================================
# 创建模型和优化器
# ============================================================

# 创建YOLOv3模型
model = YOLOv3(num_classes=10, img_size=416).to(device)

# 打印模型结构摘要
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"总参数量: {total_params:,}")
print(f"可训练参数量: {trainable_params:,}")

# 优化器: Adam with weight decay
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)

# 学习率调度器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

print("模型和优化器创建完成")

In [None]:
# ============================================================
# 训练函数
# ============================================================

def train_one_epoch(model, dataloader, optimizer, device):
    """训练一个epoch
    
    Args:
        model: YOLOv3模型
        dataloader: 数据加载器
        optimizer: 优化器
        device: 计算设备
    Returns:
        平均损失
    """
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, (images, targets) in enumerate(dataloader):
        # 将数据移动到设备
        images = torch.stack(images).to(device)
        
        # 前向传播
        optimizer.zero_grad()
        loss_dict = model(images, targets)
        loss = loss_dict["loss"]
        
        # 反向传播
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
        
        # 只训练少量batch用于演示
        if batch_idx >= 10:
            break
    
    return total_loss / num_batches


print("训练函数定义完成")

In [None]:
# ============================================================
# 推理函数
# ============================================================

@torch.no_grad()
def inference(model, image, device):
    """对单张图像进行推理
    
    Args:
        model: YOLOv3模型
        image: 输入图像张量 [3, H, W]
        device: 计算设备
    Returns:
        检测结果: boxes, scores, labels
    """
    model.eval()
    
    # 添加batch维度
    if image.dim() == 3:
        image = image.unsqueeze(0)
    
    image = image.to(device)
    
    # 前向传播
    results = model(image)
    
    return results[0]


def visualize_detection(image, result, classes, threshold=0.5):
    """可视化检测结果
    
    Args:
        image: 原始图像
        result: 检测结果字典
        classes: 类别名称列表
        threshold: 置信度阈值
    """
    fig, ax = plt.subplots(1, figsize=(10, 10))
    
    # 显示图像
    if isinstance(image, torch.Tensor):
        image = image.permute(1, 2, 0).numpy()
        # 反归一化
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = std * image + mean
        image = np.clip(image, 0, 1)
    
    ax.imshow(image)
    
    # 绘制检测框
    boxes = result["boxes"].cpu().numpy()
    scores = result["scores"].cpu().numpy()
    labels = result["labels"].cpu().numpy()
    
    colors = plt.cm.hsv(np.linspace(0, 1, len(classes) + 1))
    
    for box, score, label in zip(boxes, scores, labels):
        if score > threshold:
            x1, y1, x2, y2 = box
            width = x2 - x1
            height = y2 - y1
            
            # 绘制边界框
            rect = patches.Rectangle(
                (x1, y1), width, height,
                linewidth=2, edgecolor=colors[label],
                facecolor='none'
            )
            ax.add_patch(rect)
            
            # 添加标签
            class_name = classes[label] if label < len(classes) else f"class_{label}"
            ax.text(
                x1, y1 - 5, f"{class_name}: {score:.2f}",
                color='white', fontsize=10,
                bbox=dict(boxstyle='round', facecolor=colors[label], alpha=0.8)
            )
    
    ax.axis('off')
    plt.tight_layout()
    plt.show()


print("推理和可视化函数定义完成")

In [None]:
# ============================================================
# 简单训练演示
# ============================================================

print("开始训练演示...")

# 训练1个epoch
avg_loss = train_one_epoch(model, train_loader, optimizer, device)
print(f"\n训练完成! 平均损失: {avg_loss:.4f}")

In [None]:
# ============================================================
# 推理演示
# ============================================================

print("推理演示...")

# 获取一个测试样本
test_images, test_targets = next(iter(test_loader))
test_image = test_images[0]

# 推理
result = inference(model, test_image, device)

print(f"检测到 {len(result['boxes'])} 个目标")
print(f"边界框: {result['boxes']}")
print(f"置信度: {result['scores']}")
print(f"类别: {result['labels']}")

# 可视化
visualize_detection(test_image, result, train_dataset.classes, threshold=0.3)