In [1]:
%pip install -U cragmm-search-pipeline
%pip install datasets torch transformers matplotlib accelerate ipywidgets

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from datasets import load_dataset
import matplotlib.pyplot as plt

single_turn_dataset = load_dataset("crag-mm-2025/crag-mm-single-turn-public", revision="v0.1.1")

In [3]:
"""
from cragmm_search.search import UnifiedSearchPipeline

search_pipeline = UnifiedSearchPipeline(
    image_model_name="openai/clip-vit-large-patch14-336",
    image_hf_dataset_id="crag-mm-2025/image-search-index-validation",
    text_model_name="sentence-transformers/all-MiniLM-L6-v2",
    web_hf_dataset_id="crag-mm-2025/web-search-index-validation",
)
"""

'\nfrom cragmm_search.search import UnifiedSearchPipeline\n\nsearch_pipeline = UnifiedSearchPipeline(\n    image_model_name="openai/clip-vit-large-patch14-336",\n    image_hf_dataset_id="crag-mm-2025/image-search-index-validation",\n    text_model_name="sentence-transformers/all-MiniLM-L6-v2",\n    web_hf_dataset_id="crag-mm-2025/web-search-index-validation",\n)\n'

In [4]:
from huggingface_hub import login
login(token="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")

In [None]:
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers import AutoProcessor, MllamaForConditionalGeneration, MllamaConfig
from transformers.utils import ModelOutput
from torch import nn

@dataclass
class CausalLMOutputWithPastAndBBox(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    bbox_needed_logits: torch.FloatTensor = None
    bbox_coords: torch.FloatTensor = None

class BoundingBoxHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.bbox_needed_head = nn.Linear(4096, 1)

        # [x1, y1, x2, y2] pixel relative coordinates (0-1) to input image size.
        self.bbox_regression_head = nn.Linear(4096, 4)

        self.dtype = self.bbox_needed_head.weight.dtype

    def forward(self, attention_mask, last_hidden_states, bbox_binary_label = None, bbox_coords_label = None):
        attention_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).to(self.dtype)
        sub_embeddings = torch.sum(last_hidden_states * attention_mask_expanded, dim=1)
        sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
        pooled_output = sub_embeddings / sum_mask # (batch_size, hidden_dim)

        bbox_needed_logits = self.bbox_needed_head(pooled_output) # (batch_size, 1)
        bbox_coords = self.bbox_regression_head(pooled_output) # (batch_size, 4)

        loss = None
        if bbox_binary_label:
            bbox_needed_loss = nn.BCEWithLogitsLoss(bbox_needed_logits, bbox_binary_label)
            bbox_coord_loss = nn.MSELoss(bbox_coords, bbox_coords_label)
            loss = bbox_needed_loss + bbox_coord_loss

        return {"loss": loss, "bbox_needed_logits": bbox_needed_logits, "bbox_coords": bbox_coords}


class CustomLlamaVLM(MllamaForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)

        self.bbox = BoundingBoxHead()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        return_dict=None,
        request_bounding_box=False,
        bbox_binary_label = None,
        bbox_coords_label = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.return_dict
        
        # Call the parent `forward` method
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            output_hidden_states=True, # required for bounding box
            return_dict=return_dict,
            **kwargs
        )
        
        loss = None
        bbox_needed_logits = None
        bbox_coords = None
        if request_bounding_box:
            bbox = self.bbox(attention_mask, outputs.hidden_states[-1], bbox_binary_label = bbox_binary_label, bbox_coords_label = bbox_coords_label)
            bbox_needed_logits = bbox["bbox_needed_logits"]
            bbox_coords = bbox["bbox_coords"]
            loss = bbox["loss"]


        return CausalLMOutputWithPastAndBBox(
            loss=loss,
            logits=outputs.logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            bbox_needed_logits=bbox_needed_logits,
            bbox_coords=bbox_coords
        )

# Usage example
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

# Initialize components
processor = AutoProcessor.from_pretrained(model_id)
model = CustomLlamaVLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True
)

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Some weights of CustomLlamaVLM were not initialized from the model checkpoint at meta-llama/Llama-3.2-11B-Vision-Instruct and are newly initialized: ['bbox.bbox_needed_head.bias', 'bbox.bbox_needed_head.weight', 'bbox.bbox_regression_head.bias', 'bbox.bbox_regression_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some parameters are on the meta device because they were offloaded to the disk.


In [6]:
index = 1
image = single_turn_dataset["validation"][index]["image"]
question = single_turn_dataset["validation"][index]["turns"][0]["query"]
ground_truth = single_turn_dataset["validation"][index]["answers"][0]["ans_full"]

In [None]:
from PIL import Image 
prompt = question
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": image.resize((640, 320), resample=Image.LANCZOS)},
            {"type": "text", "text": prompt}
        ]
    }
]
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
images = [msg["image"] for msg in messages[0]["content"] if msg["type"] == "image"]
print(model.device)
with torch.no_grad():
    inputs = processor(
        text=[text],
        images=images,
        return_tensors="pt",
        add_special_tokens=False
    ).to(model.device)

    output = model(**inputs, request_bounding_box=True)
    print(output)
    #output = model.generate(**inputs, max_new_tokens=5)

    ##print(ground_truth)
    #print(processor.decode(output[0], skip_special_tokens=True).split("assistant\n", 1)[1])

mps:0
