# Electrical Component Detection Pipeline

This notebook consolidates the refactored Faster R-CNN training and inference workflow into a single place for convenient experimentation on Kaggle. It provides reusable configuration objects, dataset loaders with optional augmentation, detailed metric utilities (including per-class TP/FP/FN and mAP), and helpers for both training and inference.


In [None]:
import argparse
import multiprocessing as mp
import os
import warnings
import json
import logging
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import torch
from PIL import Image as PILImage, ImageDraw, ImageEnhance, ImageFont
from torch import Tensor, nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from torchvision.models.detection import (
    FasterRCNN_ResNet50_FPN_V2_Weights,
    fasterrcnn_resnet50_fpn_v2,
)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
LOGGER = logging.getLogger("notebook")


## Configuration objects


In [None]:
DEFAULT_PRETRAINED_URL = "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth"


@dataclass
class DatasetConfig:
    """Configuration describing the dataset layout and metadata."""

    base_dir: Path = Path("data")
    train_split: str = "train"
    valid_split: str = "valid"
    test_split: str = "test"
    image_folder: str = "images"
    label_folder: str = "labels"
    num_classes: int = 32
    class_names: Tuple[str, ...] = ()

    def __post_init__(self) -> None:
        if not self.class_names:
            self.class_names = tuple(f"class_{idx:02d}" for idx in range(self.num_classes))


@dataclass
class TrainingConfig:
    """Hyper-parameters and runtime settings for model training."""

    epochs: int = 20
    batch_size: int = 4
    learning_rate: float = 1e-4
    weight_decay: float = 1e-4
    num_workers: int = 0
    amp: bool = True
    augmentation: bool = True
    small_object: bool = False
    score_threshold: float = 0.05
    iou_threshold: float = 0.5
    eval_interval: int = 1
    seed: int = 2024
    output_dir: Path = Path("outputs")
    checkpoint_path: Path = Path("outputs/best_model.pth")
    pretrained_weights_path: Path = Path("weights/fasterrcnn_resnet50_fpn_v2_coco.pth")
    pretrained_weights_url: str = DEFAULT_PRETRAINED_URL
    log_every: int = 20

    def ensure_directories(self) -> None:
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.pretrained_weights_path.parent.mkdir(parents=True, exist_ok=True)


@dataclass
class InferenceConfig:
    """Options for running model inference and visualisation."""

    score_threshold: float = 0.3
    max_images: int = 50
    output_dir: Path = Path("outputs/inference")
    draw_ground_truth: bool = True
    class_colors: List[str] = field(default_factory=list)

    def ensure_directories(self) -> None:
        self.output_dir.mkdir(parents=True, exist_ok=True)


## Dataset loading and augmentation


In [None]:
LOGGER = logging.getLogger(__name__)


@dataclass
class AugmentationParams:
    """Parameters controlling the dataset level image augmentations."""

    horizontal_flip_prob: float = 0.5
    vertical_flip_prob: float = 0.2
    brightness: float = 0.2
    contrast: float = 0.2
    saturation: float = 0.2
    hue: float = 0.02


def load_image_hwc_uint8(path: Path) -> np.ndarray:
    """Load an ``.npy`` image stored as HWC and return an ``uint8`` array."""
    image = np.load(path, allow_pickle=False, mmap_mode="r")

    if image.dtype != np.uint8:
        image = image.astype(np.float32, copy=False)
        vmin, vmax = float(image.min()), float(image.max())
        if 0.0 <= vmin and vmax <= 1.0:
            image = (image * 255.0).round()
        elif -1.0 <= vmin and vmax <= 1.0:
            image = ((image + 1.0) * 0.5 * 255.0).round()
        image = np.clip(image, 0, 255).astype(np.uint8)

    channels = image.shape[2]
    if channels == 1:
        image = np.repeat(image, 3, axis=2)
    elif channels == 4:
        image = image[..., :3]
    return image


