Inference/Test notebook used in google colab for mAP@50 & mAP@50-95 metrics

In [None]:
!pip install -U unsloth
!pip install transformers==4.57.1 trl accelerate peft bitsandbytes
!pip install sentencepiece einops timm qwen-vl-utils
!pip install pillow matplotlib
!pip install mean-average-precision

In [None]:
%matplotlib inline
import json
import torch
import numpy as np
# import unsloth
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from transformers import Qwen2_5_VLProcessor
from unsloth import FastVisionModel
from IPython.display import display
import traceback
from google.colab import drive
from mean_average_precision import MetricBuilder  # <-- ADDED

drive.mount("/content/drive", force_remount=True)
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
BASE_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
ADAPTER_PATH = "/content/drive/MyDrive/qwen-nrp-output"
TEST_JSONL   = "/content/drive/MyDrive/test-colab.jsonl"

# Class names (must match your training set)
# Order might matter?
CLASSES = [
    "fabrics", "rigid-plastic", "non-recyclables", "large-plastic-films",
    "unopened-plastic-bags", "metal", "wrappables", "wood"
]
CLASS_TO_IDX = {name: i for i, name in enumerate(CLASSES)}

print("Loading model...")
model, _ = FastVisionModel.from_pretrained(
    BASE_MODEL_ID,
    load_in_4bit=True,
    device_map="auto",
)
model.load_adapter(ADAPTER_PATH)
model.eval()

processor = Qwen2_5_VLProcessor.from_pretrained(BASE_MODEL_ID)
print("Model loaded successfully.")

In [None]:
def load_jsonl(path):
    with open(path, "r") as f:
        return [json.loads(line) for line in f]

# Load test samples
print("Loading test dataset...")
samples = load_jsonl(TEST_JSONL)
print(f"Loaded {len(samples)} test samples.")

In [None]:
def run_inference(sample):
    user_msg = next(m for m in sample["messages"] if m["role"] == "user")
    image_path = None
    for item in user_msg["content"]:
        if item["type"] == "image":
            image_path = item["image"]
    if image_path is None:
        raise ValueError("No image path found in sample")
    image = Image.open(image_path).convert("RGB")

    messages = [
        {
            "role": "system",
            "content": (
                "You are an assistant that detects jam causing objects in images. "
                "The possible jam causing objects are: fabrics, rigid-plastic, "
                "non-recyclables, large-plastic-films, unopened-plastic-bags, "
                "metal, wrappables, wood."
            )
        },
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {
                    "type": "text",
                    "text": (
                        "Detect the waste objects in this image and output their "
                        "bounding boxes in the format: class_name xmin xmax ymin ymax"
                    )
                }
            ]
        }
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True).to(model.device)

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
    generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
    result = processor.decode(generated_ids, skip_special_tokens=True)
    return image, result

In [None]:
def parse_boxes(text):
    """Parse raw model output or GT into list of (cls, xmin, xmax, ymin, ymax)."""
    boxes = []
    lines = text.strip().splitlines()
    for i, line in enumerate(lines):
        parts = line.strip().split()
        if len(parts) != 5:
            continue
        cls, *coords = parts
        try:
            xmin, xmax, ymin, ymax = map(int, coords)
            if xmin > xmax or ymin > ymax:
                continue
            boxes.append((cls, xmin, xmax, ymin, ymax))
        except ValueError:
            continue
    return boxes

def parse_ground_truth(sample):
    try:
        gt_msg = next(m for m in sample["messages"] if m["role"] == "assistant")
        return parse_boxes(gt_msg["content"])
    except StopIteration:
        return []

In [None]:
def parse_for_map(text, default_score=1.0):
    """Convert parsed boxes to (x1, y1, x2, y2, score, class_id) for mAP."""
    preds = []
    boxes = parse_boxes(text)
    for cls, xmin, xmax, ymin, ymax in boxes:
        if cls not in CLASS_TO_IDX:
            continue
        # Reorder from (xmin, xmax, ymin, ymax) → (xmin, ymin, xmax, ymax)
        x1, y1, x2, y2 = xmin, ymin, xmax, ymax
        preds.append([x1, y1, x2, y2, default_score, CLASS_TO_IDX[cls]])
    return np.array(preds) if preds else np.empty((0, 6))

