# Electrical Component Detection Pipeline

This notebook consolidates the refactored Faster R-CNN training and inference workflow into a single place for convenient experimentation on Kaggle. It provides reusable configuration objects, dataset loaders with optional augmentation, detailed metric utilities (including per-class TP/FP/FN and mAP), and helpers for both training and inference.


In [1]:
from __future__ import annotations

import argparse
import contextlib
import inspect
import json
import logging
import math
import multiprocessing as mp
import os
import random
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union

import numpy as np
import pandas as pd
import torch
from PIL import Image as PILImage, ImageDraw, ImageEnhance, ImageFont
from torch import Tensor, nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from torchvision.models.detection import (
    FasterRCNN_ResNet50_FPN_V2_Weights,
    fasterrcnn_resnet50_fpn_v2,
)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as TVF
from tqdm import tqdm


## Configuration objects


In [2]:
"""Configuration objects for the electrical component detection project."""

from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple


DEFAULT_PRETRAINED_URL = (
    "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth"
)

#0.999 menas class does not exist
DEFAULT_CLASS_SCORE_THRESHOLDS = {
    3: 0.999,  
    6: 0.8,
    7: 0.9,
    8: 0.999,
    12: 0.999,
    16: 0.97,
    17: 0.999,
    20: 0.9,
    21: 0.9,
    24: 0.999,
    25: 0.97,
    26: 0.999,
    30: 0.95,
}


@dataclass
class DatasetConfig:
    """Configuration describing the dataset layout and metadata."""

    base_dir: Path = Path("data")
    train_split: str = "train"
    valid_split: str = "valid"
    test_split: str = "test"
    image_folder: str = "images"
    label_folder: str = "labels"
    num_classes: int = 32
    class_names: Tuple[str, ...] = ()

    def __post_init__(self) -> None:
        if not self.class_names:
            # Fallback names are useful for logging when a mapping file is not provided.
            self.class_names = tuple(f"class_{idx:02d}" for idx in range(self.num_classes))


@dataclass
class TrainingConfig:
    """Hyper-parameters and runtime settings for model training."""

    epochs: int = 20
    batch_size: int = 4
    learning_rate: float = 5e-5
    weight_decay: float = 5e-5
    num_workers: int = 0
    amp: bool = True
    augmentation: bool = True
    mosaic_prob: float = 0.6
    mixup_prob: float = 0.6
    mixup_alpha: float = 0.4
    scale_jitter_min: float = 0.8
    scale_jitter_max: float = 1.2
    rotation_prob: float = 0.5
    rotation_max_degrees: float = 30.0
    affine_prob: float = 0.3
    affine_translate: Tuple[float, float] = (0.1, 0.1)
    affine_scale_range: Tuple[float, float] = (0.9, 1.1)
    affine_shear: Tuple[float, float] = (5, 5)
    small_object: bool = True
    score_threshold: float = 0.6
    iou_threshold: float = 0.5
    eval_interval: int = 1
    seed: int = 37
    output_dir: Path = Path("outputs")
    checkpoint_path: Path = Path("outputs/best_model.pth")
    pretrained_weights_path: Path = Path("weights/fasterrcnn_resnet50_fpn_v2_coco.pth")
    pretrained_weights_url: str = DEFAULT_PRETRAINED_URL
    log_every: int = 20
    resume: bool = False
    resume_path: Optional[Path] = None
    last_checkpoint_path: Path = Path("outputs/last_checkpoint.pth")
    class_score_thresholds: Dict[int, float] = field(
        default_factory=lambda: DEFAULT_CLASS_SCORE_THRESHOLDS.copy()
    )
    exclude_samples: Tuple[str, ...] = tuple()
    fp_visual_dir: Optional[Path] = Path("outputs/fp_images")
    fp_report_path: Optional[Path] = None
    fp_list_path: Optional[Path] = None
    fp_classes: Tuple[int, ...] = (16, 30)

    def ensure_directories(self) -> None:
        """Create output directories if they do not exist."""
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.pretrained_weights_path.parent.mkdir(parents=True, exist_ok=True)
        self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
        self.last_checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
        if self.fp_visual_dir:
            self.fp_visual_dir.mkdir(parents=True, exist_ok=True)


@dataclass
class InferenceConfig:
    """Options for running model inference and visualisation."""

    score_threshold: float = 0.7
    max_images: int = 200
    output_dir: Path = Path("outputs/inference")
    draw_ground_truth: bool = True
    class_colors: List[str] = field(default_factory=list)
    class_score_thresholds: Dict[int, float] = field(
        default_factory=lambda: DEFAULT_CLASS_SCORE_THRESHOLDS.copy()
    )

    def ensure_directories(self) -> None:
        self.output_dir.mkdir(parents=True, exist_ok=True)


## Dataset loading and augmentation


In [3]:
"""Dataset and data loading utilities for electrical component detection."""

import logging
import math
import multiprocessing as mp
import os
import random
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple

import numpy as np
import pandas as pd
import torch
from PIL import Image as PILImage, ImageEnhance
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as TVF




LOGGER = logging.getLogger(__name__)


@dataclass
class AugmentationParams:
    """Parameters controlling the dataset level image augmentations."""

    horizontal_flip_prob: float = 0.5
    vertical_flip_prob: float = 0.2
    brightness: float = 0.2
    contrast: float = 0.2
    saturation: float = 0.2
    hue: float = 0.02
    rotation_prob: float = 0.0
    rotation_max_degrees: float = 0.0
    affine_prob: float = 0.0
    affine_translate: Tuple[float, float] = (0.0, 0.0)
    affine_scale_range: Tuple[float, float] = (1.0, 1.0)
    affine_shear: Tuple[float, float] = (0.0, 0.0)
    mosaic_prob: float = 0.0
    mixup_prob: float = 0.0
    mixup_alpha: float = 0.4
    scale_jitter_range: Tuple[float, float] = (1.0, 1.0)


def load_image_hwc_uint8(path: Path) -> np.ndarray:
    """Load an ``.npy`` image stored as HWC and return an ``uint8`` array."""
    image = np.load(path, allow_pickle=False, mmap_mode="r")

    if image.dtype != np.uint8:
        image = image.astype(np.float32, copy=False)
        vmin, vmax = float(image.min()), float(image.max())
        if 0.0 <= vmin and vmax <= 1.0:
            image = (image * 255.0).round()
        elif -1.0 <= vmin and vmax <= 1.0:
            image = ((image + 1.0) * 0.5 * 255.0).round()
        image = np.clip(image, 0, 255).astype(np.uint8)

    channels = image.shape[2]
    if channels == 1:
        image = np.repeat(image, 3, axis=2)
    elif channels == 4:
        image = image[..., :3]
    if not image.flags.writeable or not image.flags.c_contiguous:
        image = np.array(image, copy=True)
    return image