class ElectricalComponentsDataset(Dataset):
    """Dataset of electrical component detections stored as ``.npy`` images and CSV labels."""

    def __init__(
        self,
        root: Path,
        split: str,
        class_names: Iterable[str],
        transform: Optional[AugmentationParams] = None,
        use_augmentation: bool = False,
    ) -> None:
        self.root = Path(root)
        self.split = split
        self.class_names = list(class_names)
        self.transform_params = transform or AugmentationParams()
        self.use_augmentation = use_augmentation

        self.image_dir = self.root / split / "images"
        self.label_dir = self.root / split / "labels"

        if not self.image_dir.exists():
            raise FileNotFoundError(f"Missing image directory: {self.image_dir}")
        if not self.label_dir.exists():
            raise FileNotFoundError(f"Missing label directory: {self.label_dir}")

        self.image_stems = sorted(p.stem for p in self.label_dir.glob("*.csv"))
        if not self.image_stems:
            raise RuntimeError(f"No label files found in {self.label_dir}")

        # Pre-load all annotations to reduce I/O during training.
        self.annotations: Dict[str, pd.DataFrame] = {
            stem: pd.read_csv(self.label_dir / f"{stem}.csv") for stem in self.image_stems
        }

    def __len__(self) -> int:
        return len(self.image_stems)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        stem = self.image_stems[index]
        image_path = self.image_dir / f"{stem}.npy"
        image = load_image_hwc_uint8(image_path)
        height, width = image.shape[:2]

        ann = self.annotations[stem]
        x_center = ann["x_center"].to_numpy(dtype=np.float32)
        y_center = ann["y_center"].to_numpy(dtype=np.float32)
        box_width = ann["width"].to_numpy(dtype=np.float32)
        box_height = ann["height"].to_numpy(dtype=np.float32)

        # Auto-detect normalised coordinates and scale back to pixel space.
        if (
            (x_center.size == 0 or float(x_center.max()) <= 1.0)
            and (y_center.size == 0 or float(y_center.max()) <= 1.0)
            and (box_width.size == 0 or float(box_width.max()) <= 1.0)
            and (box_height.size == 0 or float(box_height.max()) <= 1.0)
        ):
            x_center *= width
            y_center *= height
            box_width *= width
            box_height *= height

        x1 = x_center - box_width / 2.0
        y1 = y_center - box_height / 2.0
        x2 = x_center + box_width / 2.0
        y2 = y_center + box_height / 2.0

        boxes = np.stack([x1, y1, x2, y2], axis=1)
        labels = ann["class"].to_numpy(dtype=np.int64)

        if self.use_augmentation and len(boxes):
            image, boxes = self._apply_augmentations(image, boxes)
            height, width = image.shape[:2]

        image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        boxes_tensor = torch.from_numpy(boxes).float()
        labels_tensor = torch.from_numpy(labels).long()

        boxes_tensor, labels_tensor = sanitize_boxes_and_labels(
            boxes_tensor, labels_tensor, height, width
        )

        target: Dict[str, torch.Tensor] = {
            "boxes": boxes_tensor,
            "labels": labels_tensor,
            "image_id": torch.tensor(index, dtype=torch.int64),
            "area": (boxes_tensor[:, 2] - boxes_tensor[:, 0])
            * (boxes_tensor[:, 3] - boxes_tensor[:, 1])
            if boxes_tensor.numel()
            else torch.tensor([], dtype=torch.float32),
            "iscrowd": torch.zeros((boxes_tensor.shape[0],), dtype=torch.int64),
            "orig_size": torch.tensor([height, width], dtype=torch.int64),
        }

        return image_tensor, target

    def _apply_augmentations(
        self, image: np.ndarray, boxes: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        params = self.transform_params
        height, width = image.shape[:2]

        if random.random() < params.horizontal_flip_prob:
            image = np.ascontiguousarray(image[:, ::-1, :])
            x1 = width - boxes[:, 2]
            x2 = width - boxes[:, 0]
            boxes[:, 0], boxes[:, 2] = x1, x2

        if random.random() < params.vertical_flip_prob:
            image = np.ascontiguousarray(image[::-1, :, :])
            y1 = height - boxes[:, 3]
            y2 = height - boxes[:, 1]
            boxes[:, 1], boxes[:, 3] = y1, y2

        if params.brightness or params.contrast or params.saturation or params.hue:
            pil = PILImage.fromarray(image)
            if params.brightness:
                enhancer = ImageEnhance.Brightness(pil)
                factor = 1.0 + random.uniform(-params.brightness, params.brightness)
                pil = enhancer.enhance(max(0.1, factor))
            if params.contrast:
                enhancer = ImageEnhance.Contrast(pil)
                factor = 1.0 + random.uniform(-params.contrast, params.contrast)
                pil = enhancer.enhance(max(0.1, factor))
            if params.saturation:
                enhancer = ImageEnhance.Color(pil)
                factor = 1.0 + random.uniform(-params.saturation, params.saturation)
                pil = enhancer.enhance(max(0.1, factor))
            if params.hue:
                # Hue adjustment via simple conversion to HSV.
                hsv = np.array(pil.convert("HSV"), dtype=np.uint8)
                delta = int(params.hue * 255.0 * random.choice([-1, 1]))
                hsv[..., 0] = (hsv[..., 0].astype(int) + delta) % 255
                pil = PILImage.fromarray(hsv, mode="HSV").convert("RGB")
            image = np.array(pil)

        boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, width)
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, height)
        return image, boxes


