In [1]:
# # %% [setup]
# import sys, subprocess

# def _pip(pkg):
#     try:
#         __import__(pkg)
#     except Exception:
#         subprocess.check_call([sys.executable, '-m', 'pip', 'install', pkg])

# for p in ['torch', 'torchvision', 'matplotlib', 'numpy', 'Pillow']:
#     _pip(p)

# !git clone https://github.com/facebookresearch/dinov3.git

In [4]:
import os
import sys
sys.path.append("../")
import torch
from pathlib import Path
import numpy as np
from PIL import Image

import matplotlib.pyplot as plt
from matplotlib.widgets import Button
from utils.download_image import download_image_from_url

# DOWNLOAD FROM URL

In [5]:
# https://s.hdnux.com/photos/01/36/56/03/24830199/1/998x0.jpg
query_image_pre_bbox = download_image_from_url()
if query_image_pre_bbox and os.path.exists(query_image_pre_bbox):
    print(f"Using image file at: {query_image_pre_bbox}")
else:
    print("No valid image file is available.")

Downloaded and saved image as: downloads/998x0_1.jpg
Using image file at: downloads/998x0_1.jpg


# SELECT BBOX and CROP

In [None]:
%matplotlib widget

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.widgets import Button
from PIL import Image

# --- Caricamento immagine dal path nella variabile richiesta ---
try:
    query_image_pre_bbox  # noqa: F821
except NameError as e:
    raise NameError("Definisci prima `query_image_pre_bbox = 'path/alla/immagine.jpg'`.") from e

if not isinstance(query_image_pre_bbox, str):
    raise TypeError("`query_image_pre_bbox` deve essere una STRINGA col percorso dell'immagine.")
if not os.path.exists(query_image_pre_bbox):
    raise FileNotFoundError(f"File non trovato: {query_image_pre_bbox}")

_img_pil = Image.open(query_image_pre_bbox).convert("RGB")
img = np.array(_img_pil)  # usiamo 'img' come nel tuo esempio

# --- Stato interazione (una sola box) ---
start_pt = None          # (x0, y0) primo click
rect_patch = None        # patch Rectangle mostrata a schermo
current_box = None       # (x0, y0, x1, y1) finale
cid_click = None
cid_move = None

def clamp(val, lo, hi):
    return max(lo, min(val, hi))

def onclick(event):
    global start_pt, rect_patch, current_box, cid_move
    if event.inaxes != ax_img:
        return
    if event.xdata is None or event.ydata is None:
        return

    x, y = int(event.xdata), int(event.ydata)

    # 1° click: inizia box + attiva "rubberband" col mouse move
    if start_pt is None:
        start_pt = (x, y)
        # crea o resetta la patch
        if rect_patch is None:
            rect_patch = Rectangle((x, y), 1, 1, fill=False, linewidth=2)
            ax_img.add_patch(rect_patch)
        else:
            rect_patch.set_xy((x, y))
            rect_patch.set_width(1)
            rect_patch.set_height(1)
            rect_patch.set_visible(True)
        # collega movimento per aggiornare dimensioni
        connect_motion()
        fig.canvas.draw_idle()
    # 2° click: fissa box e disconnette il movimento
    else:
        x0, y0 = start_pt
        x1, y1 = x, y
        # normalizza in [min,max]
        x0, x1 = sorted([x0, x1])
        y0, y1 = sorted([y0, y1])

        H, W = img.shape[:2]
        x0 = clamp(x0, 0, W-1); x1 = clamp(x1, 0, W-1)
        y0 = clamp(y0, 0, H-1); y1 = clamp(y1, 0, H-1)

        # aggiorna patch finale
        rect_patch.set_xy((x0, y0))
        rect_patch.set_width(max(1, x1 - x0))
        rect_patch.set_height(max(1, y1 - y0))
        fig.canvas.draw_idle()

        current_box = (x0, y0, x1, y1)
        start_pt = None
        disconnect_motion()

def onmove(event):
    # aggiorna "rubberband" durante il drag
    if start_pt is None or event.inaxes != ax_img:
        return
    if event.xdata is None or event.ydata is None:
        return
    x0, y0 = start_pt
    x1, y1 = int(event.xdata), int(event.ydata)

    # calcola box parziale
    xx0, xx1 = sorted([x0, x1])
    yy0, yy1 = sorted([y0, y1])

    H, W = img.shape[:2]
    xx0 = clamp(xx0, 0, W-1); xx1 = clamp(xx1, 0, W-1)
    yy0 = clamp(yy0, 0, H-1); yy1 = clamp(yy1, 0, H-1)

    if rect_patch is None:
        return
    rect_patch.set_xy((xx0, yy0))
    rect_patch.set_width(max(1, xx1 - xx0))
    rect_patch.set_height(max(1, yy1 - yy0))
    fig.canvas.draw_idle()

def connect_motion():
    global cid_move
    if cid_move is None:
        cid_move = fig.canvas.mpl_connect('motion_notify_event', onmove)

def disconnect_motion():
    global cid_move
    if cid_move is not None:
        fig.canvas.mpl_disconnect(cid_move)
        cid_move = None

def on_ok(event):
    """Esegue il crop dalla box selezionata e chiude la figura principale."""
    global query_image
    if current_box is None:
        print("Nessuna box definita: fai due click (inizio/fine) sull'immagine.")
        return
    x0, y0, x1, y1 = current_box
    if x1 <= x0 or y1 <= y0:
        print("Box non valida. Ridisegna la box.")
        return

    # Crop (usiamo slicing su numpy; img è RGB)
    query_image = img[y0:y1, x0:x1].copy()

    # Mostra il crop in una nuova finestra
    fig2, ax2 = plt.subplots(1, 1, figsize=(6, 6))
    ax2.imshow(query_image)
    ax2.set_title("Cropped query_image")
    ax2.axis("off")
    plt.show()

    # disconnette click e chiude figura principale
    if cid_click is not None:
        fig.canvas.mpl_disconnect(cid_click)
    disconnect_motion()
    plt.close(fig)

def on_reset(event):
    """Cancella la box corrente e ricomincia."""
    global start_pt, rect_patch, current_box
    start_pt = None
    current_box = None
    disconnect_motion()
    if rect_patch is not None:
        rect_patch.set_visible(False)
    fig.canvas.draw_idle()
    print("Reset eseguito: ridisegna la box con due click.")

# --- Figura e layout bottoni (stile identico al tuo) ---
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
plt.subplots_adjust(bottom=0.3)

ax_img = ax
ax_img.imshow(img)
ax_img.set_title("Click 1: inizio box | muovi il mouse | Click 2: chiudi box. Poi premi OK.")
ax_img.axis("off")

# pulsanti (stessa posizione/estetica del tuo snippet)
ax_ok = plt.axes([0.70, 0.05, 0.10, 0.075])
ax_reset = plt.axes([0.50, 0.05, 0.15, 0.075])

btn_ok = Button(ax_ok, 'OK')
btn_reset = Button(ax_reset, 'Reset')

# Eventi
cid_click = fig.canvas.mpl_connect('button_press_event', onclick)
btn_ok.on_clicked(on_ok)
btn_reset.on_clicked(on_reset)

