In [None]:
### TEST!! USE FOR OPTIMIZATION LATER
# yolo_with_seg.py
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional

from utils.general import non_max_suppression, scale_boxes


# ----------------------------
# YOLO-style letterbox (with ±0.1 rounding)
# ----------------------------
def _letterbox(im, new_shape=640, color=(114, 114, 114),
               auto=True, scaleFill=False, scaleup=True, stride=32):
    """
    Resize and pad image to meet stride-multiple constraints.
    Returns (img, (r,r), (dw,dh)) where dw/dh are HALF paddings.
    """
    h, w = im.shape[:2]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    r = min(new_shape[0] / h, new_shape[1] / w)  # scale ratio
    if not scaleup:
        r = min(r, 1.0)

    new_unpad = (int(round(w * r)), int(round(h * r)))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # remaining wh

    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, stride), np.mod(dh, stride)
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        r = (new_shape[1] / w, new_shape[0] / h)

    dw /= 2
    dh /= 2

    if (w, h) != new_unpad:
        im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)

    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
    return im, (r, r), (dw, dh)


# ----------------------------
# Segmentation head (simple FPN)
# ----------------------------
class PyramidSegHead(nn.Module):
    """
    Inputs: feats [P3, P4, P5] (largest->smallest), out_hw=(H, W)
    Output: logits B x out_ch x H x W
    """
    def __init__(self, out_ch=1, fuse_ch=128):
        super().__init__()
        self.l3 = nn.LazyConv2d(fuse_ch, 1)  # for P5
        self.l2 = nn.LazyConv2d(fuse_ch, 1)  # for P4
        self.l1 = nn.LazyConv2d(fuse_ch, 1)  # for P3
        self.s4 = nn.Conv2d(fuse_ch, fuse_ch, 3, padding=1)
        self.s3 = nn.Conv2d(fuse_ch, fuse_ch, 3, padding=1)
        self.o1 = nn.Conv2d(fuse_ch, 64, 3, padding=1)
        self.o2 = nn.Conv2d(64, 32, 3, padding=1)
        self.logits = nn.Conv2d(32, out_ch, 1)

    def forward(self, feats: List[torch.Tensor], out_hw: Tuple[int, int]):
        p3, p4, p5 = feats
        t5 = self.l3(p5)
        t4 = self.l2(p4) + F.interpolate(t5, size=p4.shape[-2:], mode="nearest")
        t4 = F.relu(self.s4(t4), inplace=True)
        t3 = self.l1(p3) + F.interpolate(t4, size=p3.shape[-2:], mode="nearest")
        t3 = F.relu(self.s3(t3), inplace=True)
        x = F.relu(self.o1(t3), inplace=True)
        x = F.relu(self.o2(x), inplace=True)
        x = self.logits(x)
        x = F.interpolate(x, size=out_hw, mode="bilinear", align_corners=False)
        return x  # logits