def detection_collate(batch: List[Tuple[torch.Tensor, Dict[str, torch.Tensor]]]):
    """Collate function for detection datasets returning lists of tensors."""
    images, targets = zip(*batch)
    return list(images), list(targets)


def _safe_worker_count(requested: int) -> int:
    cpu_count = os.cpu_count() or 1
    if requested <= 0:
        return 0
    max_workers = max(1, cpu_count - 1)
    return min(requested, max_workers)


def create_data_loaders(
    dataset: Dataset,
    batch_size: int,
    shuffle: bool,
    num_workers: int,
) -> DataLoader:
    """Create a :class:`torch.utils.data.DataLoader` with notebook friendly defaults."""

    worker_count = _safe_worker_count(num_workers)
    loader_kwargs = dict(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        pin_memory=torch.cuda.is_available(),
        collate_fn=detection_collate,
    )

    if worker_count > 0:
        loader_kwargs["num_workers"] = worker_count
        loader_kwargs["persistent_workers"] = True
        loader_kwargs["multiprocessing_context"] = mp.get_context("spawn")
    else:
        loader_kwargs["num_workers"] = 0

    try:
        return DataLoader(**loader_kwargs)
    except (RuntimeError, OSError, AssertionError) as exc:
        if worker_count == 0:
            raise
        warnings.warn(
            "Falling back to num_workers=0 because DataLoader worker initialisation "
            f"failed with: {exc}",
            RuntimeWarning,
        )
        LOGGER.warning("DataLoader workers failed to start (%s). Using num_workers=0 instead.", exc)
        loader_kwargs.pop("persistent_workers", None)
        loader_kwargs.pop("multiprocessing_context", None)
        loader_kwargs["num_workers"] = 0
        return DataLoader(**loader_kwargs)


## Utility helpers and metrics


In [None]:
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def sanitize_boxes_and_labels(
    boxes: Tensor, labels: Tensor, height: int, width: int, min_size: float = 1.0
) -> Tuple[Tensor, Tensor]:
    if boxes.numel() == 0:
        return boxes.reshape(0, 4).float(), labels.reshape(0).long()

    boxes = boxes.clone()
    boxes[:, 0::2] = boxes[:, 0::2].clamp(0, float(width))
    boxes[:, 1::2] = boxes[:, 1::2].clamp(0, float(height))

    widths = boxes[:, 2] - boxes[:, 0]
    heights = boxes[:, 3] - boxes[:, 1]
    keep = (widths > min_size) & (heights > min_size)

    if keep.sum() == 0:
        return boxes.new_zeros((0, 4)), labels.new_zeros((0,), dtype=torch.long)
    return boxes[keep].float(), labels[keep].long()


