In [1]:
import copy
import json
import logging
import os
import torch
import numpy as np
import cv2
import torch.nn.functional as F

from PIL import Image, ImageDraw
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize
from sklearn.metrics.pairwise import cosine_similarity


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
MODEL_NAME_OR_PATH = "Qwen/Qwen3-VL-8B-Instruct"

BASE_CLIP = "/home/phamlong/Downloads/clip finetune/clip-vit-base-patch32"
CKPT_PATH = "/home/phamlong/Downloads/clip finetune/clip_finetuned/clip_finetuned_epoch5.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

NUM_GRID = 36
TOP_K_REGIONS = 3
REGION_THRESH = 0.25
CROP_PAD = 8
FEATHER_SIGMA = 8

SCREENSPOT_IMGS = "/home/phamlong/Downloads/v2p/ScreenSpot-Pro/images"
SCREENSPOT_TEST = "/home/phamlong/Downloads/v2p/ScreenSpot-Pro/annotations"

LOG_PATH = "/home/phamlong/Downloads/zoom/ZoomClick/grounding/zoomclick_qwen3vl_8b.json"
VIS_DIR = "/home/phamlong/Downloads/zoom/ZoomClick/grounding/vis"
os.makedirs(VIS_DIR, exist_ok=True)

logging.basicConfig(level=logging.INFO)
torch.manual_seed(114514)

# ================= LOAD CLIP =================
print("Loading finetuned CLIP...")
clip_model = CLIPModel.from_pretrained(BASE_CLIP)
clip_processor = CLIPProcessor.from_pretrained(BASE_CLIP)

if os.path.exists(CKPT_PATH):
    state = torch.load(CKPT_PATH, map_location="cpu")
    if "model" in state:
        state = state["model"]
    clip_model.load_state_dict(state, strict=False)

clip_model = clip_model.to(DEVICE).eval()
torch.set_grad_enabled(False)

# ================= UTILS =================
def _norm_to_pixel_point(p, img_size):
    x, y = p
    w, h = img_size
    return int(round(x * w)), int(round(y * h))


def remap_point(local_pt, roi_bbox):
    x, y = local_pt
    x1, y1, _, _ = roi_bbox
    return x + x1, y + y1

Loading finetuned CLIP...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [3]:
def hybrid_vector_field_grid(image, instruction):
    img = image.convert("RGB")
    W, H = img.size
    pw, ph = W // NUM_GRID, H // NUM_GRID

    patches = []
    for gy in range(NUM_GRID):
        for gx in range(NUM_GRID):
            x1, y1 = gx * pw, gy * ph
            x2, y2 = x1 + pw, y1 + ph
            patches.append(img.crop((x1, y1, x2, y2)))

    text_inputs = clip_processor(text=instruction, return_tensors="pt").to(DEVICE)
    text_emb = clip_model.get_text_features(**text_inputs)
    text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True)
    text_emb = text_emb.cpu().numpy()

    feats = []
    for i in range(0, len(patches), 8):
        inputs = clip_processor(images=patches[i:i+8], return_tensors="pt").to(DEVICE)
        emb = clip_model.get_image_features(**inputs)
        emb = emb / emb.norm(dim=-1, keepdim=True)
        feats.append(emb.cpu().numpy())

    patch_embs = np.concatenate(feats, axis=0)
    cos_sim = cosine_similarity(patch_embs, text_emb).flatten()

    gates = 1 / (1 + np.exp(-3.0 * (cos_sim - 0.5)))
    gated = patch_embs * gates[:, None]

    N = gated.shape[0]
    F_field = np.zeros_like(gated)

    for i in range(N):
        diff = gated - gated[i]
        dist = np.linalg.norm(diff, axis=1)
        w = np.exp(-dist**2 / 0.4**2) * gates
        F_field[i] = np.sum(w[:, None] * diff, axis=0)

    F_field /= (np.linalg.norm(F_field, axis=1, keepdims=True) + 1e-8)

    div = np.zeros(N)
    for i in range(N):
        idx = np.argsort(np.linalg.norm(F_field - F_field[i], axis=1))[:8]
        grad = (F_field[idx] - F_field[i]).mean(axis=0)
        div[i] = np.dot(grad, F_field[i])

    score = np.log1p(np.maximum(0, -div)) * cos_sim * gates
    score = (score - score.min()) / (score.max() - score.min() + 1e-8)

    return score.reshape(NUM_GRID, NUM_GRID), (H, W)