plt.show()


In [None]:
current_box


# TEST

In [None]:
# https://s.yimg.com/ny/api/res/1.2/beOCT4bnh9ULHdfnLuImrA--/YXBwaWQ9aGlnaGxhbmRlcjt3PTY0MDtoPTQyNztjZj13ZWJw/https://media.zenfs.com/en/warriors_wire_usa_today_sports_articles_759/c34198dce54f9a3b5f3cac1b6e5ef91a
test_image = download_image_from_url()
if test_image and os.path.exists(test_image):
    print(f"Using image file at: {test_image}")
else:
    print("No valid image file is available.")

# CODE TO IMPORT FOUNDATION MODEL 

In [48]:
import os
import math
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import torchvision.transforms as T

# -----------------------------
# I/O & common utilities
# -----------------------------

def load_rgb_image(img_path: str) -> Image.Image:
    """Load an image from disk as RGB with clear errors."""
    if not os.path.isfile(img_path):
        raise FileNotFoundError(f"Image file not found at '{img_path}'")
    try:
        img = Image.open(img_path).convert("RGB")
        return img
    except Exception as e:
        raise RuntimeError(f"Error opening image: {e}")

def to_numpy(feature_tensor: torch.Tensor) -> np.ndarray:
    """
    Convert a feature tensor (1, C, H, W) or (C, H, W) or (H, W) to numpy.
    Returns (C, H, W) or (H, W) numpy array.
    """
    if not torch.is_tensor(feature_tensor):
        raise TypeError("Expected a torch.Tensor")
    x = feature_tensor.detach().cpu()
    if x.ndim == 4:
        # assume (B, C, H, W)
        if x.size(0) != 1:
            raise ValueError(f"Expected batch size 1, got {x.size(0)}")
        x = x.squeeze(0)
    return x.numpy()

def infer_square_hw_from_seq_len(seq_len: int) -> int:
    """
    Infer H=W from a sequence length that excludes the CLS token.
    Raises if not a perfect square.
    """
    side = int(math.sqrt(seq_len))
    if side * side != seq_len:
        raise ValueError(
            f"Cannot infer square spatial size from sequence length {seq_len}."
        )
    return side

def normalize_to_uint8(x: np.ndarray) -> np.ndarray:
    """Min-max normalize any array to [0, 255] uint8 for visualization."""
    x = x.astype(np.float32)
    x_min, x_max = x.min(), x.max()
    if x_max == x_min:
        return np.zeros_like(x, dtype=np.uint8)
    x = (x - x_min) / (x_max - x_min)
    x = (x * 255.0).clip(0, 255).astype(np.uint8)
    return x

# -----------------------------
# Transforms per model family
# -----------------------------

def get_clip_transform():
    # CLIP pixel normalization constants
    return T.Compose([
        T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                    std=[0.26862954, 0.26130258, 0.27577711]),
    ])

def get_resnet_transform():
    return T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])

def get_vit_224_transform():
    # Generic ViT-friendly transform to 224x224
    return T.Compose([
        T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5],
                    std=[0.5, 0.5, 0.5]),
    ])

def get_sam_preprocessor():
    """
    Returns a callable that resizes + normalizes for SAM using ResizeLongestSide.
    We only prepare this if 'segment_anything' is installed.
    """
    try:
        from segment_anything.utils.transforms import ResizeLongestSide
    except Exception as e:
        raise ImportError(
            "segment_anything not installed. Install: pip install git+https://github.com/facebookresearch/segment-anything.git"
        )
    return ResizeLongestSide(1024)  # standard SAM image size


# -----------------------------
# Loaders per model family
# -----------------------------

def load_clip(model_variant: str = "openai/clip-vit-base-patch32"):
    """
    Returns (feature_extractor_module, transform_fn, postproc_fn)
    where postproc_fn maps model output to (C,H,W) numpy.
    """
    try:
        from transformers import CLIPModel, CLIPProcessor
    except ImportError:
        raise ImportError("Please install transformers: pip install transformers")

    processor = CLIPProcessor.from_pretrained(model_variant)
    model = CLIPModel.from_pretrained(model_variant)
    vision = model.vision_model.eval()

    transform = get_clip_transform()

    def postproc(outputs):
        # outputs: BaseModelOutputWithPooling
        last = outputs.last_hidden_state  # (B, 1+HW, C)
        if last.ndim != 3:
            raise ValueError(f"Unexpected CLIP hidden shape: {last.shape}")
        b, seq, c = last.shape
        if b != 1:
            raise ValueError(f"Expected B=1, got {b}")
        spatial_tokens = last[:, 1:, :]  # drop CLS -> (1, HW, C)
        hw = spatial_tokens.shape[1]
        side = infer_square_hw_from_seq_len(hw)
        feat = spatial_tokens[0].transpose(0, 1).reshape(c, side, side)  # (C,H,W)
        return to_numpy(feat)

    def forward(t: torch.Tensor):
        # Vision-only forward
        return vision(t)

    return forward, transform, postproc

def load_resnet50():
    """Returns a (feature_extractor_module, transform_fn, postproc_fn)."""
    try:
        import torchvision.models as models
    except ImportError:
        raise ImportError("Please install torchvision: pip install torchvision")

    backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT).eval()
    # all children except avgpool+fc -> retain (B, C, H, W)
    feat_extractor = torch.nn.Sequential(*(list(backbone.children())[:-2])).eval()
    transform = get_resnet_transform()

    def postproc(x: torch.Tensor):
        # x: (1, C, H, W)
        return to_numpy(x)

    def forward(t: torch.Tensor):
        return feat_extractor(t)

    return forward, transform, postproc

def load_dinov2(model_variant: str = "dinov2_vits14"):
    """
    Load DINOv2 from torch.hub. Valid variants include:
    - dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14
    Returns (forward_fn, transform_fn, postproc_fn).
    """
    try:
        model = torch.hub.load("facebookresearch/dinov2", model_variant)
    except Exception as e:
        raise ImportError(
            "Failed to load DINOv2 via torch.hub. Make sure you have internet access once, "
            "or have the weights cached. Error: " + str(e)
        )
    model.eval()
    transform = get_vit_224_transform()

    def postproc(outputs: torch.Tensor | dict):
        """
        Some DINOv2 hub models expose forward_features returning a dict;
        but the default __call__ returns the class token embedding.
        We want spatial tokens. We'll try forward_features if present.
        """
        # Try to get the token features from model.forward_features
        tokens = None
        with torch.no_grad():
            if hasattr(model, "forward_features"):
                feats = model.forward_features
                try:
                    out = feats(_last_input_tensor)  # see forward() capturing below
                except NameError:
                    # fallback: outputs might already be what we need
                    out = outputs
                if isinstance(out, dict):
                    # common keys: 'x_norm_patchtokens' or 'token' etc.
                    for k in ["x_norm_patchtokens", "x_norm_patchtokens"]:
                        if k in out:
                            tokens = out[k]  # (B, HW, C)
                            break
                    if tokens is None:
                        # Try 'x' or 'tokens'
                        for k in ["x", "tokens"]:
                            if k in out:
                                tokens = out[k]
                                break
            # If still None, try to treat outputs as tokens
            if tokens is None:
                tokens = outputs  # hope it's (B, HW, C) or (B, N, C)

        if tokens.ndim != 3:
            raise ValueError(
                f"Unexpected DINOv2 token shape: {tokens.shape}. "
                "You may need to adapt the post-processing for your exact variant."
            )
        b, hw, c = tokens.shape
        if b != 1:
            raise ValueError(f"Expected batch size 1, got {b}")
        side = infer_square_hw_from_seq_len(hw)
        feat = tokens[0].transpose(0, 1).reshape(c, side, side)
        return to_numpy(feat)

    # We capture the input tensor inside forward so postproc can use forward_features safely.
    def forward(t: torch.Tensor):
        global _last_input_tensor
        _last_input_tensor = t
        # Try to get tokens directly
        with torch.no_grad():
            if hasattr(model, "forward_features"):
                out = model.forward_features(t)
                if isinstance(out, dict):
                    # prefer patch tokens if provided
                    for k in ["x_norm_patchtokens", "tokens", "x"]:
                        if k in out:
                            return out[k]  # (B, HW, C)
                return out  # might already be (B, HW, C)
            # fallback to model(t) (often returns cls embedding, not ideal)
            return model(t)

    return forward, transform, postproc