def compute_iou_matrix(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray:
    if boxes1.size == 0 or boxes2.size == 0:
        return np.zeros((boxes1.shape[0], boxes2.shape[0]), dtype=np.float32)

    x11, y11, x12, y12 = np.split(boxes1, 4, axis=1)
    x21, y21, x22, y22 = np.split(boxes2, 4, axis=1)

    inter_x1 = np.maximum(x11, x21.T)
    inter_y1 = np.maximum(y11, y21.T)
    inter_x2 = np.minimum(x12, x22.T)
    inter_y2 = np.minimum(y12, y22.T)

    inter_w = np.clip(inter_x2 - inter_x1, a_min=0.0, a_max=None)
    inter_h = np.clip(inter_y2 - inter_y1, a_min=0.0, a_max=None)
    inter_area = inter_w * inter_h

    area1 = (x12 - x11) * (y12 - y11)
    area2 = (x22 - x21) * (y22 - y21)

    union = area1 + area2.T - inter_area
    return np.divide(inter_area, union, out=np.zeros_like(inter_area), where=union > 0)


def compute_average_precision(recalls: np.ndarray, precisions: np.ndarray) -> float:
    if recalls.size == 0 or precisions.size == 0:
        return 0.0

    mrec = np.concatenate(([0.0], recalls, [1.0]))
    mpre = np.concatenate(([0.0], precisions, [0.0]))

    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = max(mpre[i - 1], mpre[i])

    recall_points = np.linspace(0, 1, 101)
    precision_interp = np.interp(recall_points, mrec, mpre)
    return float(np.trapz(precision_interp, recall_points))


def accumulate_classification_stats(
    predictions: Sequence[Dict[str, np.ndarray]],
    targets: Sequence[Dict[str, np.ndarray]],
    num_classes: int,
    iou_threshold: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[List[float]], List[List[int]], np.ndarray]:
    tp = np.zeros(num_classes, dtype=np.int64)
    fp = np.zeros(num_classes, dtype=np.int64)
    fn = np.zeros(num_classes, dtype=np.int64)
    scores: List[List[float]] = [[] for _ in range(num_classes)]
    matches: List[List[int]] = [[] for _ in range(num_classes)]
    gt_counter = np.zeros(num_classes, dtype=np.int64)

    for pred, tgt in zip(predictions, targets):
        pred_boxes = pred["boxes"]
        pred_scores = pred["scores"]
        pred_labels = pred["labels"].astype(np.int64)

        gt_boxes = tgt["boxes"]
        gt_labels = tgt["labels"].astype(np.int64)

        unique_classes = np.unique(np.concatenate((pred_labels, gt_labels)))
        for cls in unique_classes:
            pb = pred_boxes[pred_labels == cls]
            ps = pred_scores[pred_labels == cls]
            tb = gt_boxes[gt_labels == cls]
            gt_counter[cls] += len(tb)

            if len(tb) == 0:
                fp[cls] += len(pb)
                scores[cls].extend(ps.tolist())
                matches[cls].extend([0] * len(pb))
                continue

            order = np.argsort(-ps)
            pb = pb[order]
            ps = ps[order]
            iou_matrix = compute_iou_matrix(pb, tb)

            matched_gt: set[int] = set()
            for det_idx, score in enumerate(ps):
                if tb.size == 0:
                    fp[cls] += 1
                    scores[cls].append(float(score))
                    matches[cls].append(0)
                    continue

                best_gt = int(np.argmax(iou_matrix[det_idx]))
                best_iou = iou_matrix[det_idx, best_gt]

                if best_iou >= iou_threshold and best_gt not in matched_gt:
                    tp[cls] += 1
                    matched_gt.add(best_gt)
                    scores[cls].append(float(score))
                    matches[cls].append(1)
                else:
                    fp[cls] += 1
                    scores[cls].append(float(score))
                    matches[cls].append(0)

            fn[cls] += len(tb) - len(matched_gt)

    return tp, fp, fn, scores, matches, gt_counter


def compute_detection_metrics(
    predictions: Sequence[Dict[str, np.ndarray]],
    targets: Sequence[Dict[str, np.ndarray]],
    num_classes: int,
    iou_threshold: float,
) -> Dict[str, np.ndarray]:
    tp, fp, fn, scores, matches, gt_counter = accumulate_classification_stats(
        predictions, targets, num_classes, iou_threshold
    )

    precision = np.divide(tp, np.clip(tp + fp, a_min=1, a_max=None))
    recall = np.divide(tp, np.clip(tp + fn, a_min=1, a_max=None))

    ap = np.zeros(num_classes, dtype=np.float32)
    for cls in range(num_classes):
        if gt_counter[cls] == 0:
            ap[cls] = np.nan
            continue
        if not scores[cls]:
            ap[cls] = 0.0
            continue

        order = np.argsort(-np.asarray(scores[cls]))
        match_array = np.asarray(matches[cls], dtype=np.int32)[order]
        cumulative_tp = np.cumsum(match_array)
        cumulative_fp = np.cumsum(1 - match_array)

        recalls = cumulative_tp / gt_counter[cls]
        precisions = cumulative_tp / np.maximum(cumulative_tp + cumulative_fp, 1)
        ap[cls] = compute_average_precision(recalls, precisions)

    valid_ap = ap[np.isfinite(ap)]
    map_value = float(valid_ap.mean()) if valid_ap.size else 0.0

    return {
        "TP": tp,
        "FP": fp,
        "FN": fn,
        "precision": precision,
        "recall": recall,
        "AP": ap,
        "mAP": map_value,
        "gt_counter": gt_counter,
    }


class SmoothedValue:
    def __init__(self, window_size: int = 20) -> None:
        self.window_size = window_size
        self.deque: List[float] = []
        self.total = 0.0
        self.count = 0

    def update(self, value: float) -> None:
        if len(self.deque) == self.window_size:
            self.total -= self.deque.pop(0)
        self.deque.append(value)
        self.total += value
        self.count += 1

    @property
    def avg(self) -> float:
        if not self.deque:
            return 0.0
        return self.total / len(self.deque)


class MetricLogger:
    def __init__(self) -> None:
        self.meters: Dict[str, SmoothedValue] = {}

    def update(self, **kwargs: float) -> None:
        for name, value in kwargs.items():
            if name not in self.meters:
                self.meters[name] = SmoothedValue()
            self.meters[name].update(float(value))

    def format(self) -> str:
        parts = [f"{name}: {meter.avg:.4f}" for name, meter in self.meters.items()]
        return " | ".join(parts)


## Model construction


In [None]:
def _save_state_dict(model: nn.Module, path: Path) -> None:
    try:
        torch.save(model.state_dict(), path)
        LOGGER.info("Saved pretrained weights to %s", path)
    except Exception as exc:
        LOGGER.warning("Unable to save pretrained weights: %s", exc)


def _load_pretrained_model(train_cfg: TrainingConfig) -> nn.Module:
    pretrained_path = train_cfg.pretrained_weights_path
    pretrained_path.parent.mkdir(parents=True, exist_ok=True)

    weights_enum = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    try:
        model = fasterrcnn_resnet50_fpn_v2(weights=weights_enum)
        LOGGER.info("Loaded torchvision Faster R-CNN weights")
        if not pretrained_path.exists():
            _save_state_dict(model, pretrained_path)
        return model
    except Exception:
        LOGGER.warning("Falling back to locally saved pretrained detector weights")
        if not pretrained_path.exists():
            raise RuntimeError(
                "No pretrained weights available. Download them manually and place them at "
                + str(pretrained_path)
            )
        state_dict = torch.load(pretrained_path, map_location="cpu")
        model = fasterrcnn_resnet50_fpn_v2(weights=None)
        model.load_state_dict(state_dict)
        return model


def build_model(
    dataset_cfg: DatasetConfig,
    train_cfg: TrainingConfig,
    device: Optional[torch.device] = None,
) -> nn.Module:
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = _load_pretrained_model(train_cfg)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    num_classes_with_background = dataset_cfg.num_classes + 1
    model.roi_heads.box_predictor = FastRCNNPredictor(
        in_features, num_classes_with_background
    )

    if train_cfg.small_object:
        anchor_generator = AnchorGenerator(
            sizes=((16, 32, 64, 128, 256),) * 5,
            aspect_ratios=((0.5, 1.0, 2.0),) * 5,
        )
        model.rpn.anchor_generator = anchor_generator
        LOGGER.info("Using custom anchor sizes optimised for small objects")

    model.to(device)
    return model


## Training utilities


In [None]:
def move_to_device(targets: List[Dict[str, torch.Tensor]], device: torch.device) -> List[Dict[str, torch.Tensor]]:
    return [{k: v.to(device) for k, v in target.items()} for target in targets]


def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scaler: GradScaler,
    device: torch.device,
    amp: bool,
    log_every: int,
) -> float:
    model.train()
    metric_logger = MetricLogger()
    progress = tqdm(loader, desc="Train", leave=False)

    for step, (images, targets) in enumerate(progress, start=1):
        images = [img.to(device) for img in images]
        targets = move_to_device(targets, device)

        optimizer.zero_grad()
        with autocast(enabled=amp):
            loss_dict = model(images, targets)
            loss = sum(loss_dict.values())

        if torch.isfinite(loss):
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            LOGGER.warning("Skipping step %s due to non-finite loss", step)
            scaler.update()
            continue

        metric_logger.update(loss=loss.item())
        if step % log_every == 0:
            progress.set_postfix_str(metric_logger.format())

    return metric_logger.meters.get("loss").avg if metric_logger.meters else 0.0