# ----------------------------
# Shared encoder multitask model
# ----------------------------
class YoloV5WithSeg(nn.Module):
    """
    ONE model: YOLO backbone+neck+Detect (shared) + a segmentation head.

    Training:
        forward(images_01: torch.Tensor[B,3,H,W]) -> (pred_raw, seg_logits)
        - seg_logits are *logits*; use BCEWithLogitsLoss or similar.

    Inference:
        detect(path_or_np, size=640) -> (det_xyxy[N,6], seg_prob[H,W])
        - det boxes are mapped back to original image size
        - seg_prob is a probability map in original image size
    """
    def __init__(self, yolo_variant="yolov5s", pretrained=True,
                 seg_out_ch=1, fuse_ch=128,
                 conf=0.25, iou=0.45, max_det=300, agnostic=False):
        super().__init__()
        yolo = torch.hub.load("ultralytics/yolov5", yolo_variant, pretrained=pretrained)
        if hasattr(yolo, "model"):  # unwrap AutoShape
            yolo = yolo.model

        self.yolo = yolo
        self.yolo.conf = conf
        self.yolo.iou = iou
        self.yolo.max_det = max_det
        self.yolo.agnostic = agnostic
        self.stride = int(getattr(self.yolo, "stride", 32))
        self.yolo.eval()

        # class names (safe)
        raw_names = getattr(self.yolo, "names", None)
        if isinstance(raw_names, dict):
            self.names = [raw_names[i] for i in range(len(raw_names))]
        elif isinstance(raw_names, (list, tuple)):
            self.names = list(raw_names)
        else:
            self.names = []

        # find Detect head
        self.detect_head = None
        for m in self.yolo.modules():
            if m.__class__.__name__ == "Detect":
                self.detect_head = m
                break
        if self.detect_head is None:
            raise RuntimeError("Couldn't find YOLOv5 Detect module.")

        # segmentation head
        self.seg_head = PyramidSegHead(out_ch=seg_out_ch, fuse_ch=fuse_ch)

    # ----- collect outputs -----
    def _forward_collect(self, x: torch.Tensor):
        cache = []
        z = x
        p3 = p4 = p5 = None
        pred_raw = None

        seq = getattr(self.yolo, "model", None)
        seq = getattr(seq, "model", None) if hasattr(seq, "model") else seq
        if seq is None:
            raise RuntimeError("Unexpected YOLO structure")

        for m in seq:
            f = getattr(m, 'f', -1)
            if f == -1:
                u = z
            elif isinstance(f, int):
                u = cache[f]
            else:
                u = [z if j == -1 else cache[j] for j in f]

            z = m(u)
            cache.append(z)

            if m is self.detect_head:
                pred_raw = z
                f_list = getattr(self.detect_head, 'f', [])
                if isinstance(f_list, (list, tuple)) and len(f_list) >= 3:
                    p3, p4, p5 = [cache[j] for j in f_list[-3:]]
                elif isinstance(u, (list, tuple)) and len(u) >= 3:
                    p3, p4, p5 = u[-3:]

        if pred_raw is None or p3 is None:
            raise RuntimeError("Failed to obtain Detect outputs or P3/P4/P5.")
        return pred_raw, [p3, p4, p5]

    # ----- training path -----
    def forward(self, images_01: torch.Tensor):
        pred_raw, feats = self._forward_collect(images_01)
        B, _, H, W = images_01.shape
        seg_logits = self.seg_head(feats, out_hw=(H, W))
        return pred_raw, seg_logits

    # ----- eval convenience -----
    @torch.inference_mode()
    def detect(self, img_path_or_np, size=640, device: Optional[torch.device] = None, return_logits=False):
        dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.to(dev)

        # Force eval for consistent Detect behavior, then restore flags
        was_training_self = self.training
        was_training_yolo = self.yolo.training
        self.eval()
        self.yolo.eval()

        try:
            # load image
            if isinstance(img_path_or_np, str):
                im0 = cv2.imread(img_path_or_np)
                if im0 is None:
                    raise FileNotFoundError(img_path_or_np)
                im0 = im0[:, :, ::-1]  # BGR->RGB
            else:
                im0 = img_path_or_np
            H0, W0 = im0.shape[:2]

            # letterbox
            lb, _, (dw, dh) = _letterbox(im0, new_shape=size, stride=self.stride)
            im = torch.from_numpy(np.ascontiguousarray(lb.transpose(2, 0, 1))).float() / 255.0
            im = im.unsqueeze(0).to(dev)

            # forward
            pred_raw, feats = self._forward_collect(im)

            # normalize Detect output to [B, N, 5+nc]
            if isinstance(pred_raw, dict) and "pred" in pred_raw:
                pred = pred_raw["pred"]
            elif isinstance(pred_raw, (list, tuple)):
                if len(pred_raw) == 1 and pred_raw[0].ndim == 3:
                    pred = pred_raw[0]
                else:
                    shapes = [getattr(p, "shape", None) for p in pred_raw]
                    raise RuntimeError(f"Unexpected Detect output in eval(): {shapes}")
            else:
                pred = pred_raw

            if pred.ndim == 2:
                pred = pred.unsqueeze(0)

            # NMS
            nms_out = non_max_suppression(
                pred,
                conf_thres=self.yolo.conf,
                iou_thres=self.yolo.iou,
                classes=None,
                agnostic=self.yolo.agnostic,
                max_det=self.yolo.max_det
            )
            det = nms_out[0]
            if det is None or len(det) == 0:
                det_xyxy = torch.empty((0, 6), device=im.device)
            else:
                det_xyxy = det.clone()
                det_xyxy[:, :4] = scale_boxes(im.shape[2:], det_xyxy[:, :4], im0.shape).round()

            # segmentation (predict at letterboxed size, crop padding, then resize to original)
            Hlb, Wlb = im.shape[2], im.shape[3]
            seg_logits = self.seg_head(feats, out_hw=(Hlb, Wlb))

            left  = int(round(float(dw) - 0.1))
            right = int(round(float(dw) + 0.1))
            top   = int(round(float(dh) - 0.1))
            bot   = int(round(float(dh) + 0.1))

            seg_logits = seg_logits[..., top:Hlb - bot, left:Wlb - right]
            seg_logits = F.interpolate(seg_logits, size=(H0, W0), mode='bilinear', align_corners=False)

            seg = seg_logits[0, 0]
            if not return_logits:
                seg = torch.sigmoid(seg)

            return det_xyxy, seg

        finally:
            # restore train/eval flags
            if was_training_self:
                self.train()
            if was_training_yolo:
                self.yolo.train()

    # ----- freeze helpers -----
    def freeze_backbone(self, freeze: bool = True):
        """Freeze/unfreeze all YOLO backbone+neck layers (everything except Detect + seg head)."""
        for p in self.yolo.parameters():
            p.requires_grad = not freeze

    def freeze_detect(self, freeze: bool = True):
        """Freeze/unfreeze the YOLO Detect head only."""
        if self.detect_head is not None:
            for p in self.detect_head.parameters():
                p.requires_grad = not freeze