def load_dinov3(model_variant: str = "vit_base_patch16_224.dino", hf_vit_repo: str | None = None):
    """
    Prefer loading a ViT from timm and force extraction of patch tokens.
    Falls back to Hugging Face ViT if timm isn't available or fails.
    Returns (forward_fn, transform_fn, postproc_fn).
    """
    import inspect

    def _call_with_supported_kwargs(fn, *args, **kwargs):
        """Call fn with only kwargs it supports (by name)."""
        sig = inspect.signature(fn)
        supported = {k: v for k, v in kwargs.items() if k in sig.parameters}
        return fn(*args, **supported)

    try:
        import timm
        model = timm.create_model(model_variant, pretrained=True)
        model.eval()
        transform = get_vit_224_transform()

        # --- Strategy: try to get tokens directly; else hook the last block ---
        last_block_tokens = {}

        def _last_block_hook(module, inp, out):
            # out is (B, N, C) for ViT blocks
            last_block_tokens["x"] = out

        # Find a reasonable "last block" to hook
        if hasattr(model, "blocks") and len(getattr(model, "blocks")) > 0:
            hooked_module = model.blocks[-1]
            h = hooked_module.register_forward_hook(_last_block_hook)

        def _extract_tokens_with_hooks(x: torch.Tensor):
            """Run a regular forward to trigger the hook and retrieve tokens."""
            last_block_tokens.clear()
            if hasattr(model, "forward_features"):
                _ = model.forward_features(x)  # triggers hook
            else:
                _ = model(x)
            if "x" not in last_block_tokens:
                raise ValueError("Failed to capture tokens from the last transformer block.")
            return last_block_tokens["x"]  # (B, N, C)

        def forward(t: torch.Tensor):
            with torch.no_grad():
                # 1) Try forward_features with return_dict=True (if supported)
                if hasattr(model, "forward_features"):
                    try:
                        out = _call_with_supported_kwargs(model.forward_features, t, return_dict=True)
                        if isinstance(out, dict):
                            tok = out.get("x", None)
                            if tok is None:
                                for k in ["tokens", "features"]:
                                    if k in out:
                                        tok = out[k]; break
                            if tok is not None and tok.ndim == 3:
                                if tok.shape[1] >= 2:
                                    tok = tok[:, 1:, :]
                                return tok  # (B, HW, C)
                    except Exception:
                        pass

                    # 2) Try return_all_tokens=True if supported
                    try:
                        out = _call_with_supported_kwargs(model.forward_features, t, return_all_tokens=True)
                        if isinstance(out, torch.Tensor) and out.ndim == 3:
                            if out.shape[1] >= 2:
                                out = out[:, 1:, :]
                            return out
                        if isinstance(out, dict):
                            tok = out.get("x", None)
                            if tok is not None and tok.ndim == 3:
                                if tok.shape[1] >= 2:
                                    tok = tok[:, 1:, :]
                                return tok
                    except Exception:
                        pass

                # 3) Fallback: use hook to grab tokens from last block output
                tok = _extract_tokens_with_hooks(t)
                if tok.ndim == 3 and tok.shape[1] >= 2:
                    tok = tok[:, 1:, :]
                return tok

        def postproc(tokens: torch.Tensor):
            if tokens.ndim != 3:
                raise ValueError(f"Unexpected token tensor shape from timm model: {tokens.shape}")
            b, hw, c = tokens.shape
            if b != 1:
                raise ValueError(f"Expected batch size 1, got {b}")
            side = infer_square_hw_from_seq_len(hw)
            feat = tokens[0].transpose(0, 1).reshape(c, side, side)
            return to_numpy(feat)

        return forward, transform, postproc

    except Exception as timm_err:
        # ---- HF fallback (always returns tokens) ----
        if hf_vit_repo is None:
            raise ImportError(
                "timm not available or model not found, and no Hugging Face ViT repo provided. "
                "Install timm (pip install timm) or pass hf_vit_repo.\n"
                f"Original error: {timm_err}"
            )

        try:
            from transformers import ViTModel, AutoImageProcessor
        except ImportError:
            raise ImportError("Please install transformers: pip install transformers")

        processor = AutoImageProcessor.from_pretrained(hf_vit_repo)
        vit = ViTModel.from_pretrained(hf_vit_repo).eval()
        transform = get_vit_224_transform()

        def forward(t: torch.Tensor):
            with torch.no_grad():
                outputs = vit(pixel_values=t)
            return outputs.last_hidden_state  # (B, 1+HW, C)

        def postproc(last_hidden: torch.Tensor):
            if last_hidden.ndim != 3:
                raise ValueError(f"Unexpected ViT hidden shape: {last_hidden.shape}")
            b, n, c = last_hidden.shape
            if b != 1:
                raise ValueError(f"Expected batch size 1, got {b}")
            tokens = last_hidden[:, 1:, :]  # drop CLS
            hw = tokens.shape[1]
            side = infer_square_hw_from_seq_len(hw)
            feat = tokens[0].transpose(0, 1).reshape(c, side, side)
            return to_numpy(feat)

        return forward, transform, postproc