class ElectricalComponentsDataset(Dataset):
    """Dataset of electrical component detections stored as ``.npy`` images and CSV labels."""

    def __init__(
        self,
        root: Path,
        split: str,
        class_names: Iterable[str],
        transform: Optional[AugmentationParams] = None,
        use_augmentation: bool = False,
        exclude_stems: Optional[Iterable[str]] = None,
    ) -> None:
        self.root = Path(root)
        self.split = split
        self.class_names = list(class_names)
        self.transform_params = transform or AugmentationParams()
        self.use_augmentation = use_augmentation

        self.image_dir = self.root / split / "images"
        self.label_dir = self.root / split / "labels"

        if not self.image_dir.exists():
            raise FileNotFoundError(f"Missing image directory: {self.image_dir}")
        if not self.label_dir.exists():
            raise FileNotFoundError(f"Missing label directory: {self.label_dir}")

        self.image_stems = sorted(p.stem for p in self.label_dir.glob("*.csv"))
        if not self.image_stems:
            raise RuntimeError(f"No label files found in {self.label_dir}")

        exclude_set: Set[str] = set()
        if exclude_stems:
            exclude_set = {Path(stem).stem for stem in exclude_stems}
            if exclude_set:
                before = len(self.image_stems)
                self.image_stems = [stem for stem in self.image_stems if stem not in exclude_set]
                removed = before - len(self.image_stems)
                if removed:
                    LOGGER.info(
                        "Split %s: excluded %d samples based on provided stem filter.",
                        self.split,
                        removed,
                    )

        self.excluded_stems = sorted(exclude_set)

        # Pre-load all annotations to reduce I/O during training.
        self.annotations: Dict[str, pd.DataFrame] = {
            stem: pd.read_csv(self.label_dir / f"{stem}.csv") for stem in self.image_stems
        }

    def __len__(self) -> int:
        return len(self.image_stems)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        stem = self.image_stems[index]
        image, boxes, labels = self._load_raw_sample(stem)

        if self.use_augmentation:
            image, boxes, labels = self._apply_composite_augmentations(stem, image, boxes, labels)
            image, boxes = self._apply_augmentations(image, boxes)

        height, width = image.shape[:2]

        image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        boxes_tensor = torch.from_numpy(boxes).float() if boxes.size else torch.zeros((0, 4), dtype=torch.float32)
        labels_tensor = (
            torch.from_numpy(labels).long() if labels.size else torch.zeros((0,), dtype=torch.long)
        )

        boxes_tensor, labels_tensor = sanitize_boxes_and_labels(
            boxes_tensor, labels_tensor, height, width
        )

        target: Dict[str, torch.Tensor] = {
            "boxes": boxes_tensor,
            "labels": labels_tensor,
            "image_id": torch.tensor(index, dtype=torch.int64),
            "area": (boxes_tensor[:, 2] - boxes_tensor[:, 0])
            * (boxes_tensor[:, 3] - boxes_tensor[:, 1])
            if boxes_tensor.numel()
            else torch.tensor([], dtype=torch.float32),
            "iscrowd": torch.zeros((boxes_tensor.shape[0],), dtype=torch.int64),
            "orig_size": torch.tensor([height, width], dtype=torch.int64),
        }

        return image_tensor, target

    def _load_raw_sample(self, stem: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        image_path = self.image_dir / f"{stem}.npy"
        image = load_image_hwc_uint8(image_path)
        height, width = image.shape[:2]

        ann = self.annotations[stem]
        boxes, labels = self._annotation_to_boxes(ann, width, height)
        return image, boxes, labels

    @staticmethod
    def _annotation_to_boxes(
        ann: pd.DataFrame, width: int, height: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        if ann.empty:
            return np.zeros((0, 4), dtype=np.float32), np.zeros((0,), dtype=np.int64)

        x_center = ann["x_center"].to_numpy(dtype=np.float32)
        y_center = ann["y_center"].to_numpy(dtype=np.float32)
        box_width = ann["width"].to_numpy(dtype=np.float32)
        box_height = ann["height"].to_numpy(dtype=np.float32)

        # Auto-detect normalised coordinates and scale back to pixel space.
        if (
            (x_center.size == 0 or float(x_center.max()) <= 1.0)
            and (y_center.size == 0 or float(y_center.max()) <= 1.0)
            and (box_width.size == 0 or float(box_width.max()) <= 1.0)
            and (box_height.size == 0 or float(box_height.max()) <= 1.0)
        ):
            x_center = x_center * width
            y_center = y_center * height
            box_width = box_width * width
            box_height = box_height * height

        x1 = x_center - box_width / 2.0
        y1 = y_center - box_height / 2.0
        x2 = x_center + box_width / 2.0
        y2 = y_center + box_height / 2.0

        boxes = np.stack([x1, y1, x2, y2], axis=1).astype(np.float32)
        labels = ann["class"].to_numpy(dtype=np.int64) + 1  # shift to 1..K so 0 remains reserved for background
        return boxes, labels

    def _apply_composite_augmentations(
        self,
        stem: str,
        image: np.ndarray,
        boxes: np.ndarray,
        labels: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        params = self.transform_params

        if (
            params.mosaic_prob > 0.0
            and random.random() < params.mosaic_prob
            and len(self.image_stems) >= 4
        ):
            image, boxes, labels = self._apply_mosaic(stem, image, boxes, labels)

        if (
            params.mixup_prob > 0.0
            and params.mixup_alpha > 0.0
            and random.random() < params.mixup_prob
        ):
            image, boxes, labels = self._apply_mixup(stem, image, boxes, labels)

        if params.scale_jitter_range != (1.0, 1.0):
            image, boxes = self._apply_scale_jitter(image, boxes, params.scale_jitter_range)

        return image, boxes, labels

    def _apply_augmentations(
        self, image: np.ndarray, boxes: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        params = self.transform_params
        height, width = image.shape[:2]

        if (
            params.rotation_prob > 0.0
            and params.rotation_max_degrees > 0.0
            and random.random() < params.rotation_prob
        ):
            angle = random.uniform(-params.rotation_max_degrees, params.rotation_max_degrees)
            image, boxes = self._apply_affine_transform(
                image,
                boxes,
                angle=angle,
                translate=(0.0, 0.0),
                scale=1.0,
                shear=(0.0, 0.0),
            )
            height, width = image.shape[:2]

        if params.affine_prob > 0.0 and random.random() < params.affine_prob:
            max_tx = abs(params.affine_translate[0]) * width
            max_ty = abs(params.affine_translate[1]) * height
            translate = (
                random.uniform(-max_tx, max_tx),
                random.uniform(-max_ty, max_ty),
            )

            scale_min, scale_max = params.affine_scale_range
            if scale_min > scale_max:
                scale_min, scale_max = scale_max, scale_min
            scale_min = max(scale_min, 0.0)
            scale_max = max(scale_max, 0.0)
            scale = 1.0
            if scale_max > 0.0:
                if math.isclose(scale_min, scale_max):
                    scale = max(scale_min, 1e-3)
                else:
                    scale = max(random.uniform(scale_min, scale_max), 1e-3)

            shear_x = params.affine_shear[0]
            shear_y = params.affine_shear[1]
            shear = (
                random.uniform(-abs(shear_x), abs(shear_x)),
                random.uniform(-abs(shear_y), abs(shear_y)),
            )

            image, boxes = self._apply_affine_transform(
                image,
                boxes,
                angle=0.0,
                translate=translate,
                scale=scale,
                shear=shear,
            )
            height, width = image.shape[:2]

        if boxes.size and random.random() < params.horizontal_flip_prob:
            image = np.ascontiguousarray(image[:, ::-1, :])
            x1 = width - boxes[:, 2]
            x2 = width - boxes[:, 0]
            boxes[:, 0], boxes[:, 2] = x1, x2

        if boxes.size and random.random() < params.vertical_flip_prob:
            image = np.ascontiguousarray(image[::-1, :, :])
            y1 = height - boxes[:, 3]
            y2 = height - boxes[:, 1]
            boxes[:, 1], boxes[:, 3] = y1, y2

        if params.brightness or params.contrast or params.saturation or params.hue:
            pil = PILImage.fromarray(image)
            if params.brightness:
                enhancer = ImageEnhance.Brightness(pil)
                factor = 1.0 + random.uniform(-params.brightness, params.brightness)
                pil = enhancer.enhance(max(0.1, factor))
            if params.contrast:
                enhancer = ImageEnhance.Contrast(pil)
                factor = 1.0 + random.uniform(-params.contrast, params.contrast)
                pil = enhancer.enhance(max(0.1, factor))
            if params.saturation:
                enhancer = ImageEnhance.Color(pil)
                factor = 1.0 + random.uniform(-params.saturation, params.saturation)
                pil = enhancer.enhance(max(0.1, factor))
            if params.hue:
                hsv_image = pil.convert("HSV")
                h_channel, s_channel, v_channel = hsv_image.split()
                delta = int(params.hue * 255.0 * random.choice([-1, 1]))
                h_channel = h_channel.point(lambda h: (h + delta) % 255)
                hsv_image = PILImage.merge("HSV", (h_channel, s_channel, v_channel))
                pil = hsv_image.convert("RGB")
            image = np.array(pil)

        if boxes.size:
            boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, width)
            boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, height)
        return image, boxes

    def _apply_scale_jitter(
        self, image: np.ndarray, boxes: np.ndarray, scale_range: Tuple[float, float]
    ) -> Tuple[np.ndarray, np.ndarray]:
        min_scale, max_scale = scale_range
        if max_scale <= 0 or min_scale <= 0:
            return image, boxes

        factor = random.uniform(min_scale, max_scale)
        if np.isclose(factor, 1.0):
            return image, boxes

        height, width = image.shape[:2]
        new_height = max(1, int(round(height * factor)))
        new_width = max(1, int(round(width * factor)))

        pil = PILImage.fromarray(image)
        resized = pil.resize((new_width, new_height), resample=PILImage.BILINEAR)
        image = np.array(resized)

        if boxes.size:
            boxes = boxes.copy()
            boxes[:, [0, 2]] *= float(new_width) / float(width)
            boxes[:, [1, 3]] *= float(new_height) / float(height)
        return image, boxes

    def _apply_affine_transform(
        self,
        image: np.ndarray,
        boxes: np.ndarray,
        *,
        angle: float,
        translate: Tuple[float, float],
        scale: float,
        shear: Tuple[float, float],
    ) -> Tuple[np.ndarray, np.ndarray]:
        if not np.isfinite(scale) or scale <= 0.0:
            scale = 1.0

        height, width = image.shape[:2]
        pil = PILImage.fromarray(image)
        translate_int = (
            int(round(float(translate[0]))),
            int(round(float(translate[1]))),
        )
        transformed = TVF.affine(
            pil,
            angle=float(angle),
            translate=translate_int,
            scale=float(max(scale, 1e-3)),
            shear=(float(shear[0]), float(shear[1])),
            interpolation=InterpolationMode.BILINEAR,
            fill=0,
        )
        image_out = np.array(transformed)
        out_height, out_width = image_out.shape[:2]

        if not boxes.size:
            return image_out, boxes.astype(np.float32, copy=False)

        matrix = _compute_affine_forward_matrix(
            center=(width * 0.5, height * 0.5),
            angle=float(angle),
            translate=(float(translate_int[0]), float(translate_int[1])),
            scale=float(max(scale, 1e-3)),
            shear=(float(shear[0]), float(shear[1])),
        )

        corners = _boxes_to_corners(boxes)
        ones = np.ones((corners.shape[0], 1), dtype=np.float32)
        coords = np.concatenate([corners, ones], axis=1)
        full_matrix = np.vstack([matrix, [0.0, 0.0, 1.0]])
        transformed_coords = coords @ full_matrix.T
        transformed_corners = transformed_coords[:, :2].reshape(-1, 4, 2)
        min_xy = transformed_corners.min(axis=1)
        max_xy = transformed_corners.max(axis=1)

        boxes_out = np.concatenate([min_xy, max_xy], axis=1)
        boxes_out[:, [0, 2]] = boxes_out[:, [0, 2]].clip(0, out_width)
        boxes_out[:, [1, 3]] = boxes_out[:, [1, 3]].clip(0, out_height)
        return image_out, boxes_out.astype(np.float32, copy=False)

    def _apply_mosaic(
        self,
        stem: str,
        image: np.ndarray,
        boxes: np.ndarray,
        labels: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        height, width = image.shape[:2]
        candidate_stems = [s for s in self.image_stems if s != stem]
        if len(candidate_stems) < 3:
            return image, boxes, labels

        selected = random.sample(candidate_stems, 3)

        images: List[np.ndarray] = [image]
        boxes_list: List[np.ndarray] = [boxes]
        labels_list: List[np.ndarray] = [labels]

        for other_stem in selected:
            other_img, other_boxes, other_labels = self._load_raw_sample(other_stem)
            other_img, other_boxes = self._resize_like(other_img, other_boxes, width, height)
            images.append(other_img)
            boxes_list.append(other_boxes)
            labels_list.append(other_labels)

        canvas = np.zeros((height * 2, width * 2, 3), dtype=np.uint8)
        offsets = [(0, 0), (0, width), (height, 0), (height, width)]
        combined_boxes: List[np.ndarray] = []
        combined_labels: List[np.ndarray] = []

        for img, bxs, lbls, (y_off, x_off) in zip(images, boxes_list, labels_list, offsets):
            canvas[y_off : y_off + height, x_off : x_off + width] = img
            if bxs.size:
                shifted = bxs.copy()
                shifted[:, [0, 2]] += x_off
                shifted[:, [1, 3]] += y_off
                combined_boxes.append(shifted)
                combined_labels.append(lbls)

        if combined_boxes:
            boxes = np.concatenate(combined_boxes, axis=0)
            labels = np.concatenate(combined_labels, axis=0)
        else:
            boxes = np.zeros((0, 4), dtype=np.float32)
            labels = np.zeros((0,), dtype=np.int64)

        crop_x = random.randint(0, width)
        crop_y = random.randint(0, height)
        canvas = canvas[crop_y : crop_y + height, crop_x : crop_x + width]

        if boxes.size:
            boxes = boxes.copy()
            boxes[:, [0, 2]] -= crop_x
            boxes[:, [1, 3]] -= crop_y

            keep = (
                (boxes[:, 2] > 0)
                & (boxes[:, 3] > 0)
                & (boxes[:, 0] < width)
                & (boxes[:, 1] < height)
            )
            boxes = boxes[keep]
            labels = labels[keep]
            boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, width)
            boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, height)

        return canvas, boxes, labels

    def _apply_mixup(
        self,
        stem: str,
        image: np.ndarray,
        boxes: np.ndarray,
        labels: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        other_stem = self._sample_alternative_stem(stem)
        if other_stem is None:
            return image, boxes, labels

        other_img, other_boxes, other_labels = self._load_raw_sample(other_stem)
        height, width = image.shape[:2]
        other_img, other_boxes = self._resize_like(other_img, other_boxes, width, height)

        alpha = max(self.transform_params.mixup_alpha, 1e-3)
        lam = np.random.beta(alpha, alpha)
        lam = float(np.clip(lam, 0.3, 0.7))

        mixed = (
            image.astype(np.float32) * lam + other_img.astype(np.float32) * (1.0 - lam)
        ).astype(np.uint8)

        if boxes.size and other_boxes.size:
            boxes = np.concatenate([boxes, other_boxes], axis=0)
            labels = np.concatenate([labels, other_labels], axis=0)
        elif other_boxes.size:
            boxes = other_boxes.copy()
            labels = other_labels.copy()

        return mixed, boxes, labels

    def _sample_alternative_stem(self, current: str) -> Optional[str]:
        if len(self.image_stems) <= 1:
            return None
        candidates = [stem for stem in self.image_stems if stem != current]
        if not candidates:
            return None
        return random.choice(candidates)

    @staticmethod
    def _resize_like(
        image: np.ndarray, boxes: np.ndarray, target_width: int, target_height: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        height, width = image.shape[:2]
        if height == target_height and width == target_width:
            return image, boxes

        pil = PILImage.fromarray(image)
        resized = pil.resize((target_width, target_height), resample=PILImage.BILINEAR)
        image = np.array(resized)

        if boxes.size:
            boxes = boxes.copy()
            boxes[:, [0, 2]] *= float(target_width) / float(width)
            boxes[:, [1, 3]] *= float(target_height) / float(height)
        return image, boxes


def _boxes_to_corners(boxes: np.ndarray) -> np.ndarray:
    corners = np.stack(
        [
            boxes[:, [0, 1]],
            boxes[:, [2, 1]],
            boxes[:, [2, 3]],
            boxes[:, [0, 3]],
        ],
        axis=1,
    )
    return corners.reshape(-1, 2).astype(np.float32, copy=False)


def _compute_affine_forward_matrix(
    *,
    center: Tuple[float, float],
    angle: float,
    translate: Tuple[float, float],
    scale: float,
    shear: Tuple[float, float],
) -> np.ndarray:
    rot = math.radians(angle)
    shear_x = math.radians(shear[0])
    shear_y = math.radians(shear[1])

    cx, cy = center
    tx, ty = translate

    cos_sy = math.cos(shear_y)
    if abs(cos_sy) < 1e-6:
        cos_sy = 1e-6 if cos_sy >= 0 else -1e-6

    a = math.cos(rot - shear_y) / cos_sy
    b = -math.cos(rot - shear_y) * math.tan(shear_x) / cos_sy - math.sin(rot)
    c = math.sin(rot - shear_y) / cos_sy
    d = -math.sin(rot - shear_y) * math.tan(shear_x) / cos_sy + math.cos(rot)

    matrix = [a, b, 0.0, c, d, 0.0]
    matrix = [x * scale for x in matrix]

    matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
    matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
    matrix[2] += cx + tx
    matrix[5] += cy + ty

    return np.array(matrix, dtype=np.float32).reshape(2, 3)


def detection_collate(batch: List[Tuple[torch.Tensor, Dict[str, torch.Tensor]]]):
    """Collate function for detection datasets returning lists of tensors."""
    images, targets = zip(*batch)
    return list(images), list(targets)


def _safe_worker_count(requested: int) -> int:
    cpu_count = os.cpu_count() or 1
    if requested <= 0:
        return 0
    # Leave one core free so that the main process remains responsive on small machines.
    max_workers = max(1, cpu_count - 1)
    return min(requested, max_workers)


def _should_force_single_worker(dataset: Dataset) -> bool:
    """Determine whether multiprocessing workers should be disabled."""

    module_name = getattr(dataset.__class__, "__module__", "")
    if module_name in {"__main__", "__mp_main__", "builtins"}:
        return True

    if module_name.startswith("ipykernel"):  # pragma: no cover - notebook specific
        return True

    return running_in_ipython_kernel()


def create_data_loaders(
    dataset: Dataset,
    batch_size: int,
    shuffle: bool,
    num_workers: int,
) -> DataLoader:
    """Create a :class:`~torch.utils.data.DataLoader` for the detection dataset.

    Kaggle notebooks occasionally run in restricted multiprocessing environments where
    ``fork`` based workers cannot be reaped cleanly.  We therefore favour a conservative
    default (``num_workers=0``), detect in-notebook dataset definitions that cannot be
    spawned safely, and fall back to single-process loading automatically when worker
    start-up fails.
    """

    worker_count = _safe_worker_count(num_workers)
    if worker_count > 0 and _should_force_single_worker(dataset):
        LOGGER.info(
            "Detected interactive environment or in-notebook dataset definition; forcing num_workers=0."
        )
        worker_count = 0

    loader_kwargs = dict(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        pin_memory=torch.cuda.is_available(),
        collate_fn=detection_collate,
    )

    if worker_count > 0:
        loader_kwargs["num_workers"] = worker_count
        loader_kwargs["persistent_workers"] = True
        # ``spawn`` avoids PID mismatches that surface as ``AssertionError: can only
        # test a child process`` when the notebook kernel re-uses processes.
        loader_kwargs["multiprocessing_context"] = mp.get_context("spawn")
    else:
        loader_kwargs["num_workers"] = 0

    try:
        return DataLoader(**loader_kwargs)
    except (RuntimeError, OSError, AssertionError) as exc:
        if worker_count == 0:
            raise
        warnings.warn(
            "Falling back to num_workers=0 because DataLoader worker initialisation "
            f"failed with: {exc}",
            RuntimeWarning,
        )
        LOGGER.warning("DataLoader workers failed to start (%s). Using num_workers=0 instead.", exc)
        loader_kwargs.pop("persistent_workers", None)
        loader_kwargs.pop("multiprocessing_context", None)
        loader_kwargs["num_workers"] = 0
        return DataLoader(**loader_kwargs)


## Utility helpers and metrics


In [4]:
"""Utility helpers for training and evaluating the detection model."""

import json
import logging
import random
from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from PIL import Image as PILImage, ImageDraw, ImageFont
from torch import Tensor




# ``DEFAULT_COLORS`` previously lived in the Kaggle notebook version of the
# project.  The constant was referenced when rendering detections but never
# actually defined in the standalone module, which caused a ``NameError`` when
# the training script attempted to export visualisations.  Keeping the palette
# here restores the expected behaviour while remaining independent from the
# notebook.
DEFAULT_COLORS: Tuple[str, ...] = (
    "#FF6B6B",
    "#4ECDC4",
    "#FFD93D",
    "#1A535C",
    "#FF9F1C",
    "#2EC4B6",
    "#E71D36",
    "#9B5DE5",
    "#F15BB5",
    "#00BBF9",
    "#00F5D4",
    "#6C5CE7",
    "#45B7D1",
    "#F9C80E",
    "#F86624",
    "#EA3546",
    "#662E9B",
    "#43BCCD",
    "#A1C181",
    "#BB9F06",
)


def set_seed(seed: int) -> None:
    """Set seeds for the Python, NumPy and PyTorch RNGs."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def sanitize_boxes_and_labels(
    boxes: Tensor, labels: Tensor, height: int, width: int, min_size: float = 1.0
) -> Tuple[Tensor, Tensor]:
    """Clamp bounding boxes to the image size and drop invalid boxes."""
    if boxes.numel() == 0:
        return boxes.reshape(0, 4).float(), labels.reshape(0).long()

    boxes = boxes.clone()
    boxes[:, 0::2] = boxes[:, 0::2].clamp(0, float(width))
    boxes[:, 1::2] = boxes[:, 1::2].clamp(0, float(height))

    widths = boxes[:, 2] - boxes[:, 0]
    heights = boxes[:, 3] - boxes[:, 1]
    keep = (widths > min_size) & (heights > min_size)

    if keep.sum() == 0:
        return boxes.new_zeros((0, 4)), labels.new_zeros((0,), dtype=torch.long)
    return boxes[keep].float(), labels[keep].long()


def compute_iou_matrix(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray:
    """Compute the IoU matrix between two sets of boxes in ``xyxy`` format."""
    if boxes1.size == 0 or boxes2.size == 0:
        return np.zeros((boxes1.shape[0], boxes2.shape[0]), dtype=np.float32)

    x11, y11, x12, y12 = np.split(boxes1, 4, axis=1)
    x21, y21, x22, y22 = np.split(boxes2, 4, axis=1)

    inter_x1 = np.maximum(x11, x21.T)
    inter_y1 = np.maximum(y11, y21.T)
    inter_x2 = np.minimum(x12, x22.T)
    inter_y2 = np.minimum(y12, y22.T)

    inter_w = np.clip(inter_x2 - inter_x1, a_min=0.0, a_max=None)
    inter_h = np.clip(inter_y2 - inter_y1, a_min=0.0, a_max=None)
    inter_area = inter_w * inter_h

    area1 = (x12 - x11) * (y12 - y11)
    area2 = (x22 - x21) * (y22 - y21)

    union = area1 + area2.T - inter_area
    return np.divide(inter_area, union, out=np.zeros_like(inter_area), where=union > 0)


def compute_average_precision(recalls: np.ndarray, precisions: np.ndarray) -> float:
    """Compute the interpolated Average Precision (AP) following COCO style."""
    if recalls.size == 0 or precisions.size == 0:
        return 0.0

    mrec = np.concatenate(([0.0], recalls, [1.0]))
    mpre = np.concatenate(([0.0], precisions, [0.0]))

    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = max(mpre[i - 1], mpre[i])

    recall_points = np.linspace(0, 1, 101)
    precision_interp = np.interp(recall_points, mrec, mpre)
    return float(np.trapz(precision_interp, recall_points))


def accumulate_classification_stats(
    predictions: Sequence[Dict[str, np.ndarray]],
    targets: Sequence[Dict[str, np.ndarray]],
    num_classes: int,
    iou_threshold: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[List[float]], List[List[int]], np.ndarray]:
    """Accumulate TP/FP/FN statistics and matched predictions for AP computation."""
    tp = np.zeros(num_classes, dtype=np.int64)
    fp = np.zeros(num_classes, dtype=np.int64)
    fn = np.zeros(num_classes, dtype=np.int64)
    scores: List[List[float]] = [[] for _ in range(num_classes)]
    matches: List[List[int]] = [[] for _ in range(num_classes)]
    gt_counter = np.zeros(num_classes, dtype=np.int64)

    for pred, tgt in zip(predictions, targets):
        pred_boxes = pred["boxes"]
        pred_scores = pred["scores"]
        pred_labels = pred["labels"].astype(np.int64)

        gt_boxes = tgt["boxes"]
        gt_labels = tgt["labels"].astype(np.int64)

        unique_classes = np.unique(np.concatenate((pred_labels, gt_labels)))
        for cls in unique_classes:
            pb = pred_boxes[pred_labels == cls]
            ps = pred_scores[pred_labels == cls]
            tb = gt_boxes[gt_labels == cls]
            gt_counter[cls] += len(tb)

            if len(tb) == 0:
                fp[cls] += len(pb)
                scores[cls].extend(ps.tolist())
                matches[cls].extend([0] * len(pb))
                continue

            order = np.argsort(-ps)
            pb = pb[order]
            ps = ps[order]
            iou_matrix = compute_iou_matrix(pb, tb)

            matched_gt: set[int] = set()
            for det_idx, score in enumerate(ps):
                if tb.size == 0:
                    fp[cls] += 1
                    scores[cls].append(float(score))
                    matches[cls].append(0)
                    continue

                best_gt = int(np.argmax(iou_matrix[det_idx]))
                best_iou = iou_matrix[det_idx, best_gt]

                if best_iou >= iou_threshold and best_gt not in matched_gt:
                    tp[cls] += 1
                    matched_gt.add(best_gt)
                    scores[cls].append(float(score))
                    matches[cls].append(1)
                else:
                    fp[cls] += 1
                    scores[cls].append(float(score))
                    matches[cls].append(0)

            fn[cls] += len(tb) - len(matched_gt)

    return tp, fp, fn, scores, matches, gt_counter


def identify_false_positive_predictions(
    prediction: Dict[str, np.ndarray],
    target: Dict[str, np.ndarray],
    num_classes: int,
    iou_threshold: float,
) -> List[Dict[str, Union[int, float, List[float]]]]:
    """Return detailed records for false-positive detections in a single sample."""

    boxes = np.asarray(prediction.get("boxes", np.empty((0, 4), dtype=np.float32)), dtype=np.float32)
    scores = np.asarray(prediction.get("scores", np.empty((0,), dtype=np.float32)), dtype=np.float32)
    labels = np.asarray(prediction.get("labels", np.empty((0,), dtype=np.int64)), dtype=np.int64)

    gt_boxes = np.asarray(target.get("boxes", np.empty((0, 4), dtype=np.float32)), dtype=np.float32)
    gt_labels = np.asarray(target.get("labels", np.empty((0,), dtype=np.int64)), dtype=np.int64)

    if boxes.size == 0:
        return []

    fp_records: List[Dict[str, Union[int, float, List[float]]]] = []
    unique_classes = np.unique(labels) if labels.size else np.asarray([], dtype=np.int64)

    for cls in unique_classes:
        cls = int(cls)
        cls_mask = labels == cls
        cls_indices = np.nonzero(cls_mask)[0]
        if cls_indices.size == 0:
            continue

        pb = boxes[cls_mask]
        ps = scores[cls_mask]
        order = np.argsort(-ps)
        pb_sorted = pb[order]
        ps_sorted = ps[order]
        original_indices = cls_indices[order]

        tb = gt_boxes[gt_labels == cls]
        iou_matrix = compute_iou_matrix(pb_sorted, tb) if tb.size else np.zeros((pb_sorted.shape[0], 0), dtype=np.float32)
        matched_gt: set[int] = set()

        for rank, (pred_idx, score_value) in enumerate(zip(original_indices, ps_sorted)):
            if iou_matrix.shape[1]:
                row = iou_matrix[rank]
                best_gt = int(row.argmax())
                best_iou = float(row[best_gt])
            else:
                best_gt = -1
                best_iou = 0.0

            is_true_positive = (
                iou_matrix.shape[1] > 0
                and best_iou >= iou_threshold
                and best_gt not in matched_gt
            )

            if is_true_positive:
                matched_gt.add(best_gt)
                continue

            fp_records.append(
                {
                    "index": int(pred_idx),
                    "class": cls,
                    "score": float(score_value),
                    "best_iou": best_iou,
                    "box": boxes[pred_idx].astype(float).tolist(),
                }
            )

    return fp_records


def compute_detection_metrics(
    predictions: Sequence[Dict[str, np.ndarray]],
    targets: Sequence[Dict[str, np.ndarray]],
    num_classes: int,
    iou_threshold: float,
) -> Dict[str, np.ndarray]:
    """Compute per-class metrics and mAP for detection results."""
    tp, fp, fn, scores, matches, gt_counter = accumulate_classification_stats(
        predictions, targets, num_classes, iou_threshold
    )

    precision = np.divide(tp, np.clip(tp + fp, a_min=1, a_max=None))
    recall = np.divide(tp, np.clip(tp + fn, a_min=1, a_max=None))

    ap = np.zeros(num_classes, dtype=np.float32)
    for cls in range(num_classes):
        if gt_counter[cls] == 0:
            ap[cls] = np.nan
            continue
        if not scores[cls]:
            ap[cls] = 0.0
            continue

        order = np.argsort(-np.asarray(scores[cls]))
        match_array = np.asarray(matches[cls], dtype=np.int32)[order]
        cumulative_tp = np.cumsum(match_array)
        cumulative_fp = np.cumsum(1 - match_array)

        recalls = cumulative_tp / gt_counter[cls]
        precisions = cumulative_tp / np.maximum(cumulative_tp + cumulative_fp, 1)
        ap[cls] = compute_average_precision(recalls, precisions)

    valid_ap = ap[np.isfinite(ap)]
    map_value = float(valid_ap.mean()) if valid_ap.size else 0.0

    return {
        "TP": tp,
        "FP": fp,
        "FN": fn,
        "precision": precision,
        "recall": recall,
        "AP": ap,
        "mAP": map_value,
        "gt_counter": gt_counter,
    }


def score_threshold_mask(
    scores: np.ndarray,
    labels: np.ndarray,
    default_threshold: float,
    class_thresholds: Mapping[int, float],
) -> np.ndarray:
    """Return a boolean mask keeping predictions that pass per-class thresholds."""

    if scores.size == 0:
        return np.zeros_like(scores, dtype=bool)

    labels_int = labels.astype(np.int64, copy=False)
    thresholds = np.full(scores.shape, default_threshold, dtype=scores.dtype)
    if class_thresholds:
        labels_fg = labels_int - 1  # convert to original 0-based ids
        for cls, value in class_thresholds.items():
            thresholds[labels_fg == int(cls)] = float(value)

    keep = labels_int != 0  # drop background predictions outright
    keep &= scores >= thresholds
    return keep


def parse_class_threshold_entries(entries: Sequence[str]) -> Dict[int, float]:
    """Parse ``CLS=THRESH`` strings into a mapping of per-class thresholds."""

    thresholds: Dict[int, float] = {}
    for entry in entries:
        if not entry:
            continue

        if "=" in entry:
            key, value = entry.split("=", 1)
        elif ":" in entry:
            key, value = entry.split(":", 1)
        else:
            raise ValueError(f"Invalid class threshold format: {entry!r}")

        key = key.strip()
        value = value.strip()
        if not key or not value:
            raise ValueError(f"Invalid class threshold entry: {entry!r}")

        thresholds[int(key)] = float(value)

    return thresholds


def running_in_ipython_kernel() -> bool:
    """Return ``True`` when executing inside an IPython/Jupyter kernel."""

    try:  # ``IPython`` may be absent in some execution environments.
        from IPython import get_ipython  # type: ignore
    except Exception:  # pragma: no cover - depends on environment
        return False

    shell = get_ipython()
    return bool(shell and getattr(shell, "kernel", None))


def emit_metric_lines(
    lines: Sequence[str],
    *,
    logger: Optional[logging.Logger] = None,
    force_print: Optional[bool] = None,
) -> None:
    """Log metric lines and optionally echo them to ``stdout``.

    Kaggle notebooks buffer ``logging`` output differently from regular Python
    scripts, so we proactively mirror the messages with ``print`` when we detect
    an IPython kernel.  Callers may override this behaviour by passing
    ``force_print`` explicitly.
    """

    if logger is None:
        logger = logging.getLogger(__name__)

    should_print = force_print if force_print is not None else running_in_ipython_kernel()

    for line in lines:
        if logger is not None:
            logger.info(line)
        if should_print:
            print(line)


def _resolve_class_label(dataset_cfg: DatasetConfig, index: int) -> str:
    if index < len(dataset_cfg.class_names):
        label = dataset_cfg.class_names[index]
    else:
        label = f"class_{index:02d}"

    if label.startswith("class_") and label[6:].isdigit():
        return f"class {int(label[6:]):02d}"
    return label


def format_epoch_metrics(
    epoch: Optional[int],
    train_loss: Optional[float],
    metrics: Dict[str, torch.Tensor | float | List[float]],
    dataset_cfg: DatasetConfig,
    *,
    header: Optional[str] = None,
) -> List[str]:
    lines: List[str] = []

    val_loss = float(metrics.get("loss", float("nan")))
    map_value = float(metrics.get("mAP", float("nan")))

    if header is not None:
        summary = header
    elif epoch is not None:
        summary = f"Epoch {epoch:02d}"
    else:
        summary = "Metrics"

    if train_loss is not None and np.isfinite(train_loss):
        summary += f" | train loss {train_loss:.4f}"
    if np.isfinite(val_loss):
        summary += f" | val loss {val_loss:.4f}"
    if np.isfinite(map_value):
        summary += f" | mAP {map_value:.4f}"
    lines.append(summary)

    precision = np.asarray(metrics.get("precision", []), dtype=float)
    recall = np.asarray(metrics.get("recall", []), dtype=float)
    tp = np.asarray(metrics.get("TP", []), dtype=int)
    fp = np.asarray(metrics.get("FP", []), dtype=int)
    fn = np.asarray(metrics.get("FN", []), dtype=int)
    ap = np.asarray(metrics.get("AP", []), dtype=float)
    gt_counter = np.asarray(metrics.get("gt_counter", np.zeros_like(tp)), dtype=int)

    num_classes = min(len(tp), dataset_cfg.num_classes)
    for cls_idx in range(num_classes):
        gt_value = int(gt_counter[cls_idx]) if gt_counter.size > cls_idx else 0
        tp_value = int(tp[cls_idx]) if tp.size > cls_idx else 0
        fp_value = int(fp[cls_idx]) if fp.size > cls_idx else 0
        fn_value = int(fn[cls_idx]) if fn.size > cls_idx else 0

        if gt_value == 0 and tp_value == 0 and fp_value == 0 and fn_value == 0:
            continue

        label = _resolve_class_label(dataset_cfg, cls_idx)
        p_val = (
            float(np.nan_to_num(precision[cls_idx], nan=0.0))
            if precision.size > cls_idx
            else 0.0
        )
        r_val = (
            float(np.nan_to_num(recall[cls_idx], nan=0.0))
            if recall.size > cls_idx
            else 0.0
        )
        line = f"{label} | P={p_val:.3f} R={r_val:.3f}  TP={tp_value} FP={fp_value} FN={fn_value}"
        if ap.size > cls_idx and np.isfinite(ap[cls_idx]):
            line += f" AP={ap[cls_idx]:.3f}"
        lines.append(line)

    return lines


def load_default_font() -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
    """Load a truetype font when available, otherwise fall back to default."""

    try:
        return ImageFont.truetype("DejaVuSans.ttf", size=14)
    except Exception:  # pragma: no cover - fallback when font unavailable
        return ImageFont.load_default()


def _resolve_class_name(class_names: Sequence[str], label: int) -> str:
    if 0 <= label < len(class_names):
        name = class_names[label]
    else:
        name = f"class_{label:02d}"

    if name.startswith("class_") and name[6:].isdigit():
        return f"class {int(name[6:]):02d}"
    return name


def render_detections(
    image: np.ndarray,
    prediction: Mapping[str, np.ndarray],
    target: Mapping[str, np.ndarray] | None,
    class_names: Sequence[str],
    score_threshold: float,
    class_thresholds: Mapping[int, float],
    draw_ground_truth: bool,
) -> PILImage:
    """Render detection predictions (and optional ground truth) onto an image."""

    if image.dtype != np.uint8:
        image_array = np.clip(image, 0, 255).astype(np.uint8)
    else:
        image_array = image

    pil = PILImage.fromarray(image_array)
    draw = ImageDraw.Draw(pil)
    font = load_default_font()

    boxes = np.asarray(prediction.get("boxes", np.empty((0, 4))), dtype=float)
    labels = np.asarray(prediction.get("labels", np.empty((0,), dtype=int)), dtype=int)
    scores = np.asarray(prediction.get("scores", np.empty((0,), dtype=float)), dtype=float)

    for box, label, score in zip(boxes, labels, scores):
        threshold = class_thresholds.get(int(label), score_threshold)
        if score < threshold:
            continue

        color = DEFAULT_COLORS[int(label) % len(DEFAULT_COLORS)] if len(DEFAULT_COLORS) else "#FF6B6B"
        x1, y1, x2, y2 = [float(coord) for coord in box]
        draw.rectangle([x1, y1, x2, y2], outline=color, width=2)

        caption = f"{_resolve_class_name(class_names, int(label))} {score:.2f}"
        text_width = draw.textlength(caption, font=font)
        draw.rectangle([x1, y1 - 16, x1 + text_width + 8, y1], fill=color)
        draw.text((x1 + 4, y1 - 14), caption, fill="white", font=font)

    if draw_ground_truth and target is not None:
        gt_boxes = np.asarray(target.get("boxes", np.empty((0, 4))), dtype=float)
        gt_labels = np.asarray(target.get("labels", np.empty((0,), dtype=int)), dtype=int)
        for box, label in zip(gt_boxes, gt_labels):
            x1, y1, x2, y2 = [float(coord) for coord in box]
            draw.rectangle([x1, y1, x2, y2], outline="#FFFFFF", width=1)
            caption = f"GT {_resolve_class_name(class_names, int(label))}"
            text_width = draw.textlength(caption, font=font)
            draw.rectangle([x1, y2, x1 + text_width + 6, y2 + 14], fill="#FFFFFF")
            draw.text((x1 + 3, y2), caption, fill="black", font=font)

    return pil


def save_detection_visual(
    image: np.ndarray,
    prediction: Mapping[str, np.ndarray],
    target: Mapping[str, np.ndarray] | None,
    class_names: Sequence[str],
    score_threshold: float,
    class_thresholds: Mapping[int, float],
    draw_ground_truth: bool,
    output_path: Path,
) -> None:
    """Render and persist a detection visualisation to ``output_path``."""

    visual = render_detections(
        image,
        prediction,
        target,
        class_names,
        score_threshold,
        class_thresholds,
        draw_ground_truth,
    )
    output_path.parent.mkdir(parents=True, exist_ok=True)
    visual.save(output_path)


def write_false_positive_report(
    fp_records: Sequence[Dict[str, object]],
    report_path: Path,
    *,
    split: str,
    score_threshold: float,
    class_score_thresholds: Mapping[int, float],
    iou_threshold: float,
) -> None:
    """Serialise detailed false-positive information to JSON."""

    report_payload = {
        "split": split,
        "score_threshold": score_threshold,
        "class_score_thresholds": dict(class_score_thresholds),
        "iou_threshold": iou_threshold,
        "false_positive_images": list(fp_records),
    }
    report_path.parent.mkdir(parents=True, exist_ok=True)
    report_path.write_text(json.dumps(report_payload, indent=2, ensure_ascii=False))


def write_false_positive_list(fp_records: Sequence[Dict[str, object]], list_path: Path) -> None:
    """Write newline separated image identifiers that triggered false positives."""

    stems = sorted({str(record["image_id"]) for record in fp_records})
    list_path.parent.mkdir(parents=True, exist_ok=True)
    list_path.write_text("\n".join(stems) + ("\n" if stems else ""))


class SmoothedValue:
    """Track a series of values and provide access to smoothed statistics."""

    def __init__(self, window_size: int = 20) -> None:
        self.window_size = window_size
        self.deque: List[float] = []
        self.total = 0.0
        self.count = 0

    def update(self, value: float) -> None:
        if len(self.deque) == self.window_size:
            self.total -= self.deque.pop(0)
        self.deque.append(value)
        self.total += value
        self.count += 1

    @property
    def avg(self) -> float:
        if not self.deque:
            return 0.0
        return self.total / len(self.deque)


class MetricLogger:
    """Helper class that logs running averages for multiple metrics."""

    def __init__(self) -> None:
        self.meters: Dict[str, SmoothedValue] = {}

    def update(self, **kwargs: float) -> None:
        for name, value in kwargs.items():
            if name not in self.meters:
                self.meters[name] = SmoothedValue()
            self.meters[name].update(float(value))

    def format(self) -> str:
        parts = [f"{name}: {meter.avg:.4f}" for name, meter in self.meters.items()]
        return " | ".join(parts)


## Model construction


In [5]:
"""Model building utilities for Faster R-CNN based detectors."""

import logging
from pathlib import Path
from typing import Optional

import torch
from torch import nn
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead



LOGGER = logging.getLogger(__name__)


def _save_state_dict(model: nn.Module, path: Path) -> None:
    try:
        torch.save(model.state_dict(), path)
        LOGGER.info("Saved pretrained weights to %s", path)
    except Exception as exc:  # pragma: no cover - safety net
        LOGGER.warning("Unable to save pretrained weights: %s", exc)


def build_model(
    dataset_cfg: DatasetConfig,
    train_cfg: TrainingConfig,
    device: Optional[torch.device] = None,
) -> nn.Module:
    """Create a Faster R-CNN model adjusted for the project dataset."""
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = _load_pretrained_model(train_cfg)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    num_classes_with_background = dataset_cfg.num_classes + 1
    model.roi_heads.box_predictor = FastRCNNPredictor(
        in_features, num_classes_with_background
    )

    if train_cfg.small_object:
        anchor_generator = AnchorGenerator(
            #sizes=((16,), (32,), (64,), (128,), (256,)),
            #aspect_ratios=((0.5, 1.0, 2.0),) * 5,
            sizes=((16, 24), (32, 48), (64, 96), (128, 192), (256, 384)),
            aspect_ratios=((0.2, 0.5, 1.0, 2.0, 5.0),) * 5
        )
        model.rpn.anchor_generator = anchor_generator
        LOGGER.info("Using custom anchor sizes optimised for small objects")

        # 1. 从旧的 RPNHead 中获取 in_channels
        in_channels = model.rpn.head.cls_logits.in_channels

        # 2. 从新的 AnchorGenerator 获取每个位置的锚框数
        #    (例如，5 个 aspect_ratios * 2 个 sizes = 10)
        num_anchors_per_location = anchor_generator.num_anchors_per_location()[0]
    
        # 3. 创建并替换 RPNHead
        new_head = RPNHead(in_channels, num_anchors_per_location)
        model.rpn.head = new_head
        LOGGER.info(
            "Re-created RPN head for %d anchors per location to match AnchorGenerator.",
            num_anchors_per_location
        )

    model.to(device)
    return model


def _load_pretrained_model(train_cfg: TrainingConfig) -> nn.Module:
    """Load a Faster R-CNN model with fallback to local weights when offline."""
    pretrained_path = train_cfg.pretrained_weights_path
    pretrained_path.parent.mkdir(parents=True, exist_ok=True)

    weights_enum = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    try:
        model = fasterrcnn_resnet50_fpn_v2(weights=weights_enum)
        LOGGER.info("Loaded torchvision Faster R-CNN weights")
        if not pretrained_path.exists():
            _save_state_dict(model, pretrained_path)
        return model
    except Exception:  # pragma: no cover - fallback when torchvision download fails
        LOGGER.warning("Falling back to locally saved pretrained detector weights")
        if not pretrained_path.exists():
            raise RuntimeError(
                "No pretrained weights available. Download the torchvision weights manually "
                "and place them at %s" % pretrained_path
            )
        state_dict = torch.load(pretrained_path, map_location="cpu")
        model = fasterrcnn_resnet50_fpn_v2(weights=None)
        model.load_state_dict(state_dict)
        return model


## Training utilities


In [6]:
"""Training script for the electrical component Faster R-CNN detector."""

import argparse
import contextlib
import inspect
import json
import logging
import numpy as np
from dataclasses import asdict
from pathlib import Path
from typing import Dict, List, Set
import pickle
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm

try:
    from torch.amp import GradScaler  # PyTorch 2.1+
except ImportError:  # pragma: no cover - compatibility path
    from torch.cuda.amp import GradScaler  # type: ignore[attr-defined]

try:
    from torch.serialization import add_safe_globals
except ImportError:  # pragma: no cover - compatibility path
    add_safe_globals = None




LOGGER = logging.getLogger("train")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--data-dir", type=Path, default=DatasetConfig().base_dir)
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--batch-size", type=int, default=4)
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--weight-decay", type=float, default=1e-4)
    parser.add_argument("--num-workers", type=int, default=TrainingConfig().num_workers)
    parser.add_argument("--no-augmentation", action="store_true", help="Disable data augmentation")
    parser.add_argument(
        "--mosaic-prob",
        type=float,
        default=TrainingConfig().mosaic_prob,
        help="Probability of applying mosaic augmentation (four-image collage)",
    )
    parser.add_argument(
        "--mixup-prob",
        type=float,
        default=TrainingConfig().mixup_prob,
        help="Probability of mixing an additional image into the current sample",
    )
    parser.add_argument(
        "--mixup-alpha",
        type=float,
        default=TrainingConfig().mixup_alpha,
        help="Alpha parameter for the Beta distribution controlling MixUp strength",
    )
    parser.add_argument(
        "--scale-jitter", nargs=2, type=float, metavar=("MIN", "MAX"), default=None,
        help="Uniform scale jitter range applied to each image before flips/jitter",
    )
    parser.add_argument(
        "--rotation-prob",
        type=float,
        default=TrainingConfig().rotation_prob,
        help="Probability of applying a random in-place rotation.",
    )
    parser.add_argument(
        "--rotation-max-degrees",
        type=float,
        default=TrainingConfig().rotation_max_degrees,
        help="Maximum absolute angle in degrees for random rotations.",
    )
    parser.add_argument(
        "--affine-prob",
        type=float,
        default=TrainingConfig().affine_prob,
        help="Probability of applying a random affine transform (translate/scale/shear).",
    )
    parser.add_argument(
        "--affine-translate",
        nargs=2,
        type=float,
        metavar=("FRAC_X", "FRAC_Y"),
        default=None,
        help="Maximum absolute translation as a fraction of image width/height.",
    )
    parser.add_argument(
        "--affine-scale",
        nargs=2,
        type=float,
        metavar=("MIN", "MAX"),
        default=None,
        help="Uniform scaling range applied during affine augmentation.",
    )
    parser.add_argument(
        "--affine-shear",
        nargs=2,
        type=float,
        metavar=("SHEAR_X", "SHEAR_Y"),
        default=None,
        help="Maximum absolute shear angles (degrees) for the affine augmentation.",
    )
    parser.add_argument("--small-object", action="store_true", help="Use smaller RPN anchors")
    parser.add_argument("--score-threshold", type=float, default=0.6)
    parser.add_argument(
        "--class-threshold",
        action="append",
        default=[],
        metavar="CLS=THRESH",
        help="Override per-class score thresholds (e.g. --class-threshold 3=0.8)",
    )
    parser.add_argument("--iou-threshold", type=float, default=0.5)
    parser.add_argument("--no-amp", action="store_true", help="Disable automatic mixed precision")
    parser.add_argument("--eval-interval", type=int, default=1)
    parser.add_argument("--seed", type=int, default=2024)
    parser.add_argument("--checkpoint", type=Path, default=TrainingConfig().checkpoint_path)
    parser.add_argument("--pretrained-path", type=Path, default=TrainingConfig().pretrained_weights_path)
    parser.add_argument("--resume", action="store_true", help="Resume training from a saved checkpoint.")
    parser.add_argument(
        "--resume-path",
        type=Path,
        default=None,
        help="Optional path to the checkpoint used for resuming training. Defaults to the last-checkpoint file.",
    )
    parser.add_argument("--log-every", type=int, default=20)
    parser.add_argument("--train-split", default=DatasetConfig().train_split)
    parser.add_argument("--valid-split", default=DatasetConfig().valid_split)
    parser.add_argument("--num-classes", type=int, default=DatasetConfig().num_classes)
    parser.add_argument(
        "--exclude-list",
        action="append",
        default=[],
        type=Path,
        help="Path(s) to files listing image stems to exclude from training (text or JSON).",
    )
    parser.add_argument(
        "--exclude-sample",
        action="append",
        default=[],
        help="Additional image stems to remove from the training split.",
    )
    parser.add_argument(
        "--fp-dir",
        type=Path,
        default=None,
        help="Directory where false-positive visualisations from the final epoch are written.",
    )
    parser.add_argument(
        "--fp-report",
        type=Path,
        help="Optional path to write a JSON report describing final-epoch false positives.",
    )
    parser.add_argument(
        "--fp-list",
        type=Path,
        help="Optional path to write newline separated image stems containing final-epoch false positives.",
    )
    parser.add_argument(
        "--fp-class",
        action="append",
        type=int,
        default=[],
        help="Restrict false-positive exports to specific classes (repeatable).",
    )
    return parser.parse_args()


def _normalise_stem(value: str) -> str:
    return Path(value).stem if value else value


def _load_exclusions_from_file(path: Path) -> List[str]:
    try:
        text = path.read_text(encoding="utf-8")
    except OSError as exc:
        LOGGER.warning("Unable to read exclude list %s: %s", path, exc)
        return []

    try:
        payload = json.loads(text)
    except json.JSONDecodeError:
        lines = [line.strip() for line in text.splitlines() if line.strip()]
        return [_normalise_stem(line) for line in lines]

    stems: List[str] = []
    if isinstance(payload, dict):
        for key in ("stems", "image_ids", "images"):
            if key in payload and isinstance(payload[key], list):
                stems.extend(_normalise_stem(str(item)) for item in payload[key])
        if "false_positive_images" in payload and isinstance(payload["false_positive_images"], list):
            for entry in payload["false_positive_images"]:
                if isinstance(entry, dict):
                    if "image_id" in entry:
                        stems.append(_normalise_stem(str(entry["image_id"])))
                    elif "image" in entry:
                        stems.append(_normalise_stem(str(entry["image"])))
    elif isinstance(payload, list):
        for entry in payload:
            if isinstance(entry, dict):
                if "image_id" in entry:
                    stems.append(_normalise_stem(str(entry["image_id"])))
                elif "image" in entry:
                    stems.append(_normalise_stem(str(entry["image"])))
            else:
                stems.append(_normalise_stem(str(entry)))

    return [stem for stem in stems if stem]


def _resolve_exclusions(paths: List[Path], samples: List[str]) -> List[str]:
    stems: Set[str] = {_normalise_stem(sample) for sample in samples if sample}
    for path in paths:
        if path is None:
            continue
        if not path.exists():
            LOGGER.warning("Exclude list %s does not exist; skipping.", path)
            continue
        stems.update(_load_exclusions_from_file(path))
    return sorted(stem for stem in stems if stem)


def prepare_configs(args: argparse.Namespace) -> tuple[DatasetConfig, TrainingConfig]:
    default_train_cfg = TrainingConfig()
    dataset_cfg = DatasetConfig(
        base_dir=args.data_dir,
        train_split=args.train_split,
        valid_split=args.valid_split,
        num_classes=args.num_classes,
    )

    class_thresholds = DEFAULT_CLASS_SCORE_THRESHOLDS.copy()
    overrides = parse_class_threshold_entries(args.class_threshold)
    class_thresholds.update(overrides)

    exclude_samples = _resolve_exclusions(args.exclude_list, args.exclude_sample)

    if args.fp_class:
        fp_class_values = sorted({int(value) for value in args.fp_class})
        fp_classes = tuple(fp_class_values)
    else:
        fp_classes = default_train_cfg.fp_classes

    if args.fp_dir is not None:
        fp_dir = Path(args.fp_dir)
    else:
        fp_dir = default_train_cfg.fp_visual_dir

    fp_report_path = Path(args.fp_report) if args.fp_report is not None else args.fp_report
    fp_list_path = Path(args.fp_list) if args.fp_list is not None else args.fp_list

    if args.scale_jitter:
        scale_min, scale_max = args.scale_jitter
    else:
        scale_min = default_train_cfg.scale_jitter_min
        scale_max = default_train_cfg.scale_jitter_max

    rotation_prob = args.rotation_prob
    rotation_max = args.rotation_max_degrees

    affine_prob = args.affine_prob
    if args.affine_translate is not None:
        affine_translate = (float(args.affine_translate[0]), float(args.affine_translate[1]))
    else:
        affine_translate = default_train_cfg.affine_translate

    if args.affine_scale is not None:
        affine_scale_range = (float(args.affine_scale[0]), float(args.affine_scale[1]))
    else:
        affine_scale_range = default_train_cfg.affine_scale_range

    if args.affine_shear is not None:
        affine_shear = (float(args.affine_shear[0]), float(args.affine_shear[1]))
    else:
        affine_shear = default_train_cfg.affine_shear

    checkpoint_path = Path(args.checkpoint)
    pretrained_path = Path(args.pretrained_path)
    last_checkpoint_path = checkpoint_path.parent / "last_checkpoint.pth"
    if args.resume_path is not None:
        resume_path = Path(args.resume_path)
    elif args.resume:
        resume_path = last_checkpoint_path
    else:
        resume_path = None

    train_cfg = TrainingConfig(
        epochs=args.epochs,
        batch_size=args.batch_size,
        learning_rate=args.lr,
        weight_decay=args.weight_decay,
        num_workers=args.num_workers,
        amp=not args.no_amp,
        augmentation=not args.no_augmentation,
        mosaic_prob=args.mosaic_prob,
        mixup_prob=args.mixup_prob,
        mixup_alpha=args.mixup_alpha,
        scale_jitter_min=scale_min,
        scale_jitter_max=scale_max,
        rotation_prob=rotation_prob,
        rotation_max_degrees=rotation_max,
        affine_prob=affine_prob,
        affine_translate=affine_translate,
        affine_scale_range=affine_scale_range,
        affine_shear=affine_shear,
        small_object=args.small_object,
        score_threshold=args.score_threshold,
        iou_threshold=args.iou_threshold,
        eval_interval=args.eval_interval,
        seed=args.seed,
        checkpoint_path=checkpoint_path,
        pretrained_weights_path=pretrained_path,
        resume=args.resume,
        resume_path=resume_path,
        last_checkpoint_path=last_checkpoint_path,
        log_every=args.log_every,
        class_score_thresholds=class_thresholds,
        exclude_samples=tuple(exclude_samples),
        fp_visual_dir=fp_dir,
        fp_report_path=fp_report_path,
        fp_list_path=fp_list_path,
        fp_classes=fp_classes,
    )
    train_cfg.ensure_directories()
    return dataset_cfg, train_cfg


def move_to_device(targets: List[Dict[str, torch.Tensor]], device: torch.device) -> List[Dict[str, torch.Tensor]]:
    return [{k: v.to(device) for k, v in target.items()} for target in targets]


def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scaler: GradScaler,
    device: torch.device,
    amp: bool,
    log_every: int,
) -> float:
    model.train()
    metric_logger = MetricLogger()
    progress = tqdm(loader, desc="Train", leave=False)

    for step, (images, targets) in enumerate(progress, start=1):
        images = [img.to(device) for img in images]
        targets = move_to_device(targets, device)

        optimizer.zero_grad()
        autocast_enabled = amp and device.type == "cuda"
        autocast_context = (
            torch.amp.autocast(device_type="cuda") if autocast_enabled else contextlib.nullcontext()
        )
        with autocast_context:
            loss_dict = model(images, targets)
            loss = sum(loss_dict.values())

        if torch.isfinite(loss):
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:  # pragma: no cover - guard for invalid losses
            LOGGER.warning("Skipping step %s due to non-finite loss", step)
            scaler.update()
            continue

        metric_logger.update(loss=loss.item())
        if step % log_every == 0:
            progress.set_postfix_str(metric_logger.format())

    return metric_logger.meters.get("loss").avg if metric_logger.meters else 0.0


@torch.no_grad()
def evaluate(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    dataset_cfg: DatasetConfig,
    train_cfg: TrainingConfig,
    *,
    dataset: ElectricalComponentsDataset | None = None,
    collect_details: bool = False,
) -> tuple[Dict[str, torch.Tensor | float | List[float]], List[Dict[str, object]]]:
    was_training = model.training
    model.eval()

    predictions = []
    targets_for_eval = []
    total_loss = 0.0
    num_batches = 0

    sample_details: List[Dict[str, object]] = []
    dataset_ref = dataset if dataset is not None else getattr(loader, "dataset", None)

    for images, targets in loader:
        images = [img.to(device) for img in images]
        targets_device = move_to_device(targets, device)

        model.train()
        loss_dict = model(images, targets_device)
        total_loss += sum(loss_dict.values()).item()
        num_batches += 1
        model.eval()

        outputs = model(images)
        for output, target_device, target_cpu in zip(outputs, targets_device, targets):
            boxes_np = output["boxes"].detach().cpu().numpy()
            scores_np = output["scores"].detach().cpu().numpy()
            labels_np = output["labels"].detach().cpu().numpy().astype(np.int64, copy=False)
            keep = score_threshold_mask(
                scores_np,
                labels_np,
                train_cfg.score_threshold,
                train_cfg.class_score_thresholds,
            )
            boxes_np = boxes_np[keep]
            scores_np = scores_np[keep]
            labels_np = labels_np[keep].astype(np.int64, copy=True)
            if labels_np.size:
                labels_np -= 1

            target_boxes = target_device["boxes"].detach().cpu().numpy()
            target_labels = target_device["labels"].detach().cpu().numpy().astype(np.int64, copy=True)
            if target_labels.size:
                gt_keep = target_labels > 0
                target_boxes = target_boxes[gt_keep]
                target_labels = target_labels[gt_keep] - 1

            prediction_np = {
                "boxes": boxes_np,
                "scores": scores_np,
                "labels": labels_np,
            }
            target_np = {
                "boxes": target_boxes,
                "labels": target_labels,
            }

            predictions.append(prediction_np)
            targets_for_eval.append(target_np)

            if collect_details:
                image_identifier = target_cpu.get("image_id", -1)
                if isinstance(image_identifier, torch.Tensor):
                    image_index = int(image_identifier.item())
                else:
                    try:
                        image_index = int(image_identifier)
                    except Exception:
                        image_index = -1
                if dataset_ref is not None and 0 <= image_index < len(getattr(dataset_ref, "image_stems", [])):
                    image_id = dataset_ref.image_stems[image_index]
                else:
                    image_id = f"{dataset_cfg.valid_split}_{len(sample_details):04d}"

                fp_info = identify_false_positive_predictions(
                    prediction_np,
                    target_np,
                    dataset_cfg.num_classes,
                    train_cfg.iou_threshold,
                )

                sample_details.append(
                    {
                        "image_index": image_index,
                        "image_id": image_id,
                        "prediction": prediction_np,
                        "target": target_np,
                        "false_positives": fp_info,
                    }
                )

    metrics = compute_detection_metrics(
        predictions, targets_for_eval, dataset_cfg.num_classes, train_cfg.iou_threshold
    )
    metrics["loss"] = total_loss / max(num_batches, 1)

    if was_training:
        model.train()
    return metrics, sample_details


def save_checkpoint(
    *,
    path: Path,
    model: nn.Module,
    optimizer: torch.optim.Optimizer | None,
    scaler: GradScaler | None,
    epoch: int,
    best_map: float,
    history: List[Dict[str, float]],
    train_cfg: TrainingConfig,
) -> None:
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict() if optimizer is not None else None,
        "scaler": scaler.state_dict() if scaler is not None and getattr(scaler, "is_enabled", lambda: True)() else None,
        "epoch": epoch,
        "best_map": float(best_map),
        "history": list(history),
        "config": asdict(train_cfg),
    }
    torch.save(checkpoint, path)
    LOGGER.info("Saved checkpoint to %s (epoch %d, best mAP %.4f)", path, epoch, best_map)


def _load_checkpoint_file(path: Path, device: torch.device) -> object:
    load_kwargs = {"map_location": device}
    try:
        return torch.load(path, **load_kwargs)
    except pickle.UnpicklingError:
        safe_objects: List[object] = []
        if add_safe_globals is not None:
            safe_objects.append(TrainingConfig)
            try:
                safe_objects.append(type(Path(".")))
            except Exception:  # pragma: no cover - defensive
                pass
            add_safe_globals(safe_objects)
        load_params = inspect.signature(torch.load).parameters
        if "weights_only" in load_params:
            load_kwargs["weights_only"] = False
        return torch.load(path, **load_kwargs)


def load_checkpoint(
    *,
    path: Path,
    model: nn.Module,
    optimizer: torch.optim.Optimizer | None,
    scaler: GradScaler | None,
    device: torch.device,
) -> tuple[int, float, List[Dict[str, float]]]:
    if not path.exists():
        LOGGER.warning("Checkpoint %s does not exist; starting fresh.", path)
        return 1, -float("inf"), []

    checkpoint_obj = _load_checkpoint_file(path, device)

    if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
        model.load_state_dict(checkpoint_obj["model"])
        if optimizer is not None and checkpoint_obj.get("optimizer") is not None:
            optimizer.load_state_dict(checkpoint_obj["optimizer"])
        if scaler is not None and checkpoint_obj.get("scaler") is not None:
            try:
                scaler.load_state_dict(checkpoint_obj["scaler"])
            except Exception as exc:  # pragma: no cover - fallback if scaler config changed
                LOGGER.warning("Unable to load scaler state from %s: %s", path, exc)
        epoch = int(checkpoint_obj.get("epoch", 0))
        best_map = float(checkpoint_obj.get("best_map", -float("inf")))
        history = checkpoint_obj.get("history", [])
        LOGGER.info(
            "Loaded checkpoint from %s (epoch %d, best mAP %.4f)",
            path,
            epoch,
            best_map,
        )
        return epoch + 1, best_map, list(history)

    # Legacy checkpoint: assume it's a plain state dict.
    model.load_state_dict(checkpoint_obj)
    LOGGER.info("Loaded model weights from legacy checkpoint %s", path)
    return 1, -float("inf"), []


def _create_grad_scaler(amp_enabled: bool, device: torch.device) -> GradScaler:
    """Create a GradScaler compatible with both legacy and new AMP APIs."""
    scaler_enabled = amp_enabled and device.type == "cuda"
    init_params = inspect.signature(GradScaler.__init__).parameters
    kwargs = {"enabled": scaler_enabled}
    if "device_type" in init_params:
        kwargs["device_type"] = device.type
    return GradScaler(**kwargs)


def export_false_positive_visuals(
    *,
    dataset: ElectricalComponentsDataset,
    dataset_cfg: DatasetConfig,
    train_cfg: TrainingConfig,
    sample_details: List[Dict[str, object]],
) -> List[Dict[str, object]]:
    """Export false-positive visuals for the configured classes and return records."""

    if not train_cfg.fp_visual_dir:
        return []

    relevant_classes = {int(cls) for cls in train_cfg.fp_classes}
    if not relevant_classes:
        return []

    output_dir = train_cfg.fp_visual_dir
    output_dir.mkdir(parents=True, exist_ok=True)

    # Keep only the last-epoch visuals by clearing previous outputs.
    for existing in output_dir.glob("*.png"):
        try:
            existing.unlink()
        except OSError:
            LOGGER.debug("Unable to remove previous FP visual %s", existing)

    fp_records: List[Dict[str, object]] = []

    for detail in sample_details:
        raw_records = detail.get("false_positives", [])
        if not raw_records:
            continue

        filtered = [fp for fp in raw_records if int(fp.get("class", -1)) in relevant_classes]
        if not filtered:
            continue

        image_id = str(detail.get("image_id", detail.get("image_index", "unknown")))
        try:
            image_np, _, _ = dataset._load_raw_sample(image_id)
        except Exception as exc:  # pragma: no cover - defensive
            LOGGER.warning("Skipping FP visual for %s due to load error: %s", image_id, exc)
            continue

        output_path = output_dir / f"{image_id}.png"
        save_detection_visual(
            image_np,
            detail.get("prediction", {}),
            detail.get("target", {}),
            dataset_cfg.class_names,
            train_cfg.score_threshold,
            train_cfg.class_score_thresholds,
            True,
            output_path,
        )
        fp_records.append({"image_id": image_id, "false_positives": filtered})

    return fp_records


def run_training(args: argparse.Namespace) -> None:
    dataset_cfg, train_cfg = prepare_configs(args)

    logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
    set_seed(train_cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_model(dataset_cfg, train_cfg, device=device)

    train_dataset = ElectricalComponentsDataset(
        root=dataset_cfg.base_dir,
        split=dataset_cfg.train_split,
        class_names=dataset_cfg.class_names,
        transform=AugmentationParams(
            mosaic_prob=train_cfg.mosaic_prob,
            mixup_prob=train_cfg.mixup_prob,
            mixup_alpha=train_cfg.mixup_alpha,
            scale_jitter_range=(train_cfg.scale_jitter_min, train_cfg.scale_jitter_max),
            rotation_prob=train_cfg.rotation_prob,
            rotation_max_degrees=train_cfg.rotation_max_degrees,
            affine_prob=train_cfg.affine_prob,
            affine_translate=train_cfg.affine_translate,
            affine_scale_range=train_cfg.affine_scale_range,
            affine_shear=train_cfg.affine_shear,
        ),
        use_augmentation=train_cfg.augmentation,
        exclude_stems=train_cfg.exclude_samples,
    )
    valid_dataset = ElectricalComponentsDataset(
        root=dataset_cfg.base_dir,
        split=dataset_cfg.valid_split,
        class_names=dataset_cfg.class_names,
        use_augmentation=False,
    )

    if train_cfg.exclude_samples:
        sample_preview = ", ".join(train_cfg.exclude_samples[:5])
        if len(train_cfg.exclude_samples) > 5:
            sample_preview += ", ..."
        LOGGER.info(
            "Excluding %d training samples: %s",
            len(train_cfg.exclude_samples),
            sample_preview,
        )

    train_loader = create_data_loaders(
        train_dataset,
        batch_size=train_cfg.batch_size,
        shuffle=True,
        num_workers=train_cfg.num_workers,
    )
    if train_cfg.num_workers > 1:
        valid_workers = max(1, train_cfg.num_workers // 2)
    else:
        valid_workers = train_cfg.num_workers

    valid_loader = create_data_loaders(
        valid_dataset,
        batch_size=train_cfg.batch_size,
        shuffle=False,
        num_workers=valid_workers,
    )

    optimizer = AdamW(model.parameters(), lr=train_cfg.learning_rate, weight_decay=train_cfg.weight_decay)
    scaler = _create_grad_scaler(train_cfg.amp, device)

    start_epoch = 1
    best_map = -float("inf")
    history: List[Dict[str, float]] = []

    if train_cfg.resume or train_cfg.resume_path is not None:
        resume_source = train_cfg.resume_path or train_cfg.last_checkpoint_path
        start_epoch, best_map, history = load_checkpoint(
            path=resume_source,
            model=model,
            optimizer=optimizer,
            scaler=scaler,
            device=device,
        )
        if start_epoch > train_cfg.epochs:
            LOGGER.info(
                "Checkpoint %s already reached epoch %d; target epochs=%d. Nothing left to train. "
                "Increase --epochs if you want to continue.",
                resume_source,
                start_epoch - 1,
                train_cfg.epochs,
            )
            return

    for epoch in range(start_epoch, train_cfg.epochs + 1):
        LOGGER.info("Epoch %s/%s", epoch, train_cfg.epochs)
        train_loss = train_one_epoch(
            model, train_loader, optimizer, scaler, device, train_cfg.amp, train_cfg.log_every
        )

        should_evaluate = (epoch % train_cfg.eval_interval == 0) or (epoch == train_cfg.epochs)
        collect_fp = should_evaluate and epoch == train_cfg.epochs

        if should_evaluate:
            metrics, sample_details = evaluate(
                model,
                valid_loader,
                device,
                dataset_cfg,
                train_cfg,
                dataset=valid_dataset,
                collect_details=collect_fp,
            )
            metric_lines = format_epoch_metrics(epoch, train_loss, metrics, dataset_cfg)
            emit_metric_lines(metric_lines, logger=LOGGER)

            history.append(
                {
                    "epoch": epoch,
                    "train_loss": float(train_loss),
                    "val_loss": float(metrics["loss"]),
                    "mAP": float(metrics["mAP"]),
                }
            )

            if metrics["mAP"] > best_map:
                best_map = float(metrics["mAP"])
                save_checkpoint(
                    path=train_cfg.checkpoint_path,
                    model=model,
                    optimizer=optimizer,
                    scaler=scaler,
                    epoch=epoch,
                    best_map=best_map,
                    history=history,
                    train_cfg=train_cfg,
                )

            if collect_fp:
                fp_records = export_false_positive_visuals(
                    dataset=valid_dataset,
                    dataset_cfg=dataset_cfg,
                    train_cfg=train_cfg,
                    sample_details=sample_details,
                )
                if train_cfg.fp_report_path:
                    write_false_positive_report(
                        fp_records,
                        train_cfg.fp_report_path,
                        split=dataset_cfg.valid_split,
                        score_threshold=train_cfg.score_threshold,
                        class_score_thresholds=train_cfg.class_score_thresholds,
                        iou_threshold=train_cfg.iou_threshold,
                    )
                    LOGGER.info(
                        "Saved false-positive report for %d images to %s",
                        len(fp_records),
                        train_cfg.fp_report_path,
                    )
                if train_cfg.fp_list_path:
                    write_false_positive_list(fp_records, train_cfg.fp_list_path)
                    LOGGER.info(
                        "Saved list of %d false-positive image ids to %s",
                        len({str(record["image_id"]) for record in fp_records}),
                        train_cfg.fp_list_path,
                    )
        else:
            LOGGER.info(
                "Epoch %02d | train loss %.4f | evaluation skipped (eval_interval=%d)",
                epoch,
                train_loss,
                train_cfg.eval_interval,
            )

        save_checkpoint(
            path=train_cfg.last_checkpoint_path,
            model=model,
            optimizer=optimizer,
            scaler=scaler,
            epoch=epoch,
            best_map=best_map,
            history=history,
            train_cfg=train_cfg,
        )

    (train_cfg.output_dir / "training_history.json").write_text(json.dumps(history, indent=2))
    LOGGER.info("Training complete. Best mAP: %.4f", best_map)


def main() -> None:
    args = parse_args()
    run_training(args)


if __name__ == "__main__":
    main()


## Inference helpers


In [7]:
"""Inference utilities for evaluating trained models on the test split."""

import argparse
import logging
import pickle
import inspect
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

try:
    from torch.serialization import add_safe_globals
except ImportError:  # pragma: no cover - compatibility for older PyTorch
    add_safe_globals = None




LOGGER = logging.getLogger("inference")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--data-dir", type=Path, default=DatasetConfig().base_dir)
    parser.add_argument("--split", default=DatasetConfig().test_split)
    parser.add_argument("--checkpoint", type=Path, required=True)
    parser.add_argument("--output-dir", type=Path, default=InferenceConfig().output_dir)
    parser.add_argument("--score-threshold", type=float, default=InferenceConfig().score_threshold)
    parser.add_argument(
        "--class-threshold",
        action="append",
        default=[],
        metavar="CLS=THRESH",
        help="Override per-class score thresholds for inference",
    )
    parser.add_argument("--iou-threshold", type=float, default=0.5)
    parser.add_argument("--max-images", type=int, default=InferenceConfig().max_images)
    parser.add_argument("--num-classes", type=int, default=DatasetConfig().num_classes)
    parser.add_argument("--draw-ground-truth", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=2024)
    parser.add_argument("--pretrained-path", type=Path, default=TrainingConfig().pretrained_weights_path)
    parser.add_argument(
        "--fp-report",
        type=Path,
        help="Optional path to write a JSON report of images containing false positives.",
    )
    parser.add_argument(
        "--fp-list",
        type=Path,
        help="Optional path to write newline separated image stems that produced false positives.",
    )
    return parser.parse_args()


def _load_checkpoint_state(path: Path, device: torch.device) -> Dict[str, torch.Tensor]:
    load_kwargs = {"map_location": device}
    try:
        checkpoint_obj = torch.load(path, **load_kwargs)
    except pickle.UnpicklingError:
        if add_safe_globals is not None:
            add_safe_globals([TrainingConfig])
        load_params = inspect.signature(torch.load).parameters
        if "weights_only" in load_params:
            load_kwargs["weights_only"] = False
        checkpoint_obj = torch.load(path, **load_kwargs)

    if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
        return checkpoint_obj["model"]

    return checkpoint_obj
@torch.no_grad()
def run_inference(args: argparse.Namespace) -> None:
    dataset_cfg = DatasetConfig(
        base_dir=args.data_dir,
        test_split=args.split,
        num_classes=args.num_classes,
    )
    class_thresholds = DEFAULT_CLASS_SCORE_THRESHOLDS.copy()
    overrides = parse_class_threshold_entries(args.class_threshold)
    class_thresholds.update(overrides)

    inference_cfg = InferenceConfig(
        score_threshold=args.score_threshold,
        max_images=args.max_images,
        output_dir=args.output_dir,
        draw_ground_truth=args.draw_ground_truth,
        class_score_thresholds=class_thresholds,
    )
    inference_cfg.ensure_directories()

    train_cfg = TrainingConfig(
        augmentation=False,
        score_threshold=args.score_threshold,
        iou_threshold=args.iou_threshold,
        pretrained_weights_path=args.pretrained_path,
        class_score_thresholds=class_thresholds,
    )

    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_model(dataset_cfg, train_cfg, device=device)
    state_dict = _load_checkpoint_state(args.checkpoint, device)
    model.load_state_dict(state_dict)
    model.eval()

    dataset = ElectricalComponentsDataset(
        root=dataset_cfg.base_dir,
        split=args.split,
        class_names=dataset_cfg.class_names,
        use_augmentation=False,
    )
    loader = create_data_loaders(dataset, batch_size=1, shuffle=False, num_workers=0)

    predictions: List[Dict[str, np.ndarray]] = []
    targets_for_eval: List[Dict[str, np.ndarray]] = []

    progress = tqdm(loader, desc="Infer")
    fp_records: List[Dict[str, object]] = []
    for idx, (images, targets) in enumerate(progress):
        image = images[0].to(device)
        output = model([image])[0]

        boxes_np = output["boxes"].detach().cpu().numpy()
        scores_np = output["scores"].detach().cpu().numpy()
        labels_np = output["labels"].detach().cpu().numpy().astype(np.int64, copy=False)
        keep = score_threshold_mask(
            scores_np,
            labels_np,
            inference_cfg.score_threshold,
            inference_cfg.class_score_thresholds,
        )
        boxes_np = boxes_np[keep]
        scores_np = scores_np[keep]
        labels_np = labels_np[keep].astype(np.int64, copy=True)
        if labels_np.size:
            labels_np -= 1

        target_boxes = targets[0]["boxes"].detach().cpu().numpy()
        target_labels = targets[0]["labels"].detach().cpu().numpy().astype(np.int64, copy=True)
        if target_labels.size:
            gt_keep = target_labels > 0
            target_boxes = target_boxes[gt_keep]
            target_labels = target_labels[gt_keep] - 1

        prediction_np = {
            "boxes": boxes_np,
            "scores": scores_np,
            "labels": labels_np,
        }
        target_np = {
            "boxes": target_boxes,
            "labels": target_labels,
        }

        predictions.append(prediction_np)
        targets_for_eval.append(target_np)

        fp_details = identify_false_positive_predictions(
            prediction_np,
            target_np,
            dataset_cfg.num_classes,
            args.iou_threshold,
        )
        if fp_details:
            image_id = dataset.image_stems[idx] if idx < len(dataset.image_stems) else f"{args.split}_{idx:04d}"
            fp_records.append(
                {
                    "image_id": image_id,
                    "false_positives": fp_details,
                }
            )

        if idx < inference_cfg.max_images:
            image_np = (images[0].permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
            output_path = inference_cfg.output_dir / f"{args.split}_{idx:04d}.png"
            save_detection_visual(
                image_np,
                prediction_np,
                target_np if args.draw_ground_truth else None,
                dataset_cfg.class_names,
                inference_cfg.score_threshold,
                inference_cfg.class_score_thresholds,
                args.draw_ground_truth,
                output_path,
            )

    metrics = compute_detection_metrics(
        predictions, targets_for_eval, dataset_cfg.num_classes, args.iou_threshold
    )
    metric_lines = format_epoch_metrics(
        epoch=None,
        train_loss=None,
        metrics=metrics,
        dataset_cfg=dataset_cfg,
        header=f"Inference @ IoU {args.iou_threshold:.2f}",
    )
    emit_metric_lines(metric_lines, logger=LOGGER)

    if args.fp_report:
        write_false_positive_report(
            fp_records,
            args.fp_report,
            split=args.split,
            score_threshold=inference_cfg.score_threshold,
            class_score_thresholds=inference_cfg.class_score_thresholds,
            iou_threshold=args.iou_threshold,
        )
        LOGGER.info(
            "Wrote false-positive report for %d images to %s",
            len(fp_records),
            args.fp_report,
        )

    if args.fp_list:
        write_false_positive_list(fp_records, args.fp_list)
        LOGGER.info(
            "Wrote %d image ids with false positives to %s",
            len({str(record["image_id"]) for record in fp_records}),
            args.fp_list,
        )


def main() -> None:
    args = parse_args()
    logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
    run_inference(args)


if __name__ == "__main__":
    main()


## Example usage


1. 依次运行上面的代码单元，确保数据集配置、数据加载、训练与推理函数均已注册到当前会话。
2. 在示例代码单元中按需调整 `train_args` / `inference_args`（尤其是数据路径、批大小、是否恢复训练等），然后调用 `run_training(train_args)` 或 `run_inference(inference_args)`。
3. 训练过程中会把最佳权重写入 `outputs/best_model.pth`，每个 epoch 的断点写入 `outputs/last_checkpoint.pth`；推理会将可视化结果输出到 `InferenceConfig.output_dir`。
---
Tuning History
1. 改小lr
2. 改小0、16、25、30阈值
3. 调整窗口比例((0.2, 0.5, 1.0, 2.0, 5.0),) * len(anchor_sizes)
4. 增加epoch次数
5. 改大16阈值，改小25、30阈值
6. shuffle dataset
7. 改大 25、30阈值
8. 调整RPN比例

In [None]:
# Example usage inside the notebook
# 1) Configure training arguments
train_args = argparse.Namespace(
    data_dir=Path('/path/to/dataset'),
    epochs=2,
    batch_size=2,
    lr=5e-5,
    weight_decay=5e-5,
    num_workers=2,
    no_augmentation=True,
    mosaic_prob=TrainingConfig().mosaic_prob,
    mixup_prob=TrainingConfig().mixup_prob,
    mixup_alpha=TrainingConfig().mixup_alpha,
    scale_jitter=None,
    rotation_prob=TrainingConfig().rotation_prob,
    rotation_max_degrees=TrainingConfig().rotation_max_degrees,
    affine_prob=TrainingConfig().affine_prob,
    affine_translate=None,
    affine_scale=None,
    affine_shear=None,
    small_object=True,
    score_threshold=0.6,
    class_threshold=[],
    iou_threshold=0.5,
    no_amp=False,
    eval_interval=1,
    seed=37,
    checkpoint=Path('outputs/best_model.pth'),
    pretrained_path=TrainingConfig().pretrained_weights_path,
    resume=False,
    resume_path=None,
    log_every=20,
    train_split=DatasetConfig().train_split,
    valid_split=DatasetConfig().valid_split,
    num_classes=DatasetConfig().num_classes,
    exclude_list=[],
    exclude_sample=[],
    fp_dir=None,
    fp_report=None,
    fp_list=None,
    fp_class=[],
)

# 2) Launch training (writes checkpoints to outputs/)
# run_training(train_args)

# 3) Run inference once a checkpoint is available
# inference_args = argparse.Namespace(
#     data_dir=train_args.data_dir,
#     split=DatasetConfig().valid_split,
#     checkpoint=Path('outputs/best_model.pth'),
#     output_dir=Path('outputs/inference'),
#     score_threshold=0.6,
#     class_threshold=[],
#     iou_threshold=0.5,
#     max_images=50,
#     num_classes=train_args.num_classes,
#     draw_ground_truth=True,
#     seed=2024,
#     pretrained_path=TrainingConfig().pretrained_weights_path,
#     fp_report=None,
#     fp_list=None,
# )
# run_inference(inference_args)


Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth
100%|██████████| 167M/167M [00:00<00:00, 200MB/s]
  scaler = GradScaler(enabled=train_cfg.amp and device.type == "cuda")
  image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0


Epoch 01 | train loss 0.7170 | val loss 0.7185 | mAP 0.1218
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.000 R=0.000  TP=0 FP=0 FN=35 AP=0.000
class 02 | P=0.000 R=0.000  TP=0 FP=0 FN=35 AP=0.000
class 04 | P=0.000 R=0.000  TP=0 FP=0 FN=27 AP=0.000
class 05 | P=0.000 R=0.000  TP=0 FP=0 FN=75 AP=0.000
class 06 | P=0.000 R=0.000  TP=0 FP=0 FN=5 AP=0.000
class 07 | P=0.000 R=0.000  TP=0 FP=0 FN=72 AP=0.000
class 09 | P=1.000 R=0.355  TP=11 FP=0 FN=20 AP=0.677
class 10 | P=0.000 R=0.000  TP=0 FP=0 FN=45 AP=0.000
class 11 | P=0.000 R=0.000  TP=0 FP=0 FN=31 AP=0.000
class 13 | P=0.000 R=0.000  TP=0 FP=0 FN=85 AP=0.000
class 14 | P=0.000 R=0.000  TP=0 FP=0 FN=93 AP=0.000
class 15 | P=0.000 R=0.000  TP=0 FP=0 FN=31 AP=0.000
class 16 | P=0.000 R=0.000  TP=0 FP=0 FN=6 AP=0.000
class 18 | P=0.000 R=0.000  TP=0 FP=0 FN=37 AP=0.000
class 19 | P=0.000 R=0.000  TP=0 FP=0 FN=31 AP=0.000
class 20 | P=0.727 R=0.533  TP=8 FP=3 FN=7 AP=0.664
class 21 | P=0.000 R=0.000  TP=0 FP=0 FN=

                                                                      

Epoch 02 | train loss 0.4304 | val loss 0.4355 | mAP 0.5096
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.810 R=0.486  TP=17 FP=4 FN=18 AP=0.684
class 02 | P=0.870 R=0.571  TP=20 FP=3 FN=15 AP=0.754
class 04 | P=0.826 R=0.704  TP=19 FP=4 FN=8 AP=0.786
class 05 | P=0.000 R=0.000  TP=0 FP=0 FN=75 AP=0.000
class 06 | P=1.000 R=0.800  TP=4 FP=0 FN=1 AP=0.900
class 07 | P=0.000 R=0.000  TP=0 FP=0 FN=72 AP=0.000
class 09 | P=0.931 R=0.871  TP=27 FP=2 FN=4 AP=0.929
class 10 | P=0.515 R=0.756  TP=34 FP=32 FN=11 AP=0.753
class 11 | P=0.833 R=0.806  TP=25 FP=5 FN=6 AP=0.887
class 13 | P=0.000 R=0.000  TP=0 FP=0 FN=85 AP=0.000
class 14 | P=0.000 R=0.000  TP=0 FP=0 FN=93 AP=0.000
class 15 | P=0.000 R=0.000  TP=0 FP=0 FN=31 AP=0.000
class 16 | P=0.400 R=0.333  TP=2 FP=3 FN=4 AP=0.300
class 18 | P=0.000 R=0.000  TP=0 FP=0 FN=37 AP=0.000
class 19 | P=0.000 R=0.000  TP=0 FP=0 FN=31 AP=0.000
class 20 | P=0.733 R=0.733  TP=11 FP=4 FN=4 AP=0.762
class 21 | P=1.000 R=0.143  TP=1 FP=0

                                                                      

Epoch 03 | train loss 0.2919 | val loss 0.3532 | mAP 0.6792
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.718 R=0.800  TP=28 FP=11 FN=7 AP=0.832
class 02 | P=0.789 R=0.857  TP=30 FP=8 FN=5 AP=0.892
class 04 | P=0.852 R=0.852  TP=23 FP=4 FN=4 AP=0.906
class 05 | P=0.893 R=0.667  TP=50 FP=6 FN=25 AP=0.806
class 06 | P=1.000 R=0.800  TP=4 FP=0 FN=1 AP=0.900
class 07 | P=0.000 R=0.000  TP=0 FP=1 FN=72 AP=0.000
class 09 | P=0.966 R=0.903  TP=28 FP=1 FN=3 AP=0.950
class 10 | P=0.704 R=0.844  TP=38 FP=16 FN=7 AP=0.880
class 11 | P=0.722 R=0.839  TP=26 FP=10 FN=5 AP=0.887
class 13 | P=1.000 R=0.024  TP=2 FP=0 FN=83 AP=0.512
class 14 | P=0.912 R=0.333  TP=31 FP=3 FN=62 AP=0.626
class 15 | P=0.950 R=0.613  TP=19 FP=1 FN=12 AP=0.781
class 16 | P=0.286 R=0.333  TP=2 FP=5 FN=4 AP=0.206
class 18 | P=1.000 R=0.081  TP=3 FP=0 FN=34 AP=0.541
class 19 | P=1.000 R=0.097  TP=3 FP=0 FN=28 AP=0.548
class 20 | P=0.786 R=0.733  TP=11 FP=3 FN=4 AP=0.808
class 21 | P=0.750 R=0.429  TP=3 FP

                                                                      

Epoch 04 | train loss 0.2484 | val loss 0.3071 | mAP 0.7506
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.800 R=0.800  TP=28 FP=7 FN=7 AP=0.864
class 02 | P=0.800 R=0.914  TP=32 FP=8 FN=3 AP=0.933
class 04 | P=0.862 R=0.926  TP=25 FP=4 FN=2 AP=0.950
class 05 | P=0.932 R=0.733  TP=55 FP=4 FN=20 AP=0.854
class 06 | P=1.000 R=0.800  TP=4 FP=0 FN=1 AP=0.900
class 07 | P=0.706 R=0.167  TP=12 FP=5 FN=60 AP=0.412
class 09 | P=0.933 R=0.903  TP=28 FP=2 FN=3 AP=0.948
class 10 | P=0.709 R=0.867  TP=39 FP=16 FN=6 AP=0.882
class 11 | P=0.634 R=0.839  TP=26 FP=15 FN=5 AP=0.867
class 13 | P=1.000 R=0.059  TP=5 FP=0 FN=80 AP=0.529
class 14 | P=0.750 R=0.839  TP=78 FP=26 FN=15 AP=0.865
class 15 | P=0.867 R=0.839  TP=26 FP=4 FN=5 AP=0.888
class 16 | P=0.600 R=0.500  TP=3 FP=2 FN=3 AP=0.566
class 18 | P=1.000 R=0.568  TP=21 FP=0 FN=16 AP=0.784
class 19 | P=0.917 R=0.710  TP=22 FP=2 FN=9 AP=0.837
class 20 | P=0.750 R=0.800  TP=12 FP=4 FN=3 AP=0.766
class 21 | P=0.286 R=0.286  TP=2 F

                                                                      

Epoch 05 | train loss 0.2702 | val loss 0.2924 | mAP 0.8014
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.744 R=0.829  TP=29 FP=10 FN=6 AP=0.884
class 02 | P=0.756 R=0.971  TP=34 FP=11 FN=1 AP=0.971
class 04 | P=0.897 R=0.963  TP=26 FP=3 FN=1 AP=0.973
class 05 | P=0.901 R=0.853  TP=64 FP=7 FN=11 AP=0.901
class 06 | P=0.833 R=1.000  TP=5 FP=1 FN=0 AP=0.962
class 07 | P=0.800 R=0.333  TP=24 FP=6 FN=48 AP=0.533
class 09 | P=0.938 R=0.968  TP=30 FP=2 FN=1 AP=0.981
class 10 | P=0.615 R=0.889  TP=40 FP=25 FN=5 AP=0.861
class 11 | P=0.758 R=0.806  TP=25 FP=8 FN=6 AP=0.873
class 13 | P=0.848 R=0.459  TP=39 FP=7 FN=46 AP=0.653
class 14 | P=0.849 R=0.849  TP=79 FP=14 FN=14 AP=0.882
class 15 | P=0.867 R=0.839  TP=26 FP=4 FN=5 AP=0.905
class 16 | P=0.750 R=0.500  TP=3 FP=1 FN=3 AP=0.686
class 18 | P=1.000 R=0.676  TP=25 FP=0 FN=12 AP=0.838
class 19 | P=0.926 R=0.806  TP=25 FP=2 FN=6 AP=0.886
class 20 | P=0.706 R=0.800  TP=12 FP=5 FN=3 AP=0.762
class 21 | P=0.600 R=0.429  TP=3

                                                                      

Epoch 06 | train loss 0.2430 | val loss 0.2736 | mAP 0.8187
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.848 R=0.800  TP=28 FP=5 FN=7 AP=0.882
class 02 | P=0.821 R=0.914  TP=32 FP=7 FN=3 AP=0.940
class 04 | P=0.867 R=0.963  TP=26 FP=4 FN=1 AP=0.979
class 05 | P=0.915 R=0.867  TP=65 FP=6 FN=10 AP=0.912
class 06 | P=0.833 R=1.000  TP=5 FP=1 FN=0 AP=0.995
class 07 | P=0.778 R=0.194  TP=14 FP=4 FN=58 AP=0.491
class 09 | P=0.938 R=0.968  TP=30 FP=2 FN=1 AP=0.982
class 10 | P=0.625 R=0.889  TP=40 FP=24 FN=5 AP=0.897
class 11 | P=0.676 R=0.806  TP=25 FP=12 FN=6 AP=0.861
class 13 | P=0.839 R=0.612  TP=52 FP=10 FN=33 AP=0.719
class 14 | P=0.764 R=0.871  TP=81 FP=25 FN=12 AP=0.912
class 15 | P=0.900 R=0.871  TP=27 FP=3 FN=4 AP=0.929
class 16 | P=0.500 R=0.500  TP=3 FP=3 FN=3 AP=0.540
class 18 | P=0.967 R=0.784  TP=29 FP=1 FN=8 AP=0.883
class 19 | P=1.000 R=0.871  TP=27 FP=0 FN=4 AP=0.935
class 20 | P=0.812 R=0.867  TP=13 FP=3 FN=2 AP=0.846
class 21 | P=0.800 R=0.571  TP=4 

                                                                      

Epoch 07 | train loss 0.2182 | val loss 0.2625 | mAP 0.8238
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.844 R=0.771  TP=27 FP=5 FN=8 AP=0.864
class 02 | P=0.744 R=0.914  TP=32 FP=11 FN=3 AP=0.939
class 04 | P=1.000 R=0.963  TP=26 FP=0 FN=1 AP=0.981
class 05 | P=0.929 R=0.867  TP=65 FP=5 FN=10 AP=0.915
class 06 | P=0.833 R=1.000  TP=5 FP=1 FN=0 AP=0.995
class 07 | P=0.694 R=0.347  TP=25 FP=11 FN=47 AP=0.501
class 09 | P=0.909 R=0.968  TP=30 FP=3 FN=1 AP=0.982
class 10 | P=0.714 R=0.889  TP=40 FP=16 FN=5 AP=0.916
class 11 | P=0.781 R=0.806  TP=25 FP=7 FN=6 AP=0.877
class 13 | P=0.742 R=0.776  TP=66 FP=23 FN=19 AP=0.763
class 14 | P=0.872 R=0.882  TP=82 FP=12 FN=11 AP=0.912
class 15 | P=0.967 R=0.935  TP=29 FP=1 FN=2 AP=0.946
class 16 | P=0.667 R=0.667  TP=4 FP=2 FN=2 AP=0.721
class 18 | P=0.966 R=0.757  TP=28 FP=1 FN=9 AP=0.868
class 19 | P=1.000 R=0.903  TP=28 FP=0 FN=3 AP=0.951
class 20 | P=0.706 R=0.800  TP=12 FP=5 FN=3 AP=0.815
class 21 | P=0.500 R=0.429  TP=3

                                                                      

Epoch 08 | train loss 0.2244 | val loss 0.2720 | mAP 0.8203
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.816 R=0.886  TP=31 FP=7 FN=4 AP=0.922
class 02 | P=0.780 R=0.914  TP=32 FP=9 FN=3 AP=0.937
class 04 | P=0.963 R=0.963  TP=26 FP=1 FN=1 AP=0.981
class 05 | P=0.928 R=0.853  TP=64 FP=5 FN=11 AP=0.898
class 06 | P=0.833 R=1.000  TP=5 FP=1 FN=0 AP=0.995
class 07 | P=0.794 R=0.375  TP=27 FP=7 FN=45 AP=0.569
class 09 | P=0.938 R=0.968  TP=30 FP=2 FN=1 AP=0.982
class 10 | P=0.700 R=0.933  TP=42 FP=18 FN=3 AP=0.932
class 11 | P=0.743 R=0.839  TP=26 FP=9 FN=5 AP=0.897
class 13 | P=0.780 R=0.753  TP=64 FP=18 FN=21 AP=0.784
class 14 | P=0.889 R=0.860  TP=80 FP=10 FN=13 AP=0.893
class 15 | P=1.000 R=0.806  TP=25 FP=0 FN=6 AP=0.903
class 16 | P=0.667 R=0.667  TP=4 FP=2 FN=2 AP=0.743
class 18 | P=0.969 R=0.838  TP=31 FP=1 FN=6 AP=0.910
class 19 | P=0.964 R=0.871  TP=27 FP=1 FN=4 AP=0.918
class 20 | P=0.750 R=0.800  TP=12 FP=4 FN=3 AP=0.859
class 21 | P=0.500 R=0.429  TP=3 F

                                                                      

Epoch 09 | train loss 0.2363 | val loss 0.2696 | mAP 0.8331
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.744 R=0.914  TP=32 FP=11 FN=3 AP=0.927
class 02 | P=0.762 R=0.914  TP=32 FP=10 FN=3 AP=0.943
class 04 | P=0.867 R=0.963  TP=26 FP=4 FN=1 AP=0.979
class 05 | P=0.903 R=0.867  TP=65 FP=7 FN=10 AP=0.901
class 06 | P=0.833 R=1.000  TP=5 FP=1 FN=0 AP=0.995
class 07 | P=0.815 R=0.611  TP=44 FP=10 FN=28 AP=0.681
class 09 | P=0.968 R=0.968  TP=30 FP=1 FN=1 AP=0.983
class 10 | P=0.782 R=0.956  TP=43 FP=12 FN=2 AP=0.949
class 11 | P=0.812 R=0.839  TP=26 FP=6 FN=5 AP=0.899
class 13 | P=0.789 R=0.835  TP=71 FP=19 FN=14 AP=0.837
class 14 | P=0.889 R=0.860  TP=80 FP=10 FN=13 AP=0.902
class 15 | P=0.967 R=0.935  TP=29 FP=1 FN=2 AP=0.964
class 16 | P=0.600 R=0.500  TP=3 FP=2 FN=3 AP=0.608
class 18 | P=0.958 R=0.622  TP=23 FP=1 FN=14 AP=0.781
class 19 | P=0.967 R=0.935  TP=29 FP=1 FN=2 AP=0.966
class 20 | P=0.765 R=0.867  TP=13 FP=4 FN=2 AP=0.893
class 21 | P=0.500 R=0.571  TP

                                                                      

Epoch 10 | train loss 0.1676 | val loss 0.2566 | mAP 0.8374
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.861 R=0.886  TP=31 FP=5 FN=4 AP=0.928
class 02 | P=0.838 R=0.886  TP=31 FP=6 FN=4 AP=0.926
class 04 | P=0.963 R=0.963  TP=26 FP=1 FN=1 AP=0.981
class 05 | P=0.868 R=0.880  TP=66 FP=10 FN=9 AP=0.906
class 06 | P=0.714 R=1.000  TP=5 FP=2 FN=0 AP=0.995
class 07 | P=0.788 R=0.569  TP=41 FP=11 FN=31 AP=0.619
class 09 | P=0.938 R=0.968  TP=30 FP=2 FN=1 AP=0.982
class 10 | P=0.638 R=0.978  TP=44 FP=25 FN=1 AP=0.944
class 11 | P=0.788 R=0.839  TP=26 FP=7 FN=5 AP=0.894
class 13 | P=0.734 R=0.812  TP=69 FP=25 FN=16 AP=0.778
class 14 | P=0.884 R=0.903  TP=84 FP=11 FN=9 AP=0.917
class 15 | P=1.000 R=0.839  TP=26 FP=0 FN=5 AP=0.919
class 16 | P=0.400 R=0.667  TP=4 FP=6 FN=2 AP=0.566
class 18 | P=0.971 R=0.919  TP=34 FP=1 FN=3 AP=0.953
class 19 | P=0.935 R=0.935  TP=29 FP=2 FN=2 AP=0.960
class 20 | P=0.722 R=0.867  TP=13 FP=5 FN=2 AP=0.863
class 21 | P=0.800 R=0.571  TP=4 F

                                                                      

Epoch 11 | train loss 0.1759 | val loss 0.2565 | mAP 0.8454
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.806 R=0.829  TP=29 FP=7 FN=6 AP=0.894
class 02 | P=0.842 R=0.914  TP=32 FP=6 FN=3 AP=0.945
class 04 | P=1.000 R=0.963  TP=26 FP=0 FN=1 AP=0.981
class 05 | P=0.892 R=0.880  TP=66 FP=8 FN=9 AP=0.908
class 06 | P=1.000 R=1.000  TP=5 FP=0 FN=0 AP=0.995
class 07 | P=0.855 R=0.653  TP=47 FP=8 FN=25 AP=0.744
class 09 | P=0.938 R=0.968  TP=30 FP=2 FN=1 AP=0.982
class 10 | P=0.754 R=0.956  TP=43 FP=14 FN=2 AP=0.957
class 11 | P=0.812 R=0.839  TP=26 FP=6 FN=5 AP=0.899
class 13 | P=0.779 R=0.788  TP=67 FP=19 FN=18 AP=0.816
class 14 | P=0.921 R=0.882  TP=82 FP=7 FN=11 AP=0.925
class 15 | P=0.935 R=0.935  TP=29 FP=2 FN=2 AP=0.964
class 16 | P=0.667 R=0.667  TP=4 FP=2 FN=2 AP=0.638
class 18 | P=0.970 R=0.865  TP=32 FP=1 FN=5 AP=0.924
class 19 | P=1.000 R=0.839  TP=26 FP=0 FN=5 AP=0.919
class 20 | P=0.778 R=0.933  TP=14 FP=4 FN=1 AP=0.888
class 21 | P=0.750 R=0.429  TP=3 FP=

                                                                      

Epoch 12 | train loss 0.1996 | val loss 0.2567 | mAP 0.8325
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.816 R=0.886  TP=31 FP=7 FN=4 AP=0.925
class 02 | P=0.805 R=0.943  TP=33 FP=8 FN=2 AP=0.958
class 04 | P=1.000 R=0.963  TP=26 FP=0 FN=1 AP=0.981
class 05 | P=0.917 R=0.880  TP=66 FP=6 FN=9 AP=0.917
class 06 | P=1.000 R=1.000  TP=5 FP=0 FN=0 AP=0.995
class 07 | P=0.738 R=0.667  TP=48 FP=17 FN=24 AP=0.655
class 09 | P=0.909 R=0.968  TP=30 FP=3 FN=1 AP=0.982
class 10 | P=0.764 R=0.933  TP=42 FP=13 FN=3 AP=0.945
class 11 | P=0.833 R=0.806  TP=25 FP=5 FN=6 AP=0.887
class 13 | P=0.758 R=0.812  TP=69 FP=22 FN=16 AP=0.831
class 14 | P=0.905 R=0.925  TP=86 FP=9 FN=7 AP=0.949
class 15 | P=0.966 R=0.903  TP=28 FP=1 FN=3 AP=0.949
class 16 | P=0.667 R=0.667  TP=4 FP=2 FN=2 AP=0.638
class 18 | P=0.971 R=0.892  TP=33 FP=1 FN=4 AP=0.941
class 19 | P=0.964 R=0.871  TP=27 FP=1 FN=4 AP=0.931
class 20 | P=0.824 R=0.933  TP=14 FP=3 FN=1 AP=0.874
class 21 | P=0.500 R=0.429  TP=3 FP=

                                                                      

Epoch 13 | train loss 0.1858 | val loss 0.2462 | mAP 0.8512
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.872 R=0.971  TP=34 FP=5 FN=1 AP=0.975
class 02 | P=0.825 R=0.943  TP=33 FP=7 FN=2 AP=0.954
class 04 | P=1.000 R=0.963  TP=26 FP=0 FN=1 AP=0.981
class 05 | P=0.904 R=0.880  TP=66 FP=7 FN=9 AP=0.915
class 06 | P=1.000 R=1.000  TP=5 FP=0 FN=0 AP=0.995
class 07 | P=0.745 R=0.528  TP=38 FP=13 FN=34 AP=0.588
class 09 | P=0.938 R=0.968  TP=30 FP=2 FN=1 AP=0.982
class 10 | P=0.741 R=0.956  TP=43 FP=15 FN=2 AP=0.961
class 11 | P=0.867 R=0.839  TP=26 FP=4 FN=5 AP=0.903
class 13 | P=0.753 R=0.859  TP=73 FP=24 FN=12 AP=0.856
class 14 | P=0.926 R=0.935  TP=87 FP=7 FN=6 AP=0.938
class 15 | P=0.964 R=0.871  TP=27 FP=1 FN=4 AP=0.933
class 16 | P=0.400 R=0.667  TP=4 FP=6 FN=2 AP=0.599
class 18 | P=0.971 R=0.919  TP=34 FP=1 FN=3 AP=0.950
class 19 | P=0.967 R=0.935  TP=29 FP=1 FN=2 AP=0.962
class 20 | P=0.765 R=0.867  TP=13 FP=4 FN=2 AP=0.898
class 21 | P=0.571 R=0.571  TP=4 FP=

                                                                      

Epoch 14 | train loss 0.1832 | val loss 0.2566 | mAP 0.8438
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.791 R=0.971  TP=34 FP=9 FN=1 AP=0.974
class 02 | P=0.767 R=0.943  TP=33 FP=10 FN=2 AP=0.944
class 04 | P=0.963 R=0.963  TP=26 FP=1 FN=1 AP=0.981
class 05 | P=0.930 R=0.880  TP=66 FP=5 FN=9 AP=0.915
class 06 | P=1.000 R=1.000  TP=5 FP=0 FN=0 AP=0.995
class 07 | P=0.875 R=0.486  TP=35 FP=5 FN=37 AP=0.691
class 09 | P=1.000 R=0.968  TP=30 FP=0 FN=1 AP=0.984
class 10 | P=0.754 R=0.956  TP=43 FP=14 FN=2 AP=0.953
class 11 | P=0.818 R=0.871  TP=27 FP=6 FN=4 AP=0.916
class 13 | P=0.806 R=0.882  TP=75 FP=18 FN=10 AP=0.866
class 14 | P=0.913 R=0.903  TP=84 FP=8 FN=9 AP=0.938
class 15 | P=0.931 R=0.871  TP=27 FP=2 FN=4 AP=0.931
class 16 | P=0.714 R=0.833  TP=5 FP=2 FN=1 AP=0.655
class 18 | P=0.967 R=0.784  TP=29 FP=1 FN=8 AP=0.887
class 19 | P=0.966 R=0.903  TP=28 FP=1 FN=3 AP=0.947
class 20 | P=0.875 R=0.933  TP=14 FP=2 FN=1 AP=0.930
class 21 | P=0.500 R=0.429  TP=3 FP=

                                                                      

Epoch 15 | train loss 0.1788 | val loss 0.2400 | mAP 0.8539
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.941 R=0.914  TP=32 FP=2 FN=3 AP=0.951
class 02 | P=0.805 R=0.943  TP=33 FP=8 FN=2 AP=0.950
class 04 | P=1.000 R=0.963  TP=26 FP=0 FN=1 AP=0.981
class 05 | P=0.917 R=0.880  TP=66 FP=6 FN=9 AP=0.896
class 06 | P=0.714 R=1.000  TP=5 FP=2 FN=0 AP=0.995
class 07 | P=0.860 R=0.597  TP=43 FP=7 FN=29 AP=0.709
class 09 | P=0.912 R=1.000  TP=31 FP=3 FN=0 AP=0.994
class 10 | P=0.494 R=0.978  TP=44 FP=45 FN=1 AP=0.934
class 11 | P=0.737 R=0.903  TP=28 FP=10 FN=3 AP=0.917
class 13 | P=0.776 R=0.894  TP=76 FP=22 FN=9 AP=0.905
class 14 | P=0.929 R=0.849  TP=79 FP=6 FN=14 AP=0.896
class 15 | P=0.938 R=0.968  TP=30 FP=2 FN=1 AP=0.982
class 16 | P=0.571 R=0.667  TP=4 FP=3 FN=2 AP=0.695
class 18 | P=1.000 R=0.892  TP=33 FP=0 FN=4 AP=0.946
class 19 | P=0.935 R=0.935  TP=29 FP=2 FN=2 AP=0.965
class 20 | P=0.765 R=0.867  TP=13 FP=4 FN=2 AP=0.828
class 21 | P=0.600 R=0.429  TP=3 FP=

                                                                      

Epoch 16 | train loss 0.1420 | val loss 0.2502 | mAP 0.8538
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.917 R=0.943  TP=33 FP=3 FN=2 AP=0.966
class 02 | P=0.791 R=0.971  TP=34 FP=9 FN=1 AP=0.966
class 04 | P=0.794 R=1.000  TP=27 FP=7 FN=0 AP=0.989
class 05 | P=0.905 R=0.893  TP=67 FP=7 FN=8 AP=0.904
class 06 | P=0.833 R=1.000  TP=5 FP=1 FN=0 AP=0.995
class 07 | P=0.833 R=0.625  TP=45 FP=9 FN=27 AP=0.712
class 09 | P=0.838 R=1.000  TP=31 FP=6 FN=0 AP=0.987
class 10 | P=0.694 R=0.956  TP=43 FP=19 FN=2 AP=0.949
class 11 | P=0.903 R=0.903  TP=28 FP=3 FN=3 AP=0.941
class 13 | P=0.795 R=0.824  TP=70 FP=18 FN=15 AP=0.861
class 14 | P=0.906 R=0.828  TP=77 FP=8 FN=16 AP=0.871
class 15 | P=0.903 R=0.903  TP=28 FP=3 FN=3 AP=0.945
class 16 | P=0.500 R=0.667  TP=4 FP=4 FN=2 AP=0.616
class 18 | P=0.889 R=0.865  TP=32 FP=4 FN=5 AP=0.911
class 19 | P=0.935 R=0.935  TP=29 FP=2 FN=2 AP=0.949
class 20 | P=0.875 R=0.933  TP=14 FP=2 FN=1 AP=0.945
class 21 | P=0.600 R=0.429  TP=3 FP=

                                                                      

Epoch 17 | train loss 0.1386 | val loss 0.2581 | mAP 0.8503
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.821 R=0.914  TP=32 FP=7 FN=3 AP=0.946
class 02 | P=0.791 R=0.971  TP=34 FP=9 FN=1 AP=0.968
class 04 | P=0.929 R=0.963  TP=26 FP=2 FN=1 AP=0.980
class 05 | P=0.905 R=0.893  TP=67 FP=7 FN=8 AP=0.918
class 06 | P=0.833 R=1.000  TP=5 FP=1 FN=0 AP=0.995
class 07 | P=0.823 R=0.708  TP=51 FP=11 FN=21 AP=0.744
class 09 | P=0.886 R=1.000  TP=31 FP=4 FN=0 AP=0.992
class 10 | P=0.688 R=0.978  TP=44 FP=20 FN=1 AP=0.953
class 11 | P=0.903 R=0.903  TP=28 FP=3 FN=3 AP=0.940
class 13 | P=0.847 R=0.847  TP=72 FP=13 FN=13 AP=0.870
class 14 | P=0.903 R=0.903  TP=84 FP=9 FN=9 AP=0.928
class 15 | P=0.871 R=0.871  TP=27 FP=4 FN=4 AP=0.912
class 16 | P=0.444 R=0.667  TP=4 FP=5 FN=2 AP=0.518
class 18 | P=1.000 R=0.838  TP=31 FP=0 FN=6 AP=0.919
class 19 | P=0.964 R=0.871  TP=27 FP=1 FN=4 AP=0.923
class 20 | P=0.824 R=0.933  TP=14 FP=3 FN=1 AP=0.923
class 21 | P=0.571 R=0.571  TP=4 FP=

                                                                      

Epoch 18 | train loss 0.1343 | val loss 0.2400 | mAP 0.8551
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.868 R=0.943  TP=33 FP=5 FN=2 AP=0.958
class 02 | P=0.825 R=0.943  TP=33 FP=7 FN=2 AP=0.948
class 04 | P=0.929 R=0.963  TP=26 FP=2 FN=1 AP=0.980
class 05 | P=0.917 R=0.880  TP=66 FP=6 FN=9 AP=0.913
class 06 | P=0.714 R=1.000  TP=5 FP=2 FN=0 AP=0.995
class 07 | P=0.842 R=0.667  TP=48 FP=9 FN=24 AP=0.737
class 09 | P=0.939 R=1.000  TP=31 FP=2 FN=0 AP=0.993
class 10 | P=0.811 R=0.956  TP=43 FP=10 FN=2 AP=0.964
class 11 | P=0.906 R=0.935  TP=29 FP=3 FN=2 AP=0.959
class 13 | P=0.780 R=0.835  TP=71 FP=20 FN=14 AP=0.877
class 14 | P=0.933 R=0.903  TP=84 FP=6 FN=9 AP=0.935
class 15 | P=0.966 R=0.903  TP=28 FP=1 FN=3 AP=0.945
class 16 | P=0.556 R=0.833  TP=5 FP=4 FN=1 AP=0.636
class 18 | P=1.000 R=0.838  TP=31 FP=0 FN=6 AP=0.919
class 19 | P=0.964 R=0.871  TP=27 FP=1 FN=4 AP=0.926
class 20 | P=0.875 R=0.933  TP=14 FP=2 FN=1 AP=0.935
class 21 | P=0.600 R=0.429  TP=3 FP=2

                                                                      

Epoch 19 | train loss 0.1703 | val loss 0.2494 | mAP 0.8560
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.865 R=0.914  TP=32 FP=5 FN=3 AP=0.946
class 02 | P=0.750 R=0.943  TP=33 FP=11 FN=2 AP=0.946
class 04 | P=0.929 R=0.963  TP=26 FP=2 FN=1 AP=0.980
class 05 | P=0.917 R=0.880  TP=66 FP=6 FN=9 AP=0.908
class 06 | P=0.667 R=0.800  TP=4 FP=2 FN=1 AP=0.865
class 07 | P=0.827 R=0.597  TP=43 FP=9 FN=29 AP=0.677
class 09 | P=0.833 R=0.968  TP=30 FP=6 FN=1 AP=0.980
class 10 | P=0.854 R=0.911  TP=41 FP=7 FN=4 AP=0.947
class 11 | P=0.906 R=0.935  TP=29 FP=3 FN=2 AP=0.963
class 13 | P=0.896 R=0.812  TP=69 FP=8 FN=16 AP=0.878
class 14 | P=0.933 R=0.903  TP=84 FP=6 FN=9 AP=0.941
class 15 | P=0.966 R=0.903  TP=28 FP=1 FN=3 AP=0.947
class 16 | P=0.625 R=0.833  TP=5 FP=3 FN=1 AP=0.822
class 18 | P=0.943 R=0.892  TP=33 FP=2 FN=4 AP=0.941
class 19 | P=0.931 R=0.871  TP=27 FP=2 FN=4 AP=0.925
class 20 | P=0.875 R=0.933  TP=14 FP=2 FN=1 AP=0.904
class 21 | P=0.750 R=0.429  TP=3 FP=1 

                                                                      

Epoch 20 | train loss 0.1537 | val loss 0.2631 | mAP 0.8544
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.780 R=0.914  TP=32 FP=9 FN=3 AP=0.943
class 02 | P=0.750 R=0.943  TP=33 FP=11 FN=2 AP=0.944
class 04 | P=0.929 R=0.963  TP=26 FP=2 FN=1 AP=0.980
class 05 | P=0.909 R=0.933  TP=70 FP=7 FN=5 AP=0.937
class 06 | P=0.833 R=1.000  TP=5 FP=1 FN=0 AP=0.995
class 07 | P=0.773 R=0.708  TP=51 FP=15 FN=21 AP=0.679
class 09 | P=0.909 R=0.968  TP=30 FP=3 FN=1 AP=0.982
class 10 | P=0.732 R=0.911  TP=41 FP=15 FN=4 AP=0.933
class 11 | P=0.903 R=0.903  TP=28 FP=3 FN=3 AP=0.942
class 13 | P=0.841 R=0.812  TP=69 FP=13 FN=16 AP=0.874
class 14 | P=0.923 R=0.903  TP=84 FP=7 FN=9 AP=0.919
class 15 | P=0.900 R=0.871  TP=27 FP=3 FN=4 AP=0.918
class 16 | P=0.625 R=0.833  TP=5 FP=3 FN=1 AP=0.689
class 18 | P=1.000 R=0.811  TP=30 FP=0 FN=7 AP=0.905
class 19 | P=0.966 R=0.903  TP=28 FP=1 FN=3 AP=0.943
class 20 | P=0.812 R=0.867  TP=13 FP=3 FN=2 AP=0.903
class 21 | P=0.750 R=0.429  TP=3 FP

                                                                      

Epoch 21 | train loss 0.1295 | val loss 0.2599 | mAP 0.8591
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.821 R=0.914  TP=32 FP=7 FN=3 AP=0.941
class 02 | P=0.805 R=0.943  TP=33 FP=8 FN=2 AP=0.950
class 04 | P=0.963 R=0.963  TP=26 FP=1 FN=1 AP=0.981
class 05 | P=0.915 R=0.867  TP=65 FP=6 FN=10 AP=0.905
class 06 | P=0.833 R=1.000  TP=5 FP=1 FN=0 AP=0.995
class 07 | P=0.807 R=0.639  TP=46 FP=11 FN=26 AP=0.699
class 09 | P=0.882 R=0.968  TP=30 FP=4 FN=1 AP=0.981
class 10 | P=0.772 R=0.978  TP=44 FP=13 FN=1 AP=0.967
class 11 | P=0.963 R=0.839  TP=26 FP=1 FN=5 AP=0.915
class 13 | P=0.863 R=0.812  TP=69 FP=11 FN=16 AP=0.879
class 14 | P=0.914 R=0.914  TP=85 FP=8 FN=8 AP=0.938
class 15 | P=0.933 R=0.903  TP=28 FP=2 FN=3 AP=0.944
class 16 | P=0.500 R=0.667  TP=4 FP=4 FN=2 AP=0.692
class 18 | P=0.971 R=0.892  TP=33 FP=1 FN=4 AP=0.938
class 19 | P=1.000 R=0.839  TP=26 FP=0 FN=5 AP=0.919
class 20 | P=0.875 R=0.933  TP=14 FP=2 FN=1 AP=0.911
class 21 | P=0.667 R=0.571  TP=4 FP

                                                                      

Epoch 22 | train loss 0.1403 | val loss 0.2596 | mAP 0.8675
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.865 R=0.914  TP=32 FP=5 FN=3 AP=0.947
class 02 | P=0.829 R=0.971  TP=34 FP=7 FN=1 AP=0.955
class 04 | P=0.963 R=0.963  TP=26 FP=1 FN=1 AP=0.981
class 05 | P=0.896 R=0.920  TP=69 FP=8 FN=6 AP=0.932
class 06 | P=0.714 R=1.000  TP=5 FP=2 FN=0 AP=0.995
class 07 | P=0.864 R=0.708  TP=51 FP=8 FN=21 AP=0.781
class 09 | P=0.857 R=0.968  TP=30 FP=5 FN=1 AP=0.981
class 10 | P=0.786 R=0.978  TP=44 FP=12 FN=1 AP=0.964
class 11 | P=0.848 R=0.903  TP=28 FP=5 FN=3 AP=0.934
class 13 | P=0.812 R=0.812  TP=69 FP=16 FN=16 AP=0.850
class 14 | P=0.902 R=0.892  TP=83 FP=9 FN=10 AP=0.920
class 15 | P=0.867 R=0.839  TP=26 FP=4 FN=5 AP=0.905
class 16 | P=0.714 R=0.833  TP=5 FP=2 FN=1 AP=0.811
class 18 | P=0.943 R=0.892  TP=33 FP=2 FN=4 AP=0.942
class 19 | P=0.929 R=0.839  TP=26 FP=2 FN=5 AP=0.913
class 20 | P=0.875 R=0.933  TP=14 FP=2 FN=1 AP=0.911
class 21 | P=0.667 R=0.571  TP=4 FP=

                                                                      

Epoch 23 | train loss 0.1326 | val loss 0.2610 | mAP 0.8554
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=41 AP=0.000
class 01 | P=0.821 R=0.914  TP=32 FP=7 FN=3 AP=0.947
class 02 | P=0.791 R=0.971  TP=34 FP=9 FN=1 AP=0.957
class 04 | P=0.929 R=0.963  TP=26 FP=2 FN=1 AP=0.980
class 05 | P=0.918 R=0.893  TP=67 FP=6 FN=8 AP=0.916
class 06 | P=0.714 R=1.000  TP=5 FP=2 FN=0 AP=0.995
class 07 | P=0.862 R=0.694  TP=50 FP=8 FN=22 AP=0.747
class 09 | P=0.857 R=0.968  TP=30 FP=5 FN=1 AP=0.981
class 10 | P=0.854 R=0.911  TP=41 FP=7 FN=4 AP=0.943
class 11 | P=0.933 R=0.903  TP=28 FP=2 FN=3 AP=0.945
class 13 | P=0.787 R=0.824  TP=70 FP=19 FN=15 AP=0.839
class 14 | P=0.884 R=0.903  TP=84 FP=11 FN=9 AP=0.910
class 15 | P=0.933 R=0.903  TP=28 FP=2 FN=3 AP=0.948
class 16 | P=0.462 R=1.000  TP=6 FP=7 FN=0 AP=0.760
class 18 | P=0.939 R=0.838  TP=31 FP=2 FN=6 AP=0.910
class 19 | P=0.938 R=0.968  TP=30 FP=2 FN=1 AP=0.968
class 20 | P=0.812 R=0.867  TP=13 FP=3 FN=2 AP=0.870
class 21 | P=0.667 R=0.571  TP=4 FP=2

Infer: 100%|██████████| 200/200 [00:37<00:00,  5.28it/s]


Inference @ IoU 0.50 | mAP 0.8599
class 00 | P=0.000 R=0.000  TP=0 FP=0 FN=8 AP=0.000
class 01 | P=0.967 R=0.989  TP=89 FP=3 FN=1 AP=0.994
class 02 | P=0.889 R=0.941  TP=16 FP=2 FN=1 AP=0.927
class 04 | P=1.000 R=0.842  TP=16 FP=0 FN=3 AP=0.921
class 05 | P=0.647 R=0.917  TP=11 FP=6 FN=1 AP=0.943
class 06 | P=0.500 R=1.000  TP=1 FP=1 FN=0 AP=0.995
class 07 | P=0.909 R=0.741  TP=20 FP=2 FN=7 AP=0.823
class 09 | P=0.952 R=1.000  TP=20 FP=1 FN=0 AP=0.983
class 10 | P=0.917 R=0.957  TP=22 FP=2 FN=1 AP=0.974
class 11 | P=1.000 R=0.944  TP=17 FP=0 FN=1 AP=0.972
class 13 | P=0.733 R=0.917  TP=22 FP=8 FN=2 AP=0.928
class 14 | P=1.000 R=0.909  TP=20 FP=0 FN=2 AP=0.955
class 15 | P=0.894 R=0.854  TP=76 FP=9 FN=13 AP=0.890
class 16 | P=0.667 R=1.000  TP=2 FP=1 FN=0 AP=0.828
class 17 | P=0.000 R=0.000  TP=0 FP=0 FN=2 AP=0.000
class 18 | P=0.918 R=0.817  TP=67 FP=6 FN=15 AP=0.864
class 19 | P=0.975 R=0.898  TP=79 FP=2 FN=9 AP=0.945
class 20 | P=1.000 R=1.000  TP=5 FP=0 FN=0 AP=0.995
class 21 | P=0.

**提示：** `run_inference` 会在返回的 `metrics` 中附带 `false_positive_images` 和 `false_positive_stems`。 如果想在下一次训练时忽略这些样本，可以将 `TrainingConfig.exclude_samples` 设为 `tuple(metrics["false_positive_stems"])`，或将 `metrics["false_positive_stems"]` 写入文本文件， 然后通过命令行参数 `--exclude-list` 传入。