def parse_gt_for_map(sample):
    """Convert GT to (x1, y1, x2, y2, class_id) → then add difficult/crowd flags."""
    gts = []
    boxes = parse_ground_truth(sample)
    for cls, xmin, xmax, ymin, ymax in boxes:
        if cls not in CLASS_TO_IDX:
            continue
        x1, y1, x2, y2 = xmin, ymin, xmax, ymax
        gts.append([x1, y1, x2, y2, CLASS_TO_IDX[cls]])
    if gts:
        gts = np.array(gts)
        # Add difficult=0, crowd=0
        gts = np.column_stack([gts, np.zeros(len(gts)), np.zeros(len(gts))])
    else:
        gts = np.empty((0, 7))
    return gts

In [None]:
def visualize_comparison(image, gt_boxes, pred_boxes, title):
    fig, axes = plt.subplots(1, 2, figsize=(18, 8))
    # Ground Truth
    axes[0].imshow(image)
    axes[0].set_title("Ground Truth", fontsize=14)
    for cls, xmin, xmax, ymin, ymax in gt_boxes:
        rect = patches.Rectangle(
            (xmin, ymin),
            xmax - xmin,
            ymax - ymin,
            linewidth=2,
            edgecolor="green",
            facecolor="none"
        )
        axes[0].add_patch(rect)
        axes[0].text(xmin, ymin - 5, cls, color="green", fontsize=10, backgroundcolor="white")
    axes[0].axis("off")
    # Prediction
    axes[1].imshow(image)
    axes[1].set_title("Model Prediction", fontsize=14)
    for cls, xmin, xmax, ymin, ymax in pred_boxes:
        rect = patches.Rectangle(
            (xmin, ymin),
            xmax - xmin,
            ymax - ymin,
            linewidth=2,
            edgecolor="red",
            facecolor="none"
        )
        axes[1].add_patch(rect)
        axes[1].text(xmin, ymin - 5, cls, color="red", fontsize=10, backgroundcolor="white")
    axes[1].axis("off")
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()
    plt.close(fig)

In [None]:
print("Loading test dataset...")
samples = load_jsonl(TEST_JSONL)
print(f"Loaded {len(samples)} test samples.")

# Initialize mAP metric
metric_fn = MetricBuilder.build_evaluation_metric("map_2d", async_mode=False, num_classes=len(CLASSES))

NUM_VIS_SAMPLES = 1  # Display N number of samples
TOTAL_SAMPLES = 10   # len(samples) <- will take a long time to run

print("\n" + "="*60)
print("Running Inference + mAP Evaluation...")
print("="*60)

for idx in range(TOTAL_SAMPLES):
    try:
        sample = samples[idx]
        image, pred_text = run_inference(sample)

        # Parse for mAP
        detections = parse_for_map(pred_text)
        ground_truths = parse_gt_for_map(sample)
        metric_fn.add(detections, ground_truths)

        # Visualize only first N
        if idx < NUM_VIS_SAMPLES:
            print(f"\n--- Sample {idx+1} ---")
            print(f"Raw model output:\n{repr(pred_text)}\n")
            gt_boxes = parse_ground_truth(sample)
            pred_boxes = parse_boxes(pred_text)
            print(f"GT boxes: {gt_boxes}")
            print(f"Pred boxes: {pred_boxes}")
            visualize_comparison(image, gt_boxes, pred_boxes, f"Sample {idx+1}")

    except Exception as e:
        print(f"Error on sample {idx}: {e}")
        traceback.print_exc()
        continue

In [None]:
print("\n\n" + "="*60)
print("FINAL mAP RESULTS")
print("="*60)

# mAP@50
result_50 = metric_fn.value(iou_thresholds=[0.5])
map50 = result_50["mAP"]

# mAP@50-95
iou_range = np.arange(0.5, 1.0, 0.05)
result_all = metric_fn.value(iou_thresholds=iou_range)
map5095 = result_all["mAP"]

print(f"mAP@50 results    : {map50:.3f}")
print(f"mAP@50-95 results : {map5095:.3f}")