def spotlight_bbox(image, instruction):
    grid_scores, (H, W) = hybrid_vector_field_grid(image, instruction)

    binary = (grid_scores >= REGION_THRESH).astype(np.uint8)
    num_labels, labels = cv2.connectedComponents(binary)

    regions = []
    for l in range(1, num_labels):
        mask = labels == l
        area = mask.sum()
        if area < 3:
            continue
        score = grid_scores[mask].mean() * area
        regions.append((score, l))

    regions = sorted(regions, reverse=True)[:TOP_K_REGIONS]
    if not regions:
        return None, None

    patch_mask = np.zeros_like(grid_scores, dtype=np.float32)
    for _, l in regions:
        patch_mask += (labels == l).astype(np.float32)
    patch_mask = np.clip(patch_mask, 0, 1)

    t = torch.tensor(patch_mask)[None, None]
    mask_full = F.interpolate(t, size=(H, W), mode="bicubic", align_corners=False)
    mask_full = mask_full.squeeze().numpy()

    mask_full = (mask_full - mask_full.min()) / (mask_full.max() - mask_full.min() + 1e-8)
    mask_blur = cv2.GaussianBlur(mask_full, (0, 0), FEATHER_SIGMA)

    bin_mask = mask_blur > REGION_THRESH
    ys, xs = np.where(bin_mask)
    if len(xs) == 0:
        return None, mask_blur

    bbox = (
        max(0, xs.min() - CROP_PAD),
        max(0, ys.min() - CROP_PAD),
        min(W - 1, xs.max() + CROP_PAD),
        min(H - 1, ys.max() + CROP_PAD),
    )

    return bbox, mask_blur


def build_mask(mask_blur, case):
    if case == "no_mask" or mask_blur is None:
        return None
    if case == "hard_mask":
        return (mask_blur > REGION_THRESH).astype(np.uint8)
    raise ValueError(case)

