# Electrical Component Detection Pipeline

This notebook consolidates the refactored Faster R-CNN training and inference workflow into a single place for convenient experimentation on Kaggle. It provides reusable configuration objects, dataset loaders with optional augmentation, detailed metric utilities (including per-class TP/FP/FN and mAP), and helpers for both training and inference.


In [None]:
import argparse
import multiprocessing as mp
import os
import warnings
import json
import logging
import random
import contextlib
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union

import numpy as np
import pandas as pd
import torch
from PIL import Image as PILImage, ImageDraw, ImageEnhance, ImageFont
from torch import Tensor, nn
from torch.cuda.amp import GradScaler
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from torchvision.models.detection import (
    FasterRCNN_ResNet50_FPN_V2_Weights,
    fasterrcnn_resnet50_fpn_v2,
)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
LOGGER = logging.getLogger("notebook")


## Configuration objects


In [None]:
DEFAULT_PRETRAINED_URL = "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth"

#0.999 menas class does not exist
DEFAULT_CLASS_SCORE_THRESHOLDS = {
    3: 0.999,  
    6: 0.8,
    7: 0.9,
    8: 0.999,
    12: 0.999,
    16: 0.9,
    17: 0.999,
    20: 0.9,
    21: 0.8,
    24: 0.999,
    25: 0.9,
    26: 0.999,
    30: 0.9,
}


@dataclass
class DatasetConfig:
    """Configuration describing the dataset layout and metadata."""

    base_dir: Path = Path("data")
    train_split: str = "train"
    valid_split: str = "valid"
    test_split: str = "test"
    image_folder: str = "images"
    label_folder: str = "labels"
    num_classes: int = 32
    class_names: Tuple[str, ...] = ()

    def __post_init__(self) -> None:
        if not self.class_names:
            self.class_names = tuple(f"class_{idx:02d}" for idx in range(self.num_classes))


@dataclass
class TrainingConfig:
    """Hyper-parameters and runtime settings for model training."""

    epochs: int = 20
    batch_size: int = 4
    learning_rate: float = 1e-4
    weight_decay: float = 1e-4
    num_workers: int = 0
    amp: bool = True
    augmentation: bool = True
    mosaic_prob: float = 0.0
    mixup_prob: float = 0.0
    mixup_alpha: float = 0.4
    scale_jitter_min: float = 1.0
    scale_jitter_max: float = 1.0
    small_object: bool = True
    score_threshold: float = 0.6
    iou_threshold: float = 0.5
    eval_interval: int = 1
    seed: int = 2024
    output_dir: Path = Path("outputs")
    checkpoint_path: Path = Path("outputs/best_model.pth")
    pretrained_weights_path: Path = Path("weights/fasterrcnn_resnet50_fpn_v2_coco.pth")
    pretrained_weights_url: str = DEFAULT_PRETRAINED_URL
    log_every: int = 20
    class_score_thresholds: Dict[int, float] = field(
        default_factory=lambda: DEFAULT_CLASS_SCORE_THRESHOLDS.copy()
    )
    exclude_samples: Tuple[str, ...] = tuple()
    fp_visual_dir: Optional[Path] = Path("outputs/fp_images")
    fp_report_path: Optional[Path] = None
    fp_list_path: Optional[Path] = None
    fp_classes: Tuple[int, ...] = (16, 30)

    def ensure_directories(self) -> None:
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.pretrained_weights_path.parent.mkdir(parents=True, exist_ok=True)
        if self.fp_visual_dir:
            self.fp_visual_dir.mkdir(parents=True, exist_ok=True)


@dataclass
class InferenceConfig:
    """Options for running model inference and visualisation."""

    score_threshold: float = 0.6
    max_images: int = 50
    output_dir: Path = Path("outputs/inference")
    draw_ground_truth: bool = True
    class_colors: List[str] = field(default_factory=list)
    class_score_thresholds: Dict[int, float] = field(
        default_factory=lambda: DEFAULT_CLASS_SCORE_THRESHOLDS.copy()
    )

    def ensure_directories(self) -> None:
        self.output_dir.mkdir(parents=True, exist_ok=True)



## Dataset loading and augmentation


In [None]:

@dataclass
class AugmentationParams:
    """Parameters controlling the dataset level image augmentations."""

    horizontal_flip_prob: float = 0.5
    vertical_flip_prob: float = 0.2
    brightness: float = 0.2
    contrast: float = 0.2
    saturation: float = 0.2
    hue: float = 0.02
    mosaic_prob: float = 0.0
    mixup_prob: float = 0.0
    mixup_alpha: float = 0.4
    scale_jitter_range: Tuple[float, float] = (1.0, 1.0)


def load_image_hwc_uint8(path: Path) -> np.ndarray:
    """Load an ``.npy`` image stored as HWC and return an ``uint8`` array."""
    image = np.load(path, allow_pickle=False, mmap_mode="r")

    if image.dtype != np.uint8:
        image = image.astype(np.float32, copy=False)
        vmin, vmax = float(image.min()), float(image.max())
        if 0.0 <= vmin and vmax <= 1.0:
            image = (image * 255.0).round()
        elif -1.0 <= vmin and vmax <= 1.0:
            image = ((image + 1.0) * 0.5 * 255.0).round()
        image = np.clip(image, 0, 255).astype(np.uint8)

    channels = image.shape[2]
    if channels == 1:
        image = np.repeat(image, 3, axis=2)
    elif channels == 4:
        image = image[..., :3]
    return image