def load_sam(checkpoint_path: str, model_type: str = "vit_h"):
    """
    Load SAM and return (forward_fn, transform_fn, postproc_fn).
    This version:
      - uses sam.preprocess() to normalize+pad to 1024x1024,
      - crops the valid (unpadded) embedding region,
      - resizes embeddings back to the original image size.
    Final output: (C, H0, W0) aligned to the input image.
    """
    import os
    import numpy as np
    import torch
    import torch.nn.functional as F

    try:
        from segment_anything import sam_model_registry
        from segment_anything.utils.transforms import ResizeLongestSide
    except ImportError:
        raise ImportError(
            "Please install Segment Anything:\n"
            "pip install git+https://github.com/facebookresearch/segment-anything.git"
        )

    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(
            f"SAM checkpoint not found at '{checkpoint_path}'. "
        )

    if model_type not in {"vit_h", "vit_l", "vit_b"}:
        raise ValueError("SAM model_type must be 'vit_h', 'vit_l' or 'vit_b'.")

    sam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval()
    resizer = ResizeLongestSide(1024)

    # --- closure state to carry geometry across transform/forward/postproc ---
    orig_hw = {"H0": None, "W0": None}   # original image size
    rez_hw  = {"Hr": None, "Wr": None}   # resized (before padding) size

    def transform_for_sam(pil_img: "Image.Image") -> torch.Tensor:
        """
        Returns (1, 3, H', W') float32 in range [0..255].
        Stores original (H0,W0) and resized (Hr,Wr) in the closure.
        """
        image_np = np.array(pil_img)  # (H0, W0, 3) uint8
        H0, W0 = image_np.shape[:2]
        orig_hw["H0"], orig_hw["W0"] = H0, W0

        # resize (keep aspect) so that max(Hr,Wr)=1024; still no padding here
        resized_np = resizer.apply_image(image_np)  # (Hr, Wr, 3)
        Hr, Wr = resized_np.shape[:2]
        rez_hw["Hr"], rez_hw["Wr"] = Hr, Wr

        t = torch.as_tensor(resized_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
        return t  # (1, 3, Hr, Wr) in [0..255]

    def forward(t: torch.Tensor):
        """
        Use sam.preprocess() to normalize+pad to (1,3,1024,1024), then image_encoder().
        """
        with torch.no_grad():
            device = next(sam.parameters()).device
            x = t.to(device)          # (1, 3, Hr, Wr), 0..255
            x = sam.preprocess(x)     # (1, 3, 1024, 1024)
            emb = sam.image_encoder(x)  # (1, C, 64, 64)
            return emb

    def postproc(emb: torch.Tensor):
        """
        1) cut away bottom/right padding,
        2) upsample to (Hr,Wr),
        3) resize to (H0,W0).
        Returns -> (C, H0, W0) numpy.
        """
        if orig_hw["H0"] is None or rez_hw["Hr"] is None:
            raise RuntimeError("Geometry metadata missing. Did you call transform_for_sam first?")

        H0, W0 = orig_hw["H0"], orig_hw["W0"]
        Hr, Wr = rez_hw["Hr"], rez_hw["Wr"]

        with torch.no_grad():
            # emb: (1, C, 64, 64)
            B, C, Hf, Wf = emb.shape
            if B != 1:
                raise ValueError(f"Expected batch size 1, got {B}")

            # valid token counts BEFORE padding (padding is bottom/right)
            # each token covers 16x16 input pixels (1024/64)
            ht = int(np.ceil(Hr / 16))
            wt = int(np.ceil(Wr / 16))
            ht = min(ht, Hf)
            wt = min(wt, Wf)

            emb_valid = emb[:, :, :ht, :wt]  # (1, C, ht, wt)

            # upsample from tokens to resized image resolution (Hr, Wr)
            emb_resized = F.interpolate(emb_valid, size=(Hr, Wr), mode="bilinear", align_corners=False)  # (1, C, Hr, Wr)

            # finally, map back to original resolution (H0, W0)
            emb_original = F.interpolate(emb_resized, size=(H0, W0), mode="bilinear", align_corners=False)  # (1, C, H0, W0)

        return emb_original.squeeze(0).cpu().numpy()  # (C, H0, W0)

    return (forward, transform_for_sam, postproc)






def visualize_overlay_mean(img: Image.Image, features_chw: np.ndarray, alpha: float = 0.5, title: str = ""):
    if features_chw.ndim != 3:
        raise ValueError(f"Expected (C,H,W) features, got {features_chw.shape}")
    feat_mean = np.mean(features_chw, axis=0)  # (H, W)
    feat_u8 = normalize_to_uint8(feat_mean)
    heatmap = Image.fromarray(feat_u8).resize(img.size, resample=Image.BICUBIC)

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(img)
    plt.imshow(heatmap, cmap="viridis", alpha=alpha)
    plt.title(title or "Features Overlay (Mean)")
    plt.axis("off")
    plt.show()

def visualize_overlay_pca(img: Image.Image, features_chw: np.ndarray, alpha: float = 0.98, title: str = ""):
    if features_chw.ndim != 3 or features_chw.shape[0] < 3:
        raise ValueError(f"PCA visualization requires features with >=3 channels, got {features_chw.shape}")

    C, H, W = features_chw.shape
    X = features_chw.transpose(1, 2, 0).reshape(-1, C)  # (H*W, C)
    pca = PCA(n_components=3)
    X3 = pca.fit_transform(X)  # (H*W, 3)
    X3 = X3.reshape(H, W, 3)
    X3_u8 = normalize_to_uint8(X3)
    rgb = Image.fromarray(X3_u8).resize(img.size, resample=Image.BICUBIC)

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(img)
    plt.imshow(rgb, alpha=alpha)
    plt.title(title or "Features Overlay (PCA)")
    plt.axis("off")
    plt.show()

def visualize_features(
    img: Image.Image,
    features_chw: np.ndarray,
    mode: str = "mean",
    title_prefix: str = ""
):
    # --- fallback to the generic mean/pca overlays ---
    if features_chw.ndim == 3:
        if mode == "mean":
            visualize_overlay_mean(img, features_chw, alpha=0.5, title=f"{title_prefix} Features (Mean)")
        elif mode == "pca":
            visualize_overlay_pca(img, features_chw, alpha=0.98, title=f"{title_prefix} Features (PCA)")
        else:
            raise ValueError(f"Unsupported visualization mode: {mode}")

    else:
        raise ValueError(f"Cannot visualize features of shape {features_chw.shape}")


# -----------------------------
# Orchestrator
# -----------------------------

def get_backbone_and_transform(
    model: str,
    model_variant: str | None = None,
    sam_checkpoint: str | None = None,
    dinov3_hf_vit_repo: str | None = None
):
    """
    Return (forward_fn, transform_fn, postproc_fn, model_label) for the requested model.
    model: 'clip' | 'resnet' | 'dinov2' | 'dinov3' | 'sam' | 'handcrafted'
    model_variant: optional string to pick a specific sub-variant.
    sam_checkpoint: required for 'sam'.
    dinov3_hf_vit_repo: optional HF repo id for ViT fallback if timm variant isn't available.
    """
    m = model.strip().lower()
    if m == "clip":
        fwd, tfm, post = load_clip(model_variant or "openai/clip-vit-base-patch32")
        return fwd, tfm, post, f"CLIP ({model_variant or 'ViT-B/32'})"
    elif m == "resnet":
        fwd, tfm, post = load_resnet50()
        return fwd, tfm, post, "ResNet50 (pretrained)"
    elif m == "dinov2":
        fwd, tfm, post = load_dinov2(model_variant or "dinov2_vits14")
        return fwd, tfm, post, f"DINOv2 ({model_variant or 'dinov2_vits14'})"
    elif m == "dinov3":
        fwd, tfm, post = load_dinov3(model_variant or "vit_base_patch16_224.dino",
                                     hf_vit_repo=dinov3_hf_vit_repo)
        label = f"DINOv3/ViT ({model_variant or dinov3_hf_vit_repo or 'timm default'})"
        return fwd, tfm, post, label
    elif m == "sam":
        if not sam_checkpoint:
            raise ValueError("For 'sam', you must pass sam_checkpoint='path/to/sam_checkpoint.pth'.")
        fwd, tfm, post = load_sam(sam_checkpoint, model_variant or "vit_h")
        return fwd, tfm, post, f"SAM ({model_variant or 'vit_h'})"
    else:
        raise ValueError("Unsupported model. Choose from: 'clip', 'resnet', 'dinov2', 'dinov3', 'sam', 'handcrafted'.")

def extract_spatial_features_from_image(
    img: Image.Image,
    forward_fn,
    transform_fn,
    postproc_fn
) -> np.ndarray:
    """
    Apply transform, run forward, and convert to (C, H, W) numpy spatial map.
    """
    with torch.no_grad():
        img_t = transform_fn(img)  # (1, C, H, W) or SAM's input
        if img_t.ndim == 3:
            img_t = img_t.unsqueeze(0)
        outputs = forward_fn(img_t)
    features = postproc_fn(outputs)
    if features.ndim == 2:
        # (H, W) -> add channel dim
        features = features[None, ...]
    if features.ndim != 3:
        raise ValueError(f"Expected (C,H,W) after postproc, got {features.shape}")
    return features

# -----------------------------
# Public API (drop-in for your original function)
# -----------------------------

def spatial_features_extractor(
    img_path: str,
    model: str = "clip",
    visualization_mode: str = "mean",
    *,
    model_variant: str | None = None,
    sam_checkpoint: str | None = None,
    dinov3_hf_vit_repo: str | None = None
):
    """
    Extract and visualize spatial features from an image using different models.

    Args:
        img_path: Path to the image file.
        model: One of 'clip', 'resnet', 'dinov2', 'dinov3', 'sam', 'handcrafted'.
        visualization_mode: 'mean' or 'pca' or any handcrafted-specific mode routed by visualize_features.
        model_variant:
            - clip: HF model id (default 'openai/clip-vit-base-patch32')
            - resnet: ignored
            - dinov2: torch.hub variants (e.g., 'dinov2_vits14', 'dinov2_vitb14', ...)
            - dinov3: timm name (e.g., 'vit_base_patch16_224.dino'), or use dinov3_hf_vit_repo
            - sam: 'vit_h' | 'vit_l' | 'vit_b'
            - handcrafted: 'hog' | 'lbp' | 'gabor' | 'sobel' | 'sift' | 'orb'
        sam_checkpoint: required for 'sam' (path to .pth checkpoint).
        dinov3_hf_vit_repo:
            If timm is unavailable or your timm variant isn't found, provide a Hugging Face ViT repo id
            for fallback (any ViT that returns token sequences), e.g., 'google/vit-base-patch16-224'.
    """
    try:
        img = load_rgb_image(img_path)
    except Exception as e:
        print(str(e))
        return

    try:
        forward_fn, transform_fn, postproc_fn, label = get_backbone_and_transform(
            model=model,
            model_variant=model_variant,
            sam_checkpoint=sam_checkpoint,
            dinov3_hf_vit_repo=dinov3_hf_vit_repo
        )
    except Exception as e:
        print(f"Error loading model '{model}': {e}")
        return

    try:
        features = extract_spatial_features_from_image(img, forward_fn, transform_fn, postproc_fn)
        print(f"Spatial features shape: {features.shape}")  # (C, H, W)
    except Exception as e:
        print(f"Error during feature extraction: {e}")
        return

    try:
        visualize_features(img, features, mode=visualization_mode, title_prefix=label)
    except Exception as e:
        print(f"Error during visualization: {e}")
        return




In [None]:
# spatial_features_extractor(test_image, model="clip", visualization_mode="pca")
# spatial_features_extractor(test_image, model="resnet", visualization_mode="pca")
# spatial_features_extractor(test_image, model="dinov2", model_variant="dinov2_vitb14", visualization_mode="pca")
# spatial_features_extractor(test_image, model="dinov3", model_variant="vit_base_patch16_224.dino", visualization_mode="pca")
# spatial_features_extractor(test_image, model="dinov3", dinov3_hf_vit_repo="google/vit-base-patch16-224", visualization_mode="pca")
# spatial_features_extractor(test_image, model="sam", model_variant="vit_h", sam_checkpoint="../content/sam_vit_h_4b8939.pth", visualization_mode="pca")


# FUNCTIONS for TEMPLATE MATCHING

In [65]:
# -----------------------------
# Helpers minimi
# -----------------------------
def ensure_pil(img):
    """Accetta PIL o np.ndarray (H,W,3/4) e restituisce PIL.Image RGB."""
    if isinstance(img, Image.Image):
        return img.convert("RGB")
    if isinstance(img, np.ndarray):
        x = img
        if x.ndim != 3 or x.shape[2] not in (3, 4):
            raise TypeError(f"Unsupported ndarray shape for image: {x.shape}")
        if x.dtype != np.uint8:
            x = np.clip(x, 0, 255).astype(np.uint8)
        if x.shape[2] == 4:
            x = x[:, :, :3]
        return Image.fromarray(x, mode="RGB")
    raise TypeError(f"Unsupported image type: {type(img)}")

def infer_square_hw_from_seq_len(seq_len: int) -> int:
    side = int(round(seq_len ** 0.5))
    if side * side != seq_len:
        raise ValueError(f"Cannot infer (H=W) from seq_len={seq_len}")
    return side

def get_vit_224_transform():
    return T.Compose([
        T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

# -----------------------------
# Loader DINOv3 (timm) minimale
# -----------------------------
def load_dinov3_minimal(model_variant: str = "vit_base_patch16_224.dino"):
    """
    Crea un ViT DINOv3 da timm e restituisce:
      forward_fn(t)->tokens (B, HW, C)
      transform_fn(PIL)->(1,3,224,224)
      postproc_fn(tokens)->(C,H,W) numpy
    """
    import timm

    model = timm.create_model(model_variant, pretrained=True)
    model.eval()

    transform = get_vit_224_transform()

    # Hook sull'ultimo blocco per catturare i token
    last_tokens = {}
    if hasattr(model, "blocks") and len(model.blocks) > 0:
        h = model.blocks[-1].register_forward_hook(lambda m, i, o: last_tokens.__setitem__("x", o))

    def forward(t: torch.Tensor):
        with torch.no_grad():
            last_tokens.clear()
            if hasattr(model, "forward_features"):
                _ = model.forward_features(t)
            else:
                _ = model(t)
            x = last_tokens.get("x", None)            # (B, N, C)
            if x is None:
                raise RuntimeError("Failed to capture ViT tokens from last block.")
            if x.ndim != 3:
                raise ValueError(f"Unexpected tokens shape: {x.shape}")
            # drop CLS se presente
            if x.shape[1] >= 2:
                x = x[:, 1:, :]
            return x  # (B, HW, C)

    def postproc(tokens: torch.Tensor) -> np.ndarray:
        if tokens.ndim != 3:
            raise ValueError(f"Unexpected token tensor shape: {tokens.shape}")
        b, hw, c = tokens.shape
        if b != 1:
            raise ValueError(f"Expected B=1, got {b}")
        side = infer_square_hw_from_seq_len(hw)
        feat = tokens[0].transpose(0, 1).reshape(c, side, side)  # (C,H,W)
        return feat.detach().cpu().numpy()

    return forward, transform, postproc


# -----------------------------
# Estrazione feature (comune)
# -----------------------------
def extract_spatial_features_from_image(img_pil: Image.Image, forward_fn, transform_fn, postproc_fn) -> np.ndarray:
    with torch.no_grad():
        t = transform_fn(img_pil)  # (3,H,W) norm
        if t.ndim == 3:
            t = t.unsqueeze(0)
        out = forward_fn(t)        # tokens/feat
    feat = postproc_fn(out)        # (C,H,W) numpy
    if feat.ndim == 2:
        feat = feat[None, ...]
    if feat.ndim != 3:
        raise ValueError(f"Expected (C,H,W), got {feat.shape}")
    return feat


# -----------------------------
# Matching & visual
# -----------------------------
def resize_features_chw(features_chw: np.ndarray, out_hw: tuple[int, int]) -> np.ndarray:
    if features_chw.ndim != 3:
        raise ValueError(f"Expected (C,H,W), got {features_chw.shape}")
    C, H, W = features_chw.shape
    Ht, Wt = out_hw
    t = torch.from_numpy(features_chw).float().unsqueeze(0)  # (1,C,H,W)
    t = F.interpolate(t, size=(int(Ht), int(Wt)), mode="bilinear", align_corners=False)
    return t.squeeze(0).cpu().numpy()

def cosine_template_match(test_feat_chw: np.ndarray, templ_feat_chw: np.ndarray, stride: int = 1) -> np.ndarray:
    if test_feat_chw.ndim != 3 or templ_feat_chw.ndim != 3:
        raise ValueError("Both feature maps must be (C,H,W).")
    Ct, Ht, Wt = test_feat_chw.shape
    Ck, Hk, Wk = templ_feat_chw.shape
    if Ct != Ck:
        raise ValueError(f"Channel mismatch: test C={Ct} vs templ C={Ck}")

    x = torch.from_numpy(test_feat_chw).float().unsqueeze(0)  # (1,C,Ht,Wt)
    k = torch.from_numpy(templ_feat_chw).float().unsqueeze(0) # (1,C,Hk,Wk)

    patches = F.unfold(x, kernel_size=(Hk, Wk), stride=stride)  # (1, C*Hk*Wk, L)
    v = k.view(1, -1, 1)                                       # (1, C*Hk*Wk, 1)

    dots = (patches * v).sum(dim=1, keepdim=True)              # (1,1,L)
    v_norm = torch.norm(v, dim=1, keepdim=True)                # (1,1,1)
    p_norm = torch.norm(patches, dim=1, keepdim=True) + 1e-8   # (1,1,L)
    cos = dots / (v_norm * p_norm + 1e-8)                      # (1,1,L)

    Hout = (Ht - Hk) // stride + 1
    Wout = (Wt - Wk) // stride + 1
    return cos.view(1, 1, Hout, Wout).squeeze().cpu().numpy()  # (Hout, Wout)

def visualize_similarity_heatmap(test_img: Image.Image, sim_map: np.ndarray, alpha: float = 0.65, title: str = "Cosine Similarity — DINOv3"):
    if sim_map.ndim != 2:
        raise ValueError(f"sim_map must be (H,W), got {sim_map.shape}")
    Himg, Wimg = test_img.size[1], test_img.size[0]
    sm = sim_map.astype(np.float32)
    sm_min, sm_max = float(sm.min()), float(sm.max())
    if sm_max == sm_min:
        sm_n = np.zeros_like(sm, dtype=np.uint8)
    else:
        sm_n = ((sm - sm_min) / (sm_max - sm_min) * 255.0).clip(0, 255).astype(np.uint8)
    heat = Image.fromarray(sm_n).resize((Wimg, Himg), resample=Image.BICUBIC)

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(test_img)
    plt.title("Test Image")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(test_img)
    plt.imshow(heat, cmap="viridis", alpha=alpha)
    plt.title(title)
    plt.axis("off")
    plt.show()

# -----------------------------
# Funzione principale richiesta
# -----------------------------
def cosine_similarity_from_bbox(
    query_image,                      # PIL o np.ndarray (RGB)
    bbox: tuple[int, int, int, int],  # (x0,y0,x1,y1) in pixel sulla query_image
    test_image,                       # PIL o np.ndarray (RGB)
    *,
    model: str = "dinov3",
    model_variant: str = "vit_base_patch16_224.dino",
    template_target_hw: tuple[int, int] | None = (9, 9),
    stride: int = 1,
    visualize: bool = True,
):
    """
    Estrae il crop da query_image via bbox, calcola feature (C,H,W) per crop e test_image
    usando DINOv3 (timm), poi esegue dense cosine template matching in feature-space.
    """
    # 1) Preparazione immagini
    q_pil = ensure_pil(query_image)
    t_pil = ensure_pil(test_image)

    x0, y0, x1, y1 = map(int, bbox)
    Wq, Hq = q_pil.size
    x0 = max(0, min(x0, Wq - 1)); x1 = max(0, min(x1, Wq - 1))
    y0 = max(0, min(y0, Hq - 1)); y1 = max(0, min(y1, Hq - 1))
    if x1 <= x0 or y1 <= y0:
        raise ValueError(f"BBox non valida: {bbox}")
    crop_pil = q_pil.crop((x0, y0, x1, y1))

    # 2) Backbone (supportiamo il caso richiesto: dinov3 via timm)
    if model.lower() != "dinov3":
        raise ValueError("Questa versione minimale implementa solo 'dinov3' (timm).")
    forward_fn, transform_fn, postproc_fn = load_dinov3_minimal(model_variant=model_variant)

    # 3) Estrazione feature
    templ_feat = extract_spatial_features_from_image(crop_pil, forward_fn, transform_fn, postproc_fn)  # (C,hk,wk)
    test_feat  = extract_spatial_features_from_image(t_pil,   forward_fn, transform_fn, postproc_fn)  # (C,ht,wt)

    # (Opz.) ridimensiona il template in feature-space
    if template_target_hw is not None:
        Htgt, Wtgt = template_target_hw
        templ_feat = resize_features_chw(templ_feat, (max(3, int(Htgt)), max(3, int(Wtgt))))

    # 4) Cosine template matching
    sim_map = cosine_template_match(test_feat, templ_feat, stride=stride)  # (H',W')

    # 5) Peak & visual
    y_f, x_f = np.unravel_index(sim_map.argmax(), sim_map.shape)  # coords nel feature-space
    if visualize:
        visualize_similarity_heatmap(t_pil, sim_map, alpha=0.65,
                                     title=f"Cosine Similarity — DINOv3 ({model_variant})")
        # opzionale: punto del picco
        plt.figure(figsize=(5,5))
        plt.imshow(t_pil); plt.axis("off")
        # upsample grezzo sim_map per mostrare il picco in pixel
        sm = sim_map.astype(np.float32)
        sm_n = (255.0 * (sm - sm.min()) / (sm.max() - sm.min() + 1e-8)).astype(np.uint8)
        up = Image.fromarray(sm_n).resize(t_pil.size, resample=Image.BICUBIC)
        up_np = np.array(up); yy, xx = np.unravel_index(up_np.argmax(), up_np.shape)
        plt.scatter([xx], [yy], s=80, c='red', marker='+')
        plt.title(f"Peak cos={sim_map.max():.3f}")
        plt.show()

    return {
        "sim_map": sim_map,
        "query_feat_shape": tuple(templ_feat.shape),
        "test_feat_shape": tuple(test_feat.shape),
        "peak_feature_xy": (int(y_f), int(x_f)),
        "bbox": (x0, y0, x1, y1),
    }


In [66]:
# # old 

# import numpy as np
# import torch
# import torch.nn.functional as F
# import matplotlib.pyplot as plt


# def cosine_similarity_from_bbox(
#     query_image,                      # PIL.Image o np.ndarray RGB (H,W,3)
#     bbox: tuple[int, int, int, int],  # (x0, y0, x1, y1) in pixel dell'immagine query
#     test_image,                       # PIL.Image o np.ndarray RGB
#     *,
#     model: str = "dinov2",            # 'clip' | 'resnet' | 'dinov2' | 'dinov3' | 'sam'
#     model_variant: str | None = None, # es. 'dinov2_vits14', 'openai/clip-vit-base-patch32', 'vit_base_patch16_224.dino'
#     sam_checkpoint: str | None = None,# richiesto se model == 'sam'
#     dinov3_hf_vit_repo: str | None = None,  # fallback HF per dinov3 se timm non disponibile
#     template_target_hw: tuple[int, int] | None = (9, 9),  # ridimensiona il template in feature-space (stabilizza/velocizza)
#     stride: int = 1,                  # stride in feature-space per il matching
#     visualize: bool = True            # se True, disegna heatmap sulla test image
# ):
#     """
#     Estrae il crop dalla query_image via bbox, calcola feature (C,H,W) per crop e test_image
#     con la backbone scelta, poi esegue dense cosine template matching nello spazio feature.

#     Ritorna:
#         {
#             'sim_map': np.ndarray (H', W') in [-1,1],           # mappa di similarità in feature-space
#             'label': str,                                       # etichetta backbone
#             'query_feat_shape': tuple,                          # (C, hq, wq)
#             'test_feat_shape': tuple,                           # (C, ht, wt)
#             'peak_feature_xy': (y_f, x_f),                      # posizione picco in sim_map (feature-space)
#             'peak_image_xy': (y_img, x_img),                    # posizione picco riportata in pixel immagine
#             'bbox': (x0, y0, x1, y1)                            # bbox usata
#         }
#     """
#     # --- 1) Sanity & conversioni ---
#     q_pil = ensure_pil(query_image)  if not isinstance(query_image, str) else ensure_pil(load_rgb_image(query_image))
#     t_pil = ensure_pil(test_image) if not isinstance(test_image, str) else ensure_pil(load_rgb_image(test_image))

#     x0, y0, x1, y1 = map(int, bbox)
#     Wq, Hq = q_pil.size
#     print("Shape query image: ", Wq, Hq)
#     x0 = max(0, min(x0, Wq - 1))
#     x1 = max(0, min(x1, Wq - 1))
#     y0 = max(0, min(y0, Hq - 1))
#     y1 = max(0, min(y1, Hq - 1))
#     if x1 <= x0 or y1 <= y0:
#         print(x0,x1, " ", y0,y1)
#         raise ValueError(f"BBox non valida: {bbox}")

#     crop_pil = q_pil.crop((x0, y0, x1, y1))

#     # --- 2) Carica backbone & trasformazioni ---
#     forward_fn, transform_fn, postproc_fn, label = get_backbone_and_transform(
#         model=model,
#         model_variant=model_variant,
#         sam_checkpoint=sam_checkpoint,
#         dinov3_hf_vit_repo=dinov3_hf_vit_repo,
#     )

#     # --- 3) Estrai feature (C,H,W) ---
#     templ_feat = extract_spatial_features_from_image(crop_pil, forward_fn, transform_fn, postproc_fn)  # (C,hk,wk)
#     test_feat  = extract_spatial_features_from_image(t_pil,  forward_fn, transform_fn, postproc_fn)   # (C,ht,wt)

#     # (opzionale) ridimensiona il template in feature-space per robustezza/velocità
#     if template_target_hw is not None:
#         Htgt, Wtgt = template_target_hw
#         Htgt = max(3, int(Htgt))
#         Wtgt = max(3, int(Wtgt))
#         templ_feat = resize_features_chw(templ_feat, (Htgt, Wtgt))

#     # --- 4) Cosine template matching denso nell spazio feature ---
#     sim_map = cosine_template_match(test_feat, templ_feat, stride=stride)  # (H', W')

#     # --- 5) Prendi il picco e rimappalo in coordinate immagine per comodità ---
#     y_f, x_f = np.unravel_index(sim_map.argmax(), sim_map.shape)  # feature-space (output dell'unfold)
#     # upsample grezza della sim_map alla risoluzione immagine (così otteniamo un (x,y) in pixel)
#     sm = sim_map.astype(np.float32)
#     sm_min, sm_max = float(sm.min()), float(sm.max())
#     if sm_max == sm_min:
#         sm_norm = np.zeros_like(sm, dtype=np.uint8)
#     else:
#         sm_norm = ((sm - sm_min) / (sm_max - sm_min) * 255.0).clip(0, 255).astype(np.uint8)
#     heat = Image.fromarray(sm_norm).resize(t_pil.size, resample=Image.BICUBIC)  # (Wimg, Himg)
#     heat_np = np.array(heat)
#     y_img, x_img = np.unravel_index(heat_np.argmax(), heat_np.shape)  # (Himg, Wimg)

#     # --- 6) Visualizza opzionalmente ---
#     if visualize:
#         visualize_similarity_heatmap(t_pil, sim_map, alpha=0.65,
#                                      title=f"Cosine Similarity — {label}")

#         # marca il picco sulla test image
#         plt.figure(figsize=(6, 6))
#         plt.imshow(t_pil)
#         plt.scatter([x_img], [y_img], s=80, c='red', marker='+')
#         plt.title(f"Peak @ ({x_img}, {y_img}) | cos={sim_map.max():.3f}")
#         plt.axis("off")
#         plt.show()

#     return {
#         "sim_map": sim_map,
#         "label": label,
#         "query_feat_shape": tuple(templ_feat.shape),
#         "test_feat_shape": tuple(test_feat.shape),
#         "peak_feature_xy": (int(y_f), int(x_f)),
#         "peak_image_xy": (int(y_img), int(x_img)),
#         "bbox": (x0, y0, x1, y1),
#     }


# def features_cosine_map_from_patch(features_chw: np.ndarray, py: int, px: int) -> np.ndarray:
#     """
#     Calcola la mappa di cosine similarity tra il patch (py, px) e tutti
#     gli altri patch nello stesso feature map (C,H,W).
#     Ritorna una mappa (H, W) con valori in [-1, 1].
#     """
#     if features_chw.ndim != 3:
#         raise ValueError(f"Expected (C,H,W), got {features_chw.shape}")
#     C, H, W = features_chw.shape
#     if not (0 <= py < H and 0 <= px < W):
#         raise ValueError(f"Patch index out of range: py={py}, px={px}, H={H}, W={W}")

#     f = torch.from_numpy(features_chw).float()      # (C,H,W)
#     f = F.normalize(f, dim=0)                       # L2-normalize lungo C per ogni loc

#     q = f[:, py, px]                                # (C,)
#     q = F.normalize(q, dim=0)                       # (C,)

#     # cosine = dot(q, f[:,y,x]) per tutti (y,x) -> equivalente a (q^T @ f) lungo C
#     sim = (q.view(C, 1, 1) * f).sum(dim=0)          # (H, W)
#     return sim.cpu().numpy()

# def image_to_feature_coords(img_hw: tuple[int,int], feat_hw: tuple[int,int], y_img: int, x_img: int) -> tuple[int,int]:
#     """
#     Converte coordinate immagine (pixel) in coordinate feature (patch).
#     Usa scalatura nearest neighbor.
#     """
#     H_img, W_img = img_hw
#     Hf, Wf = feat_hw
#     py = int(round((y_img / max(1, H_img - 1)) * (Hf - 1)))
#     px = int(round((x_img / max(1, W_img - 1)) * (Wf - 1)))
#     py = max(0, min(Hf - 1, py))
#     px = max(0, min(Wf - 1, px))
#     return py, px

# def visualize_patch_similarity(test_img: Image.Image,
#                                sim_map: np.ndarray,
#                                y_img: int,
#                                x_img: int,
#                                alpha: float = 0.65,
#                                title: str = "DINOv3: Cosine map from red-cross patch"):
#     """
#     Mostra la mappa di similarità (ridimensionata all'immagine) e la croce rossa
#     nel punto sorgente (y_img, x_img).
#     """
#     if sim_map.ndim != 2:
#         raise ValueError(f"sim_map must be (H,W), got {sim_map.shape}")

#     H_img, W_img = test_img.size[1], test_img.size[0]
#     sm = sim_map.astype(np.float32)
#     sm_min, sm_max = float(sm.min()), float(sm.max())
#     if sm_max == sm_min:
#         sm_n = np.zeros_like(sm, dtype=np.uint8)
#     else:
#         sm_n = ((sm - sm_min) / (sm_max - sm_min) * 255.0).clip(0, 255).astype(np.uint8)
#     heat = Image.fromarray(sm_n).resize((W_img, H_img), resample=Image.BICUBIC)

#     plt.figure(figsize=(12, 6))
#     plt.subplot(1, 2, 1)
#     plt.imshow(test_img)
#     plt.scatter([x_img], [y_img], s=60, c='red', marker='+')  # croce rossa
#     plt.title("Red-cross patch (reference)")
#     plt.axis("off")

#     plt.subplot(1, 2, 2)
#     plt.imshow(test_img)
#     plt.imshow(heat, cmap="viridis", alpha=alpha)
#     plt.scatter([x_img], [y_img], s=60, c='red', marker='+')
#     plt.title(title)
#     plt.axis("off")
#     plt.show()


In [None]:
# 1) Scegli DINOv3 (o altra backbone che dia (C,H,W))
# model_name = "dinov3"
# model_variant = "vit_base_patch16_224.dino"   # es. timm
# forward_fn, transform_fn, postproc_fn, label = get_backbone_and_transform(
#     model=model_name,
#     model_variant=model_variant,
#     dinov3_hf_vit_repo=None,     # opzionale fallback HF
# )

# # 2) Assicurati che test_image sia PIL
# test_pil = ensure_pil(load_rgb_image(test_image))

# # 3) Estrai feature spaziali della test image
# test_feat = extract_spatial_features_from_image(test_pil, forward_fn, transform_fn, postproc_fn)  # (C,Hf,Wf)
# C, Hf, Wf = test_feat.shape
# print(f"{label} features: {test_feat.shape}")

# # 4) Scegli la posizione della croce rossa in coordinate IMM (pixel)
# #    (Puoi prenderle da un click interattivo o fissarle manualmente)

# query_image_pre_bbox = load_rgb_image(query_image_pre_bbox)

# x1,y1, x2, y2 = current_box
# y_img, x_img = (y1+y2)//2, (x1+x2)//2
# # y_img, x_img = 200, 300   # <-- rimpiazza con le tue coordinate immagine


# # 5) Converti coordinate immagine -> coordinate feature (patch indices)
# py, px = image_to_feature_coords((query_image_pre_bbox.size[1], query_image_pre_bbox.size[0]), (Hf, Wf), y_img, x_img)
# print(f"Red-cross patch @ feature coords: (py={py}, px={px}) / (Hf={Hf}, Wf={Wf})")

# # 6) Calcola la cosine similarity tra il patch selezionato e tutti gli altri patch
# sim_map = features_cosine_map_from_patch(test_feat, py=py, px=px)     # (Hf, Wf)
# print(f"sim_map: shape={sim_map.shape}, min={sim_map.min():.3f}, max={sim_map.max():.3f}")

# # 7) Visualizza: croce rossa + heatmap
# visualize_patch_similarity(test_pil, sim_map, y_img=y_img, x_img=x_img,
#                            alpha=0.65,
#                            title=f"Cosine Similarity from red-cross — {label}")


query_image = ensure_pil(query_image)  if not isinstance(query_image, str) else ensure_pil(load_rgb_image(query_image))
test_image = ensure_pil(test_image) if not isinstance(test_image, str) else ensure_pil(load_rgb_image(test_image))



res = cosine_similarity_from_bbox(
    query_image=query_image_pre_bbox,   # immagine sorgente (non croppata)
    bbox=current_box,                   # (x0,y0,x1,y1) in pixel sulla query
    test_image=test_image,              # immagine target
    model="dinov3",
    model_variant="vit_base_patch16_224.dino",
    template_target_hw=(9, 9),
    stride=1,
    visualize=True,
)

In [None]:
current_box

In [None]:
scales = [0.75, 1.0, 1.25]
best = None
for s in scales:
    Hs = max(3, int(round(templ_feat.shape[1] * s)))
    Ws = max(3, int(round(templ_feat.shape[2] * s)))
    tf = resize_features_chw(query_feat, (Hs, Ws))
    sm = cosine_template_match(test_feat, tf, stride=1)
    m = sm.max()
    if (best is None) or (m > best[0]):
        best = (m, sm, s)

print(f"Best scale={best[2]} | best max cos={best[0]:.3f}")
visualize_similarity_heatmap(test_pil, best[1], alpha=0.65,
                             title=f"Cosine Similarity (Best scale={best[2]}) — {label}")


In [None]:
x1,y1, x2, y2 = current_box
y_img, x_img = (y1+y2)//2, (x1+x2)//2

print(current_box)
print(y_img, x_img)