In [None]:
!pip install -U unsloth
!pip install transformers==4.57.1 trl accelerate peft bitsandbytes
!pip install sentencepiece einops timm qwen-vl-utils
!pip install pillow matplotlib
%matplotlib inline

In [None]:
import json
import torch
# import unsloth
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# from transformers import AutoProcessor
from transformers import Qwen2_5_VLProcessor
from unsloth import FastVisionModel
from IPython.display import display
import traceback
from transformers import Qwen2_5_VLProcessor

from google.colab import drive
drive.mount("/content/drive", force_remount=True)
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
BASE_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
ADAPTER_PATH = "/content/drive/MyDrive/qwen-nrp-output"
TEST_JSONL   = "/content/drive/MyDrive/test-colab.jsonl"

# Load ONLY the model
model, _ = FastVisionModel.from_pretrained(
    BASE_MODEL_ID,
    load_in_4bit=True,
    device_map="auto",
)
model.load_adapter(ADAPTER_PATH)
model.eval()

# Load the CORRECT processor explicitly
processor = Qwen2_5_VLProcessor.from_pretrained(BASE_MODEL_ID)
print("Model loaded successfully.")

In [None]:
def load_jsonl(path):
    with open(path, "r") as f:
        return [json.loads(line) for line in f]

# Load test samples
print("Loading test dataset...")
samples = load_jsonl(TEST_JSONL)
print(f"Loaded {len(samples)} test samples.")

In [None]:
def run_inference(sample):
    """Run inference on a single JSONL sample."""

    user_msg = next(m for m in sample["messages"] if m["role"] == "user")

    image_path = None
    for item in user_msg["content"]:
        if item["type"] == "image":
            image_path = item["image"]

    if image_path is None:
        raise ValueError("No image path found in sample")

    image = Image.open(image_path).convert("RGB")

    messages = [
        {
            "role": "system",
            "content": (
                "You are an assistant that detects jam causing objects in images. "
                "The possible jam causing objects are: fabrics, rigid-plastic, "
                "non-recyclables, large-plastic-films, unopened-plastic-bags, "
                "metal, wrappables, wood."
            )
        },
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {
                    "type": "text",
                    "text": (
                        "Detect the waste objects in this image and output their "
                        "bounding boxes in the format: class_name xmin xmax ymin ymax"
                    )
                }
            ]
        }
    ]

    text = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = processor(
        text=[text],
        images=[image],
        return_tensors="pt",
        padding=True
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=False
        )

    generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
    result = processor.decode(generated_ids, skip_special_tokens=True)

    return image, result

In [None]:
def parse_boxes(text):
    boxes = []
    lines = text.strip().splitlines()
    for i, line in enumerate(lines):
        parts = line.strip().split()
        if len(parts) != 5:
            print(f"Skipping invalid line {i+1}: {repr(line)} (expected 5 parts)")
            continue
        cls, *coords = parts
        try:
            xmin, xmax, ymin, ymax = map(int, coords)
            if xmin > xmax or ymin > ymax:
                print(f"Invalid coordinates (xmin>xmax or ymin>ymax): {line}")
                continue
            boxes.append((cls, xmin, xmax, ymin, ymax))
        except ValueError:
            print(f"Non-integer coordinates in line {i+1}: {line}")
            continue
    return boxes
def parse_ground_truth(sample):
    try:
        gt_msg = next(m for m in sample["messages"] if m["role"] == "assistant")
        return parse_boxes(gt_msg["content"])
    except StopIteration:
        print("No assistant message found in sample.")
        return []

In [None]:
def visualize_comparison(image, gt_boxes, pred_boxes, title):
    fig, axes = plt.subplots(1, 2, figsize=(18, 8))

    # Ground Truth
    axes[0].imshow(image)
    axes[0].set_title("Ground Truth", fontsize=14)
    for cls, xmin, xmax, ymin, ymax in gt_boxes:
        rect = patches.Rectangle(
            (xmin, ymin),
            xmax - xmin,
            ymax - ymin,
            linewidth=2,
            edgecolor="green",
            facecolor="none"
        )
        axes[0].add_patch(rect)
        axes[0].text(
            xmin, ymin - 5,
            cls,
            color="green",
            fontsize=10,
            backgroundcolor="white"
        )
    axes[0].axis("off")

    # Prediction
    axes[1].imshow(image)
    axes[1].set_title("Model Prediction", fontsize=14)
    for cls, xmin, xmax, ymin, ymax in pred_boxes:
        rect = patches.Rectangle(
            (xmin, ymin),
            xmax - xmin,
            ymax - ymin,
            linewidth=2,
            edgecolor="red",
            facecolor="none"
        )
        axes[1].add_patch(rect)
        axes[1].text(
            xmin, ymin - 5,
            cls,
            color="red",
            fontsize=10,
            backgroundcolor="white"
        )
    axes[1].axis("off")

    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()  # Use plt.show() for reliable Colab display
    plt.close(fig)

In [None]:
print("Running Model Inference on Test Samples...")
num_samples = 10 # Test x amount of samples

for idx in range(min(num_samples, len(samples))):
    print(f"\n--- Sample {idx+1} ---")
    try:
        sample = samples[idx]

        # Get image and model prediction
        image, pred_text = run_inference(sample)
        print(f"Raw model output:\n{repr(pred_text)}\n")

        # Parse GT and predicted boxes
        gt_boxes = parse_ground_truth(sample)
        pred_boxes = parse_boxes(pred_text)

        print(f"Original GroundTruth boxes: {gt_boxes}")
        print(f"Predicted boxes: {pred_boxes}")

        # Visualize comparison
        visualize_comparison(
            image,
            gt_boxes,
            pred_boxes,
            title=f"Sample {idx+1}: Ground Truth vs Model Prediction"
        )

    except Exception as e:
        print(f"Error processing sample {idx+1}: {e}")
        import traceback
        traceback.print_exc()
        continue