class ElectricalComponentsDataset(Dataset):
    """Dataset of electrical component detections stored as ``.npy`` images and CSV labels."""

    def __init__(
        self,
        root: Path,
        split: str,
        class_names: Iterable[str],
        transform: Optional[AugmentationParams] = None,
        use_augmentation: bool = False,
        exclude_stems: Optional[Iterable[str]] = None,
    ) -> None:
        self.root = Path(root)
        self.split = split
        self.class_names = list(class_names)
        self.transform_params = transform or AugmentationParams()
        self.use_augmentation = use_augmentation

        self.image_dir = self.root / split / "images"
        self.label_dir = self.root / split / "labels"

        if not self.image_dir.exists():
            raise FileNotFoundError(f"Missing image directory: {self.image_dir}")
        if not self.label_dir.exists():
            raise FileNotFoundError(f"Missing label directory: {self.label_dir}")

        self.image_stems = sorted(p.stem for p in self.label_dir.glob("*.csv"))
        if not self.image_stems:
            raise RuntimeError(f"No label files found in {self.label_dir}")

        exclude_set: Set[str] = set()
        if exclude_stems:
            exclude_set = {Path(stem).stem for stem in exclude_stems}
            if exclude_set:
                before = len(self.image_stems)
                self.image_stems = [stem for stem in self.image_stems if stem not in exclude_set]
                removed = before - len(self.image_stems)
                if removed:
                    LOGGER.info(
                        "Split %s: excluded %d samples based on provided stem filter.",
                        self.split,
                        removed,
                    )

        self.excluded_stems = sorted(exclude_set)

        # Pre-load all annotations to reduce I/O during training.
        self.annotations: Dict[str, pd.DataFrame] = {
            stem: pd.read_csv(self.label_dir / f"{stem}.csv") for stem in self.image_stems
        }

    def __len__(self) -> int:
        return len(self.image_stems)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        stem = self.image_stems[index]
        image, boxes, labels = self._load_raw_sample(stem)

        if self.use_augmentation:
            image, boxes, labels = self._apply_composite_augmentations(stem, image, boxes, labels)
            image, boxes = self._apply_augmentations(image, boxes)

        height, width = image.shape[:2]

        image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        boxes_tensor = (
            torch.from_numpy(boxes).float() if boxes.size else torch.zeros((0, 4), dtype=torch.float32)
        )
        labels_tensor = (
            torch.from_numpy(labels).long() if labels.size else torch.zeros((0,), dtype=torch.long)
        )

        boxes_tensor, labels_tensor = sanitize_boxes_and_labels(
            boxes_tensor, labels_tensor, height, width
        )

        target: Dict[str, torch.Tensor] = {
            "boxes": boxes_tensor,
            "labels": labels_tensor,
            "image_id": torch.tensor(index, dtype=torch.int64),
            "area": (boxes_tensor[:, 2] - boxes_tensor[:, 0])
            * (boxes_tensor[:, 3] - boxes_tensor[:, 1])
            if boxes_tensor.numel()
            else torch.tensor([], dtype=torch.float32),
            "iscrowd": torch.zeros((boxes_tensor.shape[0],), dtype=torch.int64),
            "orig_size": torch.tensor([height, width], dtype=torch.int64),
        }

        return image_tensor, target

    def _load_raw_sample(self, stem: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        image_path = self.image_dir / f"{stem}.npy"
        image = load_image_hwc_uint8(image_path)
        height, width = image.shape[:2]

        ann = self.annotations[stem]
        boxes, labels = self._annotation_to_boxes(ann, width, height)
        return image, boxes, labels

    @staticmethod
    def _annotation_to_boxes(
        ann: pd.DataFrame, width: int, height: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        if ann.empty:
            return np.zeros((0, 4), dtype=np.float32), np.zeros((0,), dtype=np.int64)

        x_center = ann["x_center"].to_numpy(dtype=np.float32)
        y_center = ann["y_center"].to_numpy(dtype=np.float32)
        box_width = ann["width"].to_numpy(dtype=np.float32)
        box_height = ann["height"].to_numpy(dtype=np.float32)

        # Auto-detect normalised coordinates and scale back to pixel space.
        if (
            (x_center.size == 0 or float(x_center.max()) <= 1.0)
            and (y_center.size == 0 or float(y_center.max()) <= 1.0)
            and (box_width.size == 0 or float(box_width.max()) <= 1.0)
            and (box_height.size == 0 or float(box_height.max()) <= 1.0)
        ):
            x_center = x_center * width
            y_center = y_center * height
            box_width = box_width * width
            box_height = box_height * height

        x1 = x_center - box_width / 2.0
        y1 = y_center - box_height / 2.0
        x2 = x_center + box_width / 2.0
        y2 = y_center + box_height / 2.0

        boxes = np.stack([x1, y1, x2, y2], axis=1).astype(np.float32)
        labels = ann["class"].to_numpy(dtype=np.int64)
        return boxes, labels

    def _apply_composite_augmentations(
        self,
        stem: str,
        image: np.ndarray,
        boxes: np.ndarray,
        labels: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        params = self.transform_params

        if (
            params.mosaic_prob > 0.0
            and random.random() < params.mosaic_prob
            and len(self.image_stems) >= 4
        ):
            image, boxes, labels = self._apply_mosaic(stem, image, boxes, labels)

        if (
            params.mixup_prob > 0.0
            and params.mixup_alpha > 0.0
            and random.random() < params.mixup_prob
        ):
            image, boxes, labels = self._apply_mixup(stem, image, boxes, labels)

        if params.scale_jitter_range != (1.0, 1.0):
            image, boxes = self._apply_scale_jitter(image, boxes, params.scale_jitter_range)

        return image, boxes, labels

    def _apply_augmentations(
        self, image: np.ndarray, boxes: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        params = self.transform_params
        height, width = image.shape[:2]

        if boxes.size and random.random() < params.horizontal_flip_prob:
            image = np.ascontiguousarray(image[:, ::-1, :])
            x1 = width - boxes[:, 2]
            x2 = width - boxes[:, 0]
            boxes[:, 0], boxes[:, 2] = x1, x2

        if boxes.size and random.random() < params.vertical_flip_prob:
            image = np.ascontiguousarray(image[::-1, :, :])
            y1 = height - boxes[:, 3]
            y2 = height - boxes[:, 1]
            boxes[:, 1], boxes[:, 3] = y1, y2

        if params.brightness or params.contrast or params.saturation or params.hue:
            pil = PILImage.fromarray(image)
            if params.brightness:
                enhancer = ImageEnhance.Brightness(pil)
                factor = 1.0 + random.uniform(-params.brightness, params.brightness)
                pil = enhancer.enhance(max(0.1, factor))
            if params.contrast:
                enhancer = ImageEnhance.Contrast(pil)
                factor = 1.0 + random.uniform(-params.contrast, params.contrast)
                pil = enhancer.enhance(max(0.1, factor))
            if params.saturation:
                enhancer = ImageEnhance.Color(pil)
                factor = 1.0 + random.uniform(-params.saturation, params.saturation)
                pil = enhancer.enhance(max(0.1, factor))
            if params.hue:
                hsv_image = pil.convert("HSV")
                h_channel, s_channel, v_channel = hsv_image.split()
                hue_array = np.array(h_channel, dtype=np.uint8)
                delta = int(params.hue * 255.0 * random.choice([-1, 1]))
                hue_adjusted = ((hue_array.astype(int) + delta) % 255).astype(np.uint8)
                new_h = PILImage.fromarray(hue_adjusted)
                hsv_image = PILImage.merge("HSV", (new_h, s_channel, v_channel))
                pil = hsv_image.convert("RGB")
            image = np.array(pil)

        if boxes.size:
            boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, width)
            boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, height)
        return image, boxes

    def _apply_scale_jitter(
        self, image: np.ndarray, boxes: np.ndarray, scale_range: Tuple[float, float]
    ) -> Tuple[np.ndarray, np.ndarray]:
        min_scale, max_scale = scale_range
        if max_scale <= 0 or min_scale <= 0:
            return image, boxes

        factor = random.uniform(min_scale, max_scale)
        if np.isclose(factor, 1.0):
            return image, boxes

        height, width = image.shape[:2]
        new_height = max(1, int(round(height * factor)))
        new_width = max(1, int(round(width * factor)))

        pil = PILImage.fromarray(image)
        resized = pil.resize((new_width, new_height), resample=PILImage.BILINEAR)
        image = np.array(resized)

        if boxes.size:
            boxes = boxes.copy()
            boxes[:, [0, 2]] *= float(new_width) / float(width)
            boxes[:, [1, 3]] *= float(new_height) / float(height)
        return image, boxes

    def _apply_mosaic(
        self,
        stem: str,
        image: np.ndarray,
        boxes: np.ndarray,
        labels: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        height, width = image.shape[:2]
        candidate_stems = [s for s in self.image_stems if s != stem]
        if len(candidate_stems) < 3:
            return image, boxes, labels

        selected = random.sample(candidate_stems, 3)

        images: List[np.ndarray] = [image]
        boxes_list: List[np.ndarray] = [boxes]
        labels_list: List[np.ndarray] = [labels]

        for other_stem in selected:
            other_img, other_boxes, other_labels = self._load_raw_sample(other_stem)
            other_img, other_boxes = self._resize_like(other_img, other_boxes, width, height)
            images.append(other_img)
            boxes_list.append(other_boxes)
            labels_list.append(other_labels)

        canvas = np.zeros((height * 2, width * 2, 3), dtype=np.uint8)
        offsets = [(0, 0), (0, width), (height, 0), (height, width)]
        combined_boxes: List[np.ndarray] = []
        combined_labels: List[np.ndarray] = []

        for img, bxs, lbls, (y_off, x_off) in zip(images, boxes_list, labels_list, offsets):
            canvas[y_off : y_off + height, x_off : x_off + width] = img
            if bxs.size:
                shifted = bxs.copy()
                shifted[:, [0, 2]] += x_off
                shifted[:, [1, 3]] += y_off
                combined_boxes.append(shifted)
                combined_labels.append(lbls)

        if combined_boxes:
            boxes = np.concatenate(combined_boxes, axis=0)
            labels = np.concatenate(combined_labels, axis=0)
        else:
            boxes = np.zeros((0, 4), dtype=np.float32)
            labels = np.zeros((0,), dtype=np.int64)

        crop_x = random.randint(0, width)
        crop_y = random.randint(0, height)
        canvas = canvas[crop_y : crop_y + height, crop_x : crop_x + width]

        if boxes.size:
            boxes = boxes.copy()
            boxes[:, [0, 2]] -= crop_x
            boxes[:, [1, 3]] -= crop_y

            keep = (
                (boxes[:, 2] > 0)
                & (boxes[:, 3] > 0)
                & (boxes[:, 0] < width)
                & (boxes[:, 1] < height)
            )
            boxes = boxes[keep]
            labels = labels[keep]
            boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, width)
            boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, height)

        return canvas, boxes, labels

    def _apply_mixup(
        self,
        stem: str,
        image: np.ndarray,
        boxes: np.ndarray,
        labels: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        other_stem = self._sample_alternative_stem(stem)
        if other_stem is None:
            return image, boxes, labels

        other_img, other_boxes, other_labels = self._load_raw_sample(other_stem)
        height, width = image.shape[:2]
        other_img, other_boxes = self._resize_like(other_img, other_boxes, width, height)

        alpha = max(self.transform_params.mixup_alpha, 1e-3)
        lam = np.random.beta(alpha, alpha)
        lam = float(np.clip(lam, 0.3, 0.7))

        mixed = (
            image.astype(np.float32) * lam + other_img.astype(np.float32) * (1.0 - lam)
        ).astype(np.uint8)

        if boxes.size and other_boxes.size:
            boxes = np.concatenate([boxes, other_boxes], axis=0)
            labels = np.concatenate([labels, other_labels], axis=0)
        elif other_boxes.size:
            boxes = other_boxes.copy()
            labels = other_labels.copy()

        return mixed, boxes, labels

    def _sample_alternative_stem(self, current: str) -> Optional[str]:
        if len(self.image_stems) <= 1:
            return None
        candidates = [stem for stem in self.image_stems if stem != current]
        if not candidates:
            return None
        return random.choice(candidates)

    @staticmethod
    def _resize_like(
        image: np.ndarray, boxes: np.ndarray, target_width: int, target_height: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        height, width = image.shape[:2]
        if height == target_height and width == target_width:
            return image, boxes

        pil = PILImage.fromarray(image)
        resized = pil.resize((target_width, target_height), resample=PILImage.BILINEAR)
        image = np.array(resized)

        if boxes.size:
            boxes = boxes.copy()
            boxes[:, [0, 2]] *= float(target_width) / float(width)
            boxes[:, [1, 3]] *= float(target_height) / float(height)
        return image, boxes


def detection_collate(batch: List[Tuple[torch.Tensor, Dict[str, torch.Tensor]]]):
    """Collate function for detection datasets returning lists of tensors."""
    images, targets = zip(*batch)
    return list(images), list(targets)


def _safe_worker_count(requested: int) -> int:
    cpu_count = os.cpu_count() or 1
    if requested <= 0:
        return 0
    max_workers = max(1, cpu_count - 1)
    return min(requested, max_workers)

def _running_in_ipython_kernel() -> bool:
    """Return ``True`` when executing inside an IPython/Jupyter kernel."""

    try:  # ``IPython`` is an optional dependency in our runtime.
        from IPython import get_ipython  # type: ignore
    except Exception:  # pragma: no cover - depends on environment
        return False

    shell = get_ipython()
    return bool(shell and getattr(shell, "kernel", None))


def emit_metric_lines(
    lines: Sequence[str],
    *,
    logger: Optional[logging.Logger] = None,
    force_print: Optional[bool] = None,
) -> None:
    """Log metric lines and optionally mirror them with ``print`` output."""

    if logger is None:
        logger = LOGGER

    should_print = force_print if force_print is not None else _running_in_ipython_kernel()

    for line in lines:
        if logger is not None:
            logger.info(line)
        if should_print:
            print(line)


def _should_force_single_worker(dataset: Dataset) -> bool:
    """Determine whether multiprocessing workers should be disabled."""

    module_name = getattr(dataset.__class__, "__module__", "")
    if module_name in {"__main__", "__mp_main__", "builtins"}:
        return True

    if module_name.startswith("ipykernel"):  # pragma: no cover - notebook specific
        return True

    return _running_in_ipython_kernel()


def create_data_loaders(
    dataset: Dataset,
    batch_size: int,
    shuffle: bool,
    num_workers: int,
) -> DataLoader:
    """Create a :class:`torch.utils.data.DataLoader` with notebook friendly defaults."""

    worker_count = _safe_worker_count(num_workers)
    if worker_count > 0 and _should_force_single_worker(dataset):
        LOGGER.info(
            "Detected interactive environment or in-notebook dataset definition; forcing num_workers=0."
        )
        worker_count = 0
    loader_kwargs = dict(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        pin_memory=torch.cuda.is_available(),
        collate_fn=detection_collate,
    )

    if worker_count > 0:
        loader_kwargs["num_workers"] = worker_count
        loader_kwargs["persistent_workers"] = True
        loader_kwargs["multiprocessing_context"] = mp.get_context("spawn")
    else:
        loader_kwargs["num_workers"] = 0

    try:
        return DataLoader(**loader_kwargs)
    except (RuntimeError, OSError, AssertionError) as exc:
        if worker_count == 0:
            raise
        warnings.warn(
            "Falling back to num_workers=0 because DataLoader worker initialisation "
            f"failed with: {exc}",
            RuntimeWarning,
        )
        LOGGER.warning("DataLoader workers failed to start (%s). Using num_workers=0 instead.", exc)
        loader_kwargs.pop("persistent_workers", None)
        loader_kwargs.pop("multiprocessing_context", None)
        loader_kwargs["num_workers"] = 0
        return DataLoader(**loader_kwargs)



## Utility helpers and metrics


In [None]:
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def sanitize_boxes_and_labels(
    boxes: Tensor, labels: Tensor, height: int, width: int, min_size: float = 1.0
) -> Tuple[Tensor, Tensor]:
    if boxes.numel() == 0:
        return boxes.reshape(0, 4).float(), labels.reshape(0).long()

    boxes = boxes.clone()
    boxes[:, 0::2] = boxes[:, 0::2].clamp(0, float(width))
    boxes[:, 1::2] = boxes[:, 1::2].clamp(0, float(height))

    widths = boxes[:, 2] - boxes[:, 0]
    heights = boxes[:, 3] - boxes[:, 1]
    keep = (widths > min_size) & (heights > min_size)

    if keep.sum() == 0:
        return boxes.new_zeros((0, 4)), labels.new_zeros((0,), dtype=torch.long)
    return boxes[keep].float(), labels[keep].long()


def compute_iou_matrix(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray:
    if boxes1.size == 0 or boxes2.size == 0:
        return np.zeros((boxes1.shape[0], boxes2.shape[0]), dtype=np.float32)

    x11, y11, x12, y12 = np.split(boxes1, 4, axis=1)
    x21, y21, x22, y22 = np.split(boxes2, 4, axis=1)

    inter_x1 = np.maximum(x11, x21.T)
    inter_y1 = np.maximum(y11, y21.T)
    inter_x2 = np.minimum(x12, x22.T)
    inter_y2 = np.minimum(y12, y22.T)

    inter_w = np.clip(inter_x2 - inter_x1, a_min=0.0, a_max=None)
    inter_h = np.clip(inter_y2 - inter_y1, a_min=0.0, a_max=None)
    inter_area = inter_w * inter_h

    area1 = (x12 - x11) * (y12 - y11)
    area2 = (x22 - x21) * (y22 - y21)

    union = area1 + area2.T - inter_area
    return np.divide(inter_area, union, out=np.zeros_like(inter_area), where=union > 0)


def compute_average_precision(recalls: np.ndarray, precisions: np.ndarray) -> float:
    if recalls.size == 0 or precisions.size == 0:
        return 0.0

    mrec = np.concatenate(([0.0], recalls, [1.0]))
    mpre = np.concatenate(([0.0], precisions, [0.0]))

    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = max(mpre[i - 1], mpre[i])

    recall_points = np.linspace(0, 1, 101)
    precision_interp = np.interp(recall_points, mrec, mpre)
    return float(np.trapz(precision_interp, recall_points))


def accumulate_classification_stats(
    predictions: Sequence[Dict[str, np.ndarray]],
    targets: Sequence[Dict[str, np.ndarray]],
    num_classes: int,
    iou_threshold: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[List[float]], List[List[int]], np.ndarray]:
    tp = np.zeros(num_classes, dtype=np.int64)
    fp = np.zeros(num_classes, dtype=np.int64)
    fn = np.zeros(num_classes, dtype=np.int64)
    scores: List[List[float]] = [[] for _ in range(num_classes)]
    matches: List[List[int]] = [[] for _ in range(num_classes)]
    gt_counter = np.zeros(num_classes, dtype=np.int64)

    for pred, tgt in zip(predictions, targets):
        pred_boxes = pred["boxes"]
        pred_scores = pred["scores"]
        pred_labels = pred["labels"].astype(np.int64)

        gt_boxes = tgt["boxes"]
        gt_labels = tgt["labels"].astype(np.int64)

        unique_classes = np.unique(np.concatenate((pred_labels, gt_labels)))
        for cls in unique_classes:
            pb = pred_boxes[pred_labels == cls]
            ps = pred_scores[pred_labels == cls]
            tb = gt_boxes[gt_labels == cls]
            gt_counter[cls] += len(tb)

            if len(tb) == 0:
                fp[cls] += len(pb)
                scores[cls].extend(ps.tolist())
                matches[cls].extend([0] * len(pb))
                continue

            order = np.argsort(-ps)
            pb = pb[order]
            ps = ps[order]
            iou_matrix = compute_iou_matrix(pb, tb)

            matched_gt: set[int] = set()
            for det_idx, score in enumerate(ps):
                if tb.size == 0:
                    fp[cls] += 1
                    scores[cls].append(float(score))
                    matches[cls].append(0)
                    continue

                best_gt = int(np.argmax(iou_matrix[det_idx]))
                best_iou = iou_matrix[det_idx, best_gt]

                if best_iou >= iou_threshold and best_gt not in matched_gt:
                    tp[cls] += 1
                    matched_gt.add(best_gt)
                    scores[cls].append(float(score))
                    matches[cls].append(1)
                else:
                    fp[cls] += 1
                    scores[cls].append(float(score))
                    matches[cls].append(0)

            fn[cls] += len(tb) - len(matched_gt)

    return tp, fp, fn, scores, matches, gt_counter


def compute_detection_metrics(
    predictions: Sequence[Dict[str, np.ndarray]],
    targets: Sequence[Dict[str, np.ndarray]],
    num_classes: int,
    iou_threshold: float,
) -> Dict[str, np.ndarray]:
    tp, fp, fn, scores, matches, gt_counter = accumulate_classification_stats(
        predictions, targets, num_classes, iou_threshold
    )

    precision = np.divide(tp, np.clip(tp + fp, a_min=1, a_max=None))
    recall = np.divide(tp, np.clip(tp + fn, a_min=1, a_max=None))

    ap = np.zeros(num_classes, dtype=np.float32)
    for cls in range(num_classes):
        if gt_counter[cls] == 0:
            ap[cls] = np.nan
            continue
        if not scores[cls]:
            ap[cls] = 0.0
            continue

        order = np.argsort(-np.asarray(scores[cls]))
        match_array = np.asarray(matches[cls], dtype=np.int32)[order]
        cumulative_tp = np.cumsum(match_array)
        cumulative_fp = np.cumsum(1 - match_array)

        recalls = cumulative_tp / gt_counter[cls]
        precisions = cumulative_tp / np.maximum(cumulative_tp + cumulative_fp, 1)
        ap[cls] = compute_average_precision(recalls, precisions)

    valid_ap = ap[np.isfinite(ap)]
    map_value = float(valid_ap.mean()) if valid_ap.size else 0.0

    return {
        "TP": tp,
        "FP": fp,
        "FN": fn,
        "precision": precision,
        "recall": recall,
        "AP": ap,
        "mAP": map_value,
        "gt_counter": gt_counter,
    }


class SmoothedValue:
    def __init__(self, window_size: int = 20) -> None:
        self.window_size = window_size
        self.deque: List[float] = []
        self.total = 0.0
        self.count = 0

    def update(self, value: float) -> None:
        if len(self.deque) == self.window_size:
            self.total -= self.deque.pop(0)
        self.deque.append(value)
        self.total += value
        self.count += 1

    @property
    def avg(self) -> float:
        if not self.deque:
            return 0.0
        return self.total / len(self.deque)


class MetricLogger:
    def __init__(self) -> None:
        self.meters: Dict[str, SmoothedValue] = {}

    def update(self, **kwargs: float) -> None:
        for name, value in kwargs.items():
            if name not in self.meters:
                self.meters[name] = SmoothedValue()
            self.meters[name].update(float(value))

    def format(self) -> str:
        parts = [f"{name}: {meter.avg:.4f}" for name, meter in self.meters.items()]
        return " | ".join(parts)



def score_threshold_mask(
    scores: np.ndarray,
    labels: np.ndarray,
    default_threshold: float,
    class_thresholds: Mapping[int, float],
) -> np.ndarray:
    """Return a boolean mask keeping predictions that pass per-class thresholds."""

    if scores.size == 0:
        return np.zeros_like(scores, dtype=bool)

    thresholds = np.full(scores.shape, default_threshold, dtype=scores.dtype)
    if class_thresholds:
        labels_int = labels.astype(np.int64, copy=False)
        for cls, value in class_thresholds.items():
            thresholds[labels_int == int(cls)] = float(value)

    return scores >= thresholds


def parse_class_threshold_entries(entries: Sequence[str]) -> Dict[int, float]:
    """Parse `CLS=THRESH` strings into a mapping of per-class thresholds."""

    thresholds: Dict[int, float] = {}
    for entry in entries:
        if not entry:
            continue

        if "=" in entry:
            key, value = entry.split("=", 1)
        elif ":" in entry:
            key, value = entry.split(":", 1)
        else:
            raise ValueError(f"Invalid class threshold format: {entry!r}")

        key = key.strip()
        value = value.strip()
        if not key or not value:
            raise ValueError(f"Invalid class threshold entry: {entry!r}")

        thresholds[int(key)] = float(value)

    return thresholds

def identify_false_positive_predictions(
    prediction: Dict[str, np.ndarray],
    target: Dict[str, np.ndarray],
    num_classes: int,
    iou_threshold: float,
) -> List[Dict[str, Union[int, float, List[float]]]]:
    """Return detailed records for false-positive detections in a single sample."""

    boxes = np.asarray(prediction.get("boxes", np.empty((0, 4), dtype=np.float32)), dtype=np.float32)
    scores = np.asarray(prediction.get("scores", np.empty((0,), dtype=np.float32)), dtype=np.float32)
    labels = np.asarray(prediction.get("labels", np.empty((0,), dtype=np.int64)), dtype=np.int64)

    gt_boxes = np.asarray(target.get("boxes", np.empty((0, 4), dtype=np.float32)), dtype=np.float32)
    gt_labels = np.asarray(target.get("labels", np.empty((0,), dtype=np.int64)), dtype=np.int64)

    if boxes.size == 0:
        return []

    fp_records: List[Dict[str, Union[int, float, List[float]]]] = []
    unique_classes = np.unique(labels) if labels.size else np.asarray([], dtype=np.int64)

    for cls in unique_classes:
        cls = int(cls)
        cls_mask = labels == cls
        cls_indices = np.nonzero(cls_mask)[0]
        if cls_indices.size == 0:
            continue

        pb = boxes[cls_mask]
        ps = scores[cls_mask]
        order = np.argsort(-ps)
        pb_sorted = pb[order]
        ps_sorted = ps[order]
        original_indices = cls_indices[order]

        tb = gt_boxes[gt_labels == cls]
        iou_matrix = compute_iou_matrix(pb_sorted, tb) if tb.size else np.zeros((pb_sorted.shape[0], 0), dtype=np.float32)
        matched_gt: set[int] = set()

        for rank, (pred_idx, score_value) in enumerate(zip(original_indices, ps_sorted)):
            if iou_matrix.shape[1]:
                row = iou_matrix[rank]
                best_gt = int(row.argmax())
                best_iou = float(row[best_gt])
            else:
                best_gt = -1
                best_iou = 0.0

            is_true_positive = (
                iou_matrix.shape[1] > 0
                and best_iou >= iou_threshold
                and best_gt not in matched_gt
            )

            if is_true_positive:
                matched_gt.add(best_gt)
                continue

            fp_records.append(
                {
                    "index": int(pred_idx),
                    "class": cls,
                    "score": float(score_value),
                    "best_iou": best_iou,
                    "box": boxes[pred_idx].astype(float).tolist(),
                }
            )

    return fp_records



DEFAULT_COLORS = [
    "#FF6B6B",
    "#4ECDC4",
    "#556270",
    "#C44D58",
    "#FFB347",
    "#6B5B95",
    "#88B04B",
    "#92A8D1",
    "#955251",
    "#B565A7",
]


def load_default_font() -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
    try:
        return ImageFont.truetype("DejaVuSans.ttf", size=14)
    except Exception:
        return ImageFont.load_default()


def _resolve_class_name(class_names: Sequence[str], label: int) -> str:
    if 0 <= label < len(class_names):
        name = class_names[label]
    else:
        name = f"class_{label:02d}"
    if name.startswith("class_") and name[6:].isdigit():
        return f"class {int(name[6:]):02d}"
    return name


def render_detections(
    image: np.ndarray,
    prediction: Mapping[str, np.ndarray],
    target: Mapping[str, np.ndarray] | None,
    class_names: Sequence[str],
    score_threshold: float,
    class_thresholds: Mapping[int, float],
    draw_ground_truth: bool,
) -> PILImage:
    if image.dtype != np.uint8:
        image_array = np.clip(image, 0, 255).astype(np.uint8)
    else:
        image_array = image

    pil = PILImage.fromarray(image_array)
    draw = ImageDraw.Draw(pil)
    font = load_default_font()

    boxes = np.asarray(prediction.get("boxes", np.empty((0, 4))), dtype=float)
    labels = np.asarray(prediction.get("labels", np.empty((0,), dtype=int)), dtype=int)
    scores = np.asarray(prediction.get("scores", np.empty((0,), dtype=float)), dtype=float)

    for box, label, score in zip(boxes, labels, scores):
        threshold = class_thresholds.get(int(label), score_threshold)
        if score < threshold:
            continue
        color = DEFAULT_COLORS[int(label) % len(DEFAULT_COLORS)] if DEFAULT_COLORS else "#FF6B6B"
        x1, y1, x2, y2 = [float(coord) for coord in box]
        draw.rectangle([x1, y1, x2, y2], outline=color, width=2)

        caption = f"{_resolve_class_name(class_names, int(label))} {score:.2f}"
        text_width = draw.textlength(caption, font=font)
        draw.rectangle([x1, y1 - 16, x1 + text_width + 8, y1], fill=color)
        draw.text((x1 + 4, y1 - 14), caption, fill="white", font=font)

    if draw_ground_truth and target is not None:
        gt_boxes = np.asarray(target.get("boxes", np.empty((0, 4))), dtype=float)
        gt_labels = np.asarray(target.get("labels", np.empty((0,), dtype=int)), dtype=int)
        for box, label in zip(gt_boxes, gt_labels):
            x1, y1, x2, y2 = [float(coord) for coord in box]
            draw.rectangle([x1, y1, x2, y2], outline="#FFFFFF", width=1)
            caption = f"GT {_resolve_class_name(class_names, int(label))}"
            text_width = draw.textlength(caption, font=font)
            draw.rectangle([x1, y2, x1 + text_width + 6, y2 + 14], fill="#FFFFFF")
            draw.text((x1 + 3, y2), caption, fill="black", font=font)

    return pil


def save_detection_visual(
    image: np.ndarray,
    prediction: Mapping[str, np.ndarray],
    target: Mapping[str, np.ndarray] | None,
    class_names: Sequence[str],
    score_threshold: float,
    class_thresholds: Mapping[int, float],
    draw_ground_truth: bool,
    output_path: Path,
) -> None:
    visual = render_detections(
        image,
        prediction,
        target,
        class_names,
        score_threshold,
        class_thresholds,
        draw_ground_truth,
    )
    output_path.parent.mkdir(parents=True, exist_ok=True)
    visual.save(output_path)


def write_false_positive_report(
    fp_records: Sequence[Dict[str, object]],
    report_path: Path,
    *,
    split: str,
    score_threshold: float,
    class_score_thresholds: Mapping[int, float],
    iou_threshold: float,
) -> None:
    payload = {
        "split": split,
        "score_threshold": score_threshold,
        "class_score_thresholds": dict(class_score_thresholds),
        "iou_threshold": iou_threshold,
        "false_positive_images": list(fp_records),
    }
    report_path.parent.mkdir(parents=True, exist_ok=True)
    report_path.write_text(json.dumps(payload, indent=2, ensure_ascii=False))


def write_false_positive_list(fp_records: Sequence[Dict[str, object]], list_path: Path) -> None:
    stems = sorted({str(record["image_id"]) for record in fp_records})
    list_path.parent.mkdir(parents=True, exist_ok=True)
    list_path.write_text("\n".join(stems) + ("\n" if stems else ""))


## Model construction


In [None]:
def _save_state_dict(model: nn.Module, path: Path) -> None:
    try:
        torch.save(model.state_dict(), path)
        LOGGER.info("Saved pretrained weights to %s", path)
    except Exception as exc:
        LOGGER.warning("Unable to save pretrained weights: %s", exc)


def _load_pretrained_model(train_cfg: TrainingConfig) -> nn.Module:
    pretrained_path = train_cfg.pretrained_weights_path
    pretrained_path.parent.mkdir(parents=True, exist_ok=True)

    weights_enum = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    try:
        model = fasterrcnn_resnet50_fpn_v2(weights=weights_enum)
        LOGGER.info("Loaded torchvision Faster R-CNN weights")
        if not pretrained_path.exists():
            _save_state_dict(model, pretrained_path)
        return model
    except Exception:
        LOGGER.warning("Falling back to locally saved pretrained detector weights")
        if not pretrained_path.exists():
            raise RuntimeError(
                "No pretrained weights available. Download them manually and place them at "
                + str(pretrained_path)
            )
        state_dict = torch.load(pretrained_path, map_location="cpu")
        model = fasterrcnn_resnet50_fpn_v2(weights=None)
        model.load_state_dict(state_dict)
        return model


def build_model(
    dataset_cfg: DatasetConfig,
    train_cfg: TrainingConfig,
    device: Optional[torch.device] = None,
) -> nn.Module:
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = _load_pretrained_model(train_cfg)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    num_classes_with_background = dataset_cfg.num_classes + 1
    model.roi_heads.box_predictor = FastRCNNPredictor(
        in_features, num_classes_with_background
    )

    if train_cfg.small_object:
        anchor_generator = AnchorGenerator(
            sizes=((16, 32, 64, 128, 256),) * 5,
            aspect_ratios=((0.5, 1.0, 2.0),) * 5,
        )
        model.rpn.anchor_generator = anchor_generator
        LOGGER.info("Using custom anchor sizes optimised for small objects")

    model.to(device)
    return model


## Training utilities


In [None]:
def move_to_device(targets: List[Dict[str, torch.Tensor]], device: torch.device) -> List[Dict[str, torch.Tensor]]:
    return [{k: v.to(device) for k, v in target.items()} for target in targets]


def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scaler: GradScaler,
    device: torch.device,
    amp: bool,
    log_every: int,
) -> float:
    model.train()
    metric_logger = MetricLogger()
    progress = tqdm(loader, desc="Train", leave=False)

    for step, (images, targets) in enumerate(progress, start=1):
        images = [img.to(device) for img in images]
        targets = move_to_device(targets, device)

        optimizer.zero_grad()
        autocast_enabled = amp and device.type == "cuda"
        autocast_context = (
            torch.amp.autocast(device_type="cuda") if autocast_enabled else contextlib.nullcontext()
        )
        with autocast_context:
            loss_dict = model(images, targets)
            loss = sum(loss_dict.values())

        if torch.isfinite(loss):
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            LOGGER.warning("Skipping step %s due to non-finite loss", step)
            scaler.update()
            continue

        metric_logger.update(loss=loss.item())
        if step % log_every == 0:
            progress.set_postfix_str(metric_logger.format())

    return metric_logger.meters.get("loss").avg if metric_logger.meters else 0.0


@torch.no_grad()
def evaluate(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    dataset_cfg: DatasetConfig,
    train_cfg: TrainingConfig,
    *,
    dataset: ElectricalComponentsDataset | None = None,
    collect_details: bool = False,
) -> Tuple[Dict[str, torch.Tensor | float | List[float]], List[Dict[str, object]]]:
    was_training = model.training
    model.eval()

    predictions = []
    targets_for_eval = []
    total_loss = 0.0
    num_batches = 0
    sample_details: List[Dict[str, object]] = []
    dataset_ref = dataset if dataset is not None else getattr(loader, "dataset", None)

    for images, targets in loader:
        images = [img.to(device) for img in images]
        targets_device = move_to_device(targets, device)

        model.train()
        loss_dict = model(images, targets_device)
        total_loss += sum(loss_dict.values()).item()
        num_batches += 1
        model.eval()

        outputs = model(images)
        for output, target_device, target_cpu in zip(outputs, targets_device, targets):
            scores = output["scores"].detach().cpu().numpy()
            labels = output["labels"].detach().cpu().numpy()
            keep = score_threshold_mask(
                scores,
                labels,
                train_cfg.score_threshold,
                train_cfg.class_score_thresholds,
            )
            prediction_np = {
                "boxes": output["boxes"].detach().cpu().numpy()[keep],
                "scores": scores[keep],
                "labels": labels[keep],
            }
            target_np = {
                "boxes": target_device["boxes"].detach().cpu().numpy(),
                "labels": target_device["labels"].detach().cpu().numpy(),
            }

            predictions.append(prediction_np)
            targets_for_eval.append(target_np)

            if collect_details:
                image_identifier = target_cpu.get("image_id", -1)
                if isinstance(image_identifier, torch.Tensor):
                    image_index = int(image_identifier.item())
                else:
                    try:
                        image_index = int(image_identifier)
                    except Exception:
                        image_index = -1
                if dataset_ref is not None and 0 <= image_index < len(getattr(dataset_ref, "image_stems", [])):
                    image_id = dataset_ref.image_stems[image_index]
                else:
                    image_id = f"{dataset_cfg.valid_split}_{len(sample_details):04d}"

                fp_info = identify_false_positive_predictions(
                    prediction_np,
                    target_np,
                    dataset_cfg.num_classes,
                    train_cfg.iou_threshold,
                )

                sample_details.append(
                    {
                        "image_index": image_index,
                        "image_id": image_id,
                        "prediction": prediction_np,
                        "target": target_np,
                        "false_positives": fp_info,
                    }
                )

    metrics = compute_detection_metrics(
        predictions, targets_for_eval, dataset_cfg.num_classes, train_cfg.iou_threshold
    )
    metrics["loss"] = total_loss / max(num_batches, 1)

    if was_training:
        model.train()
    return metrics, sample_details


def _resolve_class_label(dataset_cfg: DatasetConfig, index: int) -> str:
    if index < len(dataset_cfg.class_names):
        label = dataset_cfg.class_names[index]
    else:
        label = f"class_{index:02d}"

    if label.startswith("class_") and label[6:].isdigit():
        return f"class {int(label[6:]):02d}"
    return label


def format_epoch_metrics(
    epoch: Optional[int],
    train_loss: Optional[float],
    metrics: Dict[str, torch.Tensor | float | List[float]],
    dataset_cfg: DatasetConfig,
    *,
    header: Optional[str] = None,
) -> List[str]:
    lines: List[str] = []

    val_loss = float(metrics.get("loss", float("nan")))
    map_value = float(metrics.get("mAP", float("nan")))

    if header is not None:
        summary = header
    elif epoch is not None:
        summary = f"Epoch {epoch:02d}"
    else:
        summary = "Metrics"

    if train_loss is not None and np.isfinite(train_loss):
        summary += f" | train loss {train_loss:.4f}"
    if np.isfinite(val_loss):
        summary += f" | val loss {val_loss:.4f}"
    if np.isfinite(map_value):
        summary += f" | mAP {map_value:.4f}"
    lines.append(summary)

    precision = np.asarray(metrics.get("precision", []), dtype=float)
    recall = np.asarray(metrics.get("recall", []), dtype=float)
    tp = np.asarray(metrics.get("TP", []), dtype=int)
    fp = np.asarray(metrics.get("FP", []), dtype=int)
    fn = np.asarray(metrics.get("FN", []), dtype=int)
    ap = np.asarray(metrics.get("AP", []), dtype=float)
    gt_counter = np.asarray(metrics.get("gt_counter", np.zeros_like(tp)), dtype=int)

    num_classes = min(len(tp), dataset_cfg.num_classes)
    for cls_idx in range(num_classes):
        gt_value = int(gt_counter[cls_idx]) if gt_counter.size > cls_idx else 0
        tp_value = int(tp[cls_idx]) if tp.size > cls_idx else 0
        fp_value = int(fp[cls_idx]) if fp.size > cls_idx else 0
        fn_value = int(fn[cls_idx]) if fn.size > cls_idx else 0

        if gt_value == 0 and tp_value == 0 and fp_value == 0 and fn_value == 0:
            continue

        label = _resolve_class_label(dataset_cfg, cls_idx)
        if precision.size > cls_idx:
            p_val = float(np.nan_to_num(precision[cls_idx], nan=0.0))
        else:
            p_val = 0.0
        if recall.size > cls_idx:
            r_val = float(np.nan_to_num(recall[cls_idx], nan=0.0))
        else:
            r_val = 0.0

        line = f"{label} | P={p_val:.3f} R={r_val:.3f}  TP={tp_value} FP={fp_value} FN={fn_value}"
        if ap.size > cls_idx and np.isfinite(ap[cls_idx]):
            line += f" AP={ap[cls_idx]:.3f}"
        lines.append(line)

    return lines


def save_checkpoint(model: nn.Module, path: Path) -> None:
    torch.save(model.state_dict(), path)
    LOGGER.info("Saved checkpoint to %s", path)


def export_false_positive_visuals(
    *,
    dataset: ElectricalComponentsDataset,
    dataset_cfg: DatasetConfig,
    train_cfg: TrainingConfig,
    sample_details: List[Dict[str, object]],
) -> List[Dict[str, object]]:
    if not train_cfg.fp_visual_dir:
        return []

    relevant_classes = {int(cls) for cls in train_cfg.fp_classes}
    if not relevant_classes:
        return []

    output_dir = train_cfg.fp_visual_dir
    output_dir.mkdir(parents=True, exist_ok=True)

    for existing in output_dir.glob("*.png"):
        try:
            existing.unlink()
        except OSError:
            LOGGER.debug("Unable to remove previous FP visual %s", existing)

    fp_records: List[Dict[str, object]] = []

    for detail in sample_details:
        raw_records = detail.get("false_positives", [])
        if not raw_records:
            continue

        filtered = [fp for fp in raw_records if int(fp.get("class", -1)) in relevant_classes]
        if not filtered:
            continue

        image_id = str(detail.get("image_id", detail.get("image_index", "unknown")))
        try:
            image_np, _, _ = dataset._load_raw_sample(image_id)
        except Exception as exc:
            LOGGER.warning("Skipping FP visual for %s due to load error: %s", image_id, exc)
            continue

        output_path = output_dir / f"{image_id}.png"
        save_detection_visual(
            image_np,
            detail.get("prediction", {}),
            detail.get("target", {}),
            dataset_cfg.class_names,
            train_cfg.score_threshold,
            train_cfg.class_score_thresholds,
            True,
            output_path,
        )
        fp_records.append({"image_id": image_id, "false_positives": filtered})

    return fp_records


def train_pipeline(
    dataset_cfg: DatasetConfig,
    train_cfg: TrainingConfig,
    *,
    resume_from: Optional[Path] = None,
) -> Tuple[nn.Module, List[Dict[str, float]]]:
    train_cfg.ensure_directories()
    set_seed(train_cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_model(dataset_cfg, train_cfg, device=device)

    if resume_from is not None and Path(resume_from).exists():
        LOGGER.info("Resuming model weights from %s", resume_from)
        state_dict = torch.load(resume_from, map_location=device)
        model.load_state_dict(state_dict)

    train_dataset = ElectricalComponentsDataset(
        root=dataset_cfg.base_dir,
        split=dataset_cfg.train_split,
        class_names=dataset_cfg.class_names,
        transform=AugmentationParams(
            mosaic_prob=train_cfg.mosaic_prob,
            mixup_prob=train_cfg.mixup_prob,
            mixup_alpha=train_cfg.mixup_alpha,
            scale_jitter_range=(train_cfg.scale_jitter_min, train_cfg.scale_jitter_max),
        ),
        use_augmentation=train_cfg.augmentation,
        exclude_stems=train_cfg.exclude_samples,
    )
    valid_dataset = ElectricalComponentsDataset(
        root=dataset_cfg.base_dir,
        split=dataset_cfg.valid_split,
        class_names=dataset_cfg.class_names,
        use_augmentation=False,
    )

    if train_cfg.exclude_samples:
        sample_preview = ", ".join(train_cfg.exclude_samples[:5])
        if len(train_cfg.exclude_samples) > 5:
            sample_preview += ", ..."
        LOGGER.info(
            "Excluding %d training samples: %s",
            len(train_cfg.exclude_samples),
            sample_preview,
        )

    train_loader = create_data_loaders(
        train_dataset,
        batch_size=train_cfg.batch_size,
        shuffle=True,
        num_workers=train_cfg.num_workers,
    )
    if train_cfg.num_workers > 1:
        valid_workers = max(1, train_cfg.num_workers // 2)
    else:
        valid_workers = train_cfg.num_workers

    valid_loader = create_data_loaders(
        valid_dataset,
        batch_size=train_cfg.batch_size,
        shuffle=False,
        num_workers=valid_workers,
    )

    optimizer = AdamW(model.parameters(), lr=train_cfg.learning_rate, weight_decay=train_cfg.weight_decay)
    scaler = GradScaler(enabled=train_cfg.amp and device.type == "cuda")

    best_map = -float("inf")
    history: List[Dict[str, float]] = []

    for epoch in range(1, train_cfg.epochs + 1):
        LOGGER.info("Epoch %s/%s", epoch, train_cfg.epochs)
        train_loss = train_one_epoch(
            model, train_loader, optimizer, scaler, device, train_cfg.amp, train_cfg.log_every
        )

        should_evaluate = (epoch % train_cfg.eval_interval == 0) or (epoch == train_cfg.epochs)
        collect_fp = should_evaluate and epoch == train_cfg.epochs

        if should_evaluate:
            metrics, sample_details = evaluate(
                model,
                valid_loader,
                device,
                dataset_cfg,
                train_cfg,
                dataset=valid_dataset,
                collect_details=collect_fp,
            )
            metric_lines = format_epoch_metrics(epoch, train_loss, metrics, dataset_cfg)
            emit_metric_lines(metric_lines, logger=LOGGER)

            history.append(
                {
                    "epoch": epoch,
                    "train_loss": float(train_loss),
                    "val_loss": float(metrics["loss"]),
                    "mAP": float(metrics["mAP"]),
                }
            )

            if metrics["mAP"] > best_map:
                best_map = float(metrics["mAP"])
                save_checkpoint(model, train_cfg.checkpoint_path)

            if collect_fp:
                fp_records = export_false_positive_visuals(
                    dataset=valid_dataset,
                    dataset_cfg=dataset_cfg,
                    train_cfg=train_cfg,
                    sample_details=sample_details,
                )
                if train_cfg.fp_report_path:
                    write_false_positive_report(
                        fp_records,
                        train_cfg.fp_report_path,
                        split=dataset_cfg.valid_split,
                        score_threshold=train_cfg.score_threshold,
                        class_score_thresholds=train_cfg.class_score_thresholds,
                        iou_threshold=train_cfg.iou_threshold,
                    )
                    LOGGER.info(
                        "Saved false-positive report for %d images to %s",
                        len(fp_records),
                        train_cfg.fp_report_path,
                    )
                if train_cfg.fp_list_path:
                    write_false_positive_list(fp_records, train_cfg.fp_list_path)
                    LOGGER.info(
                        "Saved list of %d false-positive image ids to %s",
                        len({str(record["image_id"]) for record in fp_records}),
                        train_cfg.fp_list_path,
                    )
        else:
            LOGGER.info(
                "Epoch %02d | train loss %.4f | evaluation skipped (eval_interval=%d)",
                epoch,
                train_loss,
                train_cfg.eval_interval,
            )

    (train_cfg.output_dir / "training_history.json").write_text(json.dumps(history, indent=2))
    LOGGER.info("Training complete. Best mAP: %.4f", best_map)

    return model, history






## Inference helpers


In [None]:
@torch.no_grad()
def run_inference(
    dataset_cfg: DatasetConfig,
    inference_cfg: InferenceConfig,
    train_cfg: TrainingConfig,
    checkpoint_path: Path,
    split: Optional[str] = None,
    fp_report_path: Optional[Path] = None,
    fp_list_path: Optional[Path] = None,
) -> Dict[str, object]:
    inference_cfg.ensure_directories()
    set_seed(train_cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_model(dataset_cfg, train_cfg, device=device)
    state_dict = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()

    dataset = ElectricalComponentsDataset(
        root=dataset_cfg.base_dir,
        split=split or dataset_cfg.test_split,
        class_names=dataset_cfg.class_names,
        use_augmentation=False,
    )
    loader = create_data_loaders(dataset, batch_size=1, shuffle=False, num_workers=0)

    predictions: List[Dict[str, np.ndarray]] = []
    targets_for_eval: List[Dict[str, np.ndarray]] = []
    fp_records: List[Dict[str, object]] = []

    progress = tqdm(loader, desc="Infer")
    for idx, (images, targets) in enumerate(progress):
        image = images[0].to(device)
        output = model([image])[0]

        boxes_np = output["boxes"].detach().cpu().numpy()
        scores_np = output["scores"].detach().cpu().numpy()
        labels_np = output["labels"].detach().cpu().numpy()
        keep = score_threshold_mask(
            scores_np,
            labels_np,
            inference_cfg.score_threshold,
            inference_cfg.class_score_thresholds,
        )
        prediction_np = {
            "boxes": boxes_np[keep],
            "scores": scores_np[keep],
            "labels": labels_np[keep],
        }
        target_np = {
            "boxes": targets[0]["boxes"].detach().cpu().numpy(),
            "labels": targets[0]["labels"].detach().cpu().numpy(),
        }

        predictions.append(prediction_np)
        targets_for_eval.append(target_np)

        fp_details = identify_false_positive_predictions(
            prediction_np,
            target_np,
            dataset_cfg.num_classes,
            train_cfg.iou_threshold,
        )
        if fp_details:
            default_split = split or dataset_cfg.test_split
            image_id = dataset.image_stems[idx] if idx < len(dataset.image_stems) else f"{default_split}_{idx:04d}"
            fp_records.append(
                {
                    "image_id": image_id,
                    "false_positives": fp_details,
                }
            )

        if idx < inference_cfg.max_images:
            image_np = (images[0].permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
            output_path = inference_cfg.output_dir / f"{(split or dataset_cfg.test_split)}_{idx:04d}.png"
            save_detection_visual(
                image_np,
                prediction_np,
                target_np if inference_cfg.draw_ground_truth else None,
                dataset_cfg.class_names,
                inference_cfg.score_threshold,
                inference_cfg.class_score_thresholds,
                inference_cfg.draw_ground_truth,
                output_path,
            )

    metrics = compute_detection_metrics(
        predictions, targets_for_eval, dataset_cfg.num_classes, train_cfg.iou_threshold
    )
    metrics["false_positive_images"] = fp_records
    metrics["false_positive_stems"] = sorted({record["image_id"] for record in fp_records})
    metric_lines = format_epoch_metrics(
        epoch=None,
        train_loss=None,
        metrics=metrics,
        dataset_cfg=dataset_cfg,
        header=f"Inference @ IoU {train_cfg.iou_threshold:.2f}",
    )
    emit_metric_lines(metric_lines, logger=LOGGER)

    if fp_report_path is not None:
        fp_report_path = Path(fp_report_path)
        write_false_positive_report(
            fp_records,
            fp_report_path,
            split=split or dataset_cfg.test_split,
            score_threshold=inference_cfg.score_threshold,
            class_score_thresholds=inference_cfg.class_score_thresholds,
            iou_threshold=train_cfg.iou_threshold,
        )
        LOGGER.info(
            "Wrote false-positive report for %d images to %s",
            len(fp_records),
            fp_report_path,
        )

    if fp_list_path is not None:
        fp_list_path = Path(fp_list_path)
        write_false_positive_list(fp_records, fp_list_path)
        LOGGER.info(
            "Wrote %d image ids with false positives to %s",
            len({str(record["image_id"]) for record in fp_records}),
            fp_list_path,
        )

    return metrics






## Example usage


In [None]:
dataset_cfg = DatasetConfig(base_dir=Path('/kaggle/input/electrical-component/dataset'))
train_cfg = TrainingConfig(epochs=7, batch_size=2, augmentation=False)
inference_cfg = InferenceConfig(score_threshold=0.6, draw_ground_truth=True)

model, history = train_pipeline(dataset_cfg, train_cfg)
metrics = run_inference(
    dataset_cfg,
    inference_cfg,
    train_cfg,
    checkpoint_path=train_cfg.checkpoint_path,
)



**提示：** `run_inference` 会在返回的 `metrics` 中附带 `false_positive_images` 和 `false_positive_stems`。 如果想在下一次训练时忽略这些样本，可以将 `TrainingConfig.exclude_samples` 设为 `tuple(metrics["false_positive_stems"])`，或将 `metrics["false_positive_stems"]` 写入文本文件， 然后通过命令行参数 `--exclude-list` 传入。