@torch.no_grad()
def evaluate(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    dataset_cfg: DatasetConfig,
    train_cfg: TrainingConfig,
) -> Dict[str, torch.Tensor | float | List[float]]:
    was_training = model.training
    model.eval()

    predictions = []
    targets_for_eval = []
    total_loss = 0.0
    num_batches = 0

    for images, targets in loader:
        images = [img.to(device) for img in images]
        targets_device = move_to_device(targets, device)

        model.train()
        loss_dict = model(images, targets_device)
        total_loss += sum(loss_dict.values()).item()
        num_batches += 1
        model.eval()

        outputs = model(images)
        for output, target in zip(outputs, targets_device):
            scores = output["scores"].detach().cpu().numpy()
            keep = scores >= train_cfg.score_threshold
            predictions.append(
                {
                    "boxes": output["boxes"].detach().cpu().numpy()[keep],
                    "scores": scores[keep],
                    "labels": output["labels"].detach().cpu().numpy()[keep],
                }
            )
            targets_for_eval.append(
                {
                    "boxes": target["boxes"].detach().cpu().numpy(),
                    "labels": target["labels"].detach().cpu().numpy(),
                }
            )

    metrics = compute_detection_metrics(
        predictions, targets_for_eval, dataset_cfg.num_classes, train_cfg.iou_threshold
    )
    metrics["loss"] = total_loss / max(num_batches, 1)

    if was_training:
        model.train()
    return metrics