# ================= VIS =================
def visualize_full(image, mask, roi_bbox, pred_pt, gt_bbox, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    img_np = np.array(image).astype(np.float32) / 255.0

    if mask is not None:
        img_np[mask == 0] = 0.0

    vis = Image.fromarray((img_np * 255).astype(np.uint8))
    draw = ImageDraw.Draw(vis)

    if roi_bbox:
        draw.rectangle(roi_bbox, outline="yellow", width=3)
    draw.rectangle(gt_bbox, outline="green", width=3)

    if pred_pt:
        x, y = pred_pt
        draw.ellipse((x-5, y-5, x+5, y+5), fill="red")

    vis.save(save_path)

In [4]:
class BaseBackend:
    def infer_norm_point(self, instruction, pil_img):
        raise NotImplementedError


class Qwen3Backend(BaseBackend):
    def __init__(self, model):
        self.model = model
        self.processor = model.processor

    def infer_norm_point(self, instruction, pil_img):
        res = self.model.ground_only_positive(instruction=instruction, image=pil_img)
        if not isinstance(res, dict):
            return None, res
        pt = res.get("point")
        if pt is None:
            return None, res

        img_w, img_h = pil_img.size
        try:
            patch = self.processor.image_processor.patch_size
            merge = self.processor.image_processor.merge_size
            resized_h, resized_w = smart_resize(
                img_h, img_w,
                factor=patch * merge,
                min_pixels=patch * patch * merge * merge * 16,
                max_pixels=patch * patch * merge * merge * 6400,
            )
        except Exception:
            resized_w, resized_h = img_w, img_h

        return (pt[0] / resized_w, pt[1] / resized_h), res


class DirectInferenceRunner:
    def __init__(self, backend):
        self.backend = backend

    def ground_only_positive(self, instruction, image):
        pt_norm, raw = self.backend.infer_norm_point(instruction, image)
        if pt_norm is None:
            return {"result": "negative", "point": None}
        px = _norm_to_pixel_point(pt_norm, image.size)
        return {"result": "positive", "point": [px[0], px[1]]}

# ================= METRIC =================
def eval_sample_positive_gt(sample, response):
    if response["point"] is None:
        return "wrong_format"
    x, y = response["point"]
    x1, y1, x2, y2 = sample["bbox"]
    return "correct" if x1 <= x <= x2 and y1 <= y <= y2 else "wrong"


def calc_metric(results):
    total = len(results)
    correct = sum(r["correctness"] == "correct" for r in results)
    return {"num_total": total, "num_correct": correct, "acc": correct / total if total else 0}


In [None]:
def build_backend():
    from models.qwen3vl import Qwen3VLModel
    model = Qwen3VLModel()
    model.load_model(model_path=MODEL_NAME_OR_PATH)
    return Qwen3Backend(model)

backend = build_backend()
runner = DirectInferenceRunner(backend)

# ================= LOAD DATA =================
task_files = [f[:-5] for f in os.listdir(SCREENSPOT_TEST) if f.endswith(".json")]
samples = []
for tf in task_files:
    with open(os.path.join(SCREENSPOT_TEST, tf + ".json")) as f:
        for s in json.load(f):
            ss = copy.deepcopy(s)
            ss["prompt_to_evaluate"] = s.get("instruction", "")
            samples.append(ss)

def to_json_safe(obj):
    if isinstance(obj, dict):
        return {k: to_json_safe(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [to_json_safe(v) for v in obj]
    if isinstance(obj, tuple):
        return [to_json_safe(v) for v in obj]
    if isinstance(obj, np.integer):
        return int(obj)
    if isinstance(obj, np.floating):
        return float(obj)
    return obj
CASES = ["no_mask", "hard_mask"]

for case in CASES:
    print(f"\n=== Running case: {case} ===")
    case_results = []

    case_vis_dir = os.path.join(VIS_DIR, case)
    os.makedirs(case_vis_dir, exist_ok=True)

    for s in tqdm(samples):
        # 1️⃣ Load ảnh gốc
        img = Image.open(
            os.path.join(SCREENSPOT_IMGS, s["img_filename"])
        ).convert("RGB")

        roi_bbox, mask_blur = spotlight_bbox(
            img, s["prompt_to_evaluate"]
        )

        img_for_model = img

        if case == "hard_mask" and mask_blur is not None:
            # hard mask: 0 / 1
            hard = (mask_blur > REGION_THRESH).astype(np.uint8)

   
            if hard.ndim == 2:
                hard = hard[..., None]

            img_np = np.array(img_for_model)
            img_np = img_np * hard 

            img_for_model = Image.fromarray(img_np)

        # 4️⃣ Crop SAU KHI mask
        crop_img = (
            img_for_model
            if roi_bbox is None
            else img_for_model.crop(roi_bbox)
        )

        # 5️⃣ Inference
        r_crop = runner.ground_only_positive(
            s["prompt_to_evaluate"],
            crop_img
        )

        # 6️⃣ Remap point về ảnh gốc
        if r_crop["point"] is not None:
            r_crop["point"] = list(
                remap_point(r_crop["point"], roi_bbox)
            )

        correctness = eval_sample_positive_gt(s, r_crop)

        # 7️⃣ Visualization (KHÔNG ảnh hưởng model)
        vis_mask = build_mask(mask_blur, case)
        if vis_mask is not None:
            vis_mask = torch.from_numpy(vis_mask)

        visualize_full(
            img,
            vis_mask,
            roi_bbox,
            r_crop["point"],
            s["bbox"],
            os.path.join(
                case_vis_dir,
            f"{os.path.splitext(s['img_filename'])[0]}_{correctness}.jpg"
            )
        )

        case_results.append({
            **s,
            "pred": r_crop["point"],
            "correctness": correctness
        })

    # 8️⃣ Save JSON
    case_results = to_json_safe(case_results)
    case_log = LOG_PATH.replace(".json", f"_{case}.json")

    with open(case_log, "w") as f:
        json.dump(
            {
                "case": case,
                "metrics": calc_metric(case_results),
                "details": case_results
            },
            f,
            indent=2
        )

print("DONE")

`torch_dtype` is deprecated! Use `dtype` instead!
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.43it/s]



=== Running case: no_mask ===


  0%|          | 0/1581 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
100%|██████████| 1581/1581 [3:25:52<00:00,  7.81s/it]  



=== Running case: hard_mask ===


100%|██████████| 1581/1581 [3:27:16<00:00,  7.87s/it]  

DONE