def save_checkpoint(model: nn.Module, path: Path) -> None:
    torch.save(model.state_dict(), path)
    LOGGER.info("Saved checkpoint to %s", path)


def train_pipeline(
    dataset_cfg: DatasetConfig,
    train_cfg: TrainingConfig,
    *,
    resume_from: Optional[Path] = None,
) -> Tuple[nn.Module, List[Dict[str, float]]]:
    train_cfg.ensure_directories()
    set_seed(train_cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_model(dataset_cfg, train_cfg, device=device)

    if resume_from is not None and Path(resume_from).exists():
        LOGGER.info("Resuming model weights from %s", resume_from)
        state_dict = torch.load(resume_from, map_location=device)
        model.load_state_dict(state_dict)

    train_dataset = ElectricalComponentsDataset(
        root=dataset_cfg.base_dir,
        split=dataset_cfg.train_split,
        class_names=dataset_cfg.class_names,
        use_augmentation=train_cfg.augmentation,
    )
    valid_dataset = ElectricalComponentsDataset(
        root=dataset_cfg.base_dir,
        split=dataset_cfg.valid_split,
        class_names=dataset_cfg.class_names,
        use_augmentation=False,
    )

    train_loader = create_data_loaders(
        train_dataset,
        batch_size=train_cfg.batch_size,
        shuffle=True,
        num_workers=train_cfg.num_workers,
    )
    valid_loader = create_data_loaders(
        valid_dataset,
        batch_size=train_cfg.batch_size,
        shuffle=False,
        num_workers=max(1, train_cfg.num_workers // 2),
    )

    optimizer = AdamW(model.parameters(), lr=train_cfg.learning_rate, weight_decay=train_cfg.weight_decay)
    scaler = GradScaler(enabled=train_cfg.amp and device.type == "cuda")

    best_map = -float("inf")
    history: List[Dict[str, float]] = []

    for epoch in range(1, train_cfg.epochs + 1):
        LOGGER.info("Epoch %s/%s", epoch, train_cfg.epochs)
        train_loss = train_one_epoch(
            model, train_loader, optimizer, scaler, device, train_cfg.amp, train_cfg.log_every
        )
        LOGGER.info("Training loss: %.4f", train_loss)

        if epoch % train_cfg.eval_interval == 0:
            metrics = evaluate(model, valid_loader, device, dataset_cfg, train_cfg)
            LOGGER.info(
                "Validation loss: %.4f | mAP@%.2f: %.4f",
                metrics["loss"],
                train_cfg.iou_threshold,
                metrics["mAP"],
            )

            for cls_idx, (tp, fp, fn, ap) in enumerate(
                zip(metrics["TP"], metrics["FP"], metrics["FN"], metrics["AP"])
            ):
                class_name = dataset_cfg.class_names[cls_idx]
                LOGGER.info(
                    "Class %-15s | TP: %3d FP: %3d FN: %3d | Precision: %.3f Recall: %.3f | AP: %s",
                    class_name,
                    int(tp),
                    int(fp),
                    int(fn),
                    float(metrics["precision"][cls_idx]),
                    float(metrics["recall"][cls_idx]),
                    "nan" if np.isnan(ap) else f"{ap:.3f}",
                )

            history.append(
                {
                    "epoch": epoch,
                    "train_loss": float(train_loss),
                    "val_loss": float(metrics["loss"]),
                    "mAP": float(metrics["mAP"]),
                }
            )

            if metrics["mAP"] > best_map:
                best_map = float(metrics["mAP"])
                save_checkpoint(model, train_cfg.checkpoint_path)

    (train_cfg.output_dir / "training_history.json").write_text(json.dumps(history, indent=2))
    LOGGER.info("Training complete. Best mAP: %.4f", best_map)

    return model, history


## Inference helpers


In [None]:
DEFAULT_COLORS = [
    "#FF6B6B",
    "#4ECDC4",
    "#556270",
    "#C44D58",
    "#FFB347",
    "#6B5B95",
    "#88B04B",
    "#92A8D1",
    "#955251",
    "#B565A7",
]


def load_font() -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
    try:
        return ImageFont.truetype("DejaVuSans.ttf", size=14)
    except Exception:
        return ImageFont.load_default()


def draw_boxes(
    image: np.ndarray,
    prediction: Dict[str, np.ndarray],
    target: Dict[str, np.ndarray] | None,
    class_names: List[str],
    score_threshold: float,
    draw_ground_truth: bool,
    output_path: Path,
) -> None:
    pil = PILImage.fromarray(image)
    draw = ImageDraw.Draw(pil)
    font = load_font()

    colors = DEFAULT_COLORS
    boxes = prediction["boxes"]
    labels = prediction["labels"].astype(int)
    scores = prediction["scores"]

    for box, label, score in zip(boxes, labels, scores):
        if score < score_threshold:
            continue
        color = colors[label % len(colors)]
        x1, y1, x2, y2 = box.tolist()
        draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
        caption = f"{class_names[label]} {score:.2f}"
        text_size = draw.textlength(caption, font=font)
        draw.rectangle([x1, y1 - 16, x1 + text_size + 8, y1], fill=color)
        draw.text((x1 + 4, y1 - 14), caption, fill="white", font=font)

    if draw_ground_truth and target is not None:
        gt_boxes = target["boxes"]
        gt_labels = target["labels"].astype(int)
        for box, label in zip(gt_boxes, gt_labels):
            color = "#FFFFFF"
            x1, y1, x2, y2 = box.tolist()
            draw.rectangle([x1, y1, x2, y2], outline=color, width=1)
            caption = f"GT {class_names[label]}"
            text_size = draw.textlength(caption, font=font)
            draw.rectangle([x1, y2, x1 + text_size + 6, y2 + 14], fill=color)
            draw.text((x1 + 3, y2), caption, fill="black", font=font)

    output_path.parent.mkdir(parents=True, exist_ok=True)
    pil.save(output_path)


@torch.no_grad()
def run_inference(
    dataset_cfg: DatasetConfig,
    inference_cfg: InferenceConfig,
    train_cfg: TrainingConfig,
    checkpoint_path: Path,
    split: Optional[str] = None,
) -> Dict[str, np.ndarray]:
    inference_cfg.ensure_directories()
    set_seed(train_cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_model(dataset_cfg, train_cfg, device=device)
    state_dict = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()

    dataset = ElectricalComponentsDataset(
        root=dataset_cfg.base_dir,
        split=split or dataset_cfg.test_split,
        class_names=dataset_cfg.class_names,
        use_augmentation=False,
    )
    loader = create_data_loaders(dataset, batch_size=1, shuffle=False, num_workers=0)

    predictions: List[Dict[str, np.ndarray]] = []
    targets_for_eval: List[Dict[str, np.ndarray]] = []

    progress = tqdm(loader, desc="Infer")
    for idx, (images, targets) in enumerate(progress):
        image = images[0].to(device)
        output = model([image])[0]

        prediction_np = {
            "boxes": output["boxes"].detach().cpu().numpy(),
            "scores": output["scores"].detach().cpu().numpy(),
            "labels": output["labels"].detach().cpu().numpy(),
        }
        target_np = {
            "boxes": targets[0]["boxes"].detach().cpu().numpy(),
            "labels": targets[0]["labels"].detach().cpu().numpy(),
        }

        predictions.append(prediction_np)
        targets_for_eval.append(target_np)

        if idx < inference_cfg.max_images:
            image_np = (images[0].permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
            output_path = inference_cfg.output_dir / f"{split or dataset_cfg.test_split}_{idx:04d}.png"
            draw_boxes(
                image_np,
                prediction_np,
                target_np if inference_cfg.draw_ground_truth else None,
                dataset_cfg.class_names,
                inference_cfg.score_threshold,
                inference_cfg.draw_ground_truth,
                output_path,
            )

    metrics = compute_detection_metrics(
        predictions, targets_for_eval, dataset_cfg.num_classes, train_cfg.iou_threshold
    )
    LOGGER.info("mAP@%.2f: %.4f", train_cfg.iou_threshold, metrics["mAP"])
    return metrics


## Example usage


In [None]:
# Example configuration overrides (adjust paths to your Kaggle dataset structure)
# dataset_cfg = DatasetConfig(base_dir=Path('/kaggle/input/your-dataset'))
# train_cfg = TrainingConfig(epochs=10, batch_size=2, augmentation=True)
# inference_cfg = InferenceConfig(score_threshold=0.4, draw_ground_truth=True)

# model, history = train_pipeline(dataset_cfg, train_cfg)
# metrics = run_inference(
#     dataset_cfg,
#     inference_cfg,
#     train_cfg,
#     checkpoint_path=train_cfg.checkpoint_path,
# )

print(
    "Notebook utilities loaded. Configure DatasetConfig/TrainingConfig and call "
    "train_pipeline/run_inference as needed."
)
