In [None]:
%pip install -U cragmm-search-pipeline
%pip install datasets torch transformers matplotlib accelerate ipywidgets

In [3]:
"""
from cragmm_search.search import UnifiedSearchPipeline

search_pipeline = UnifiedSearchPipeline(
    image_model_name="openai/clip-vit-large-patch14-336",
    image_hf_dataset_id="crag-mm-2025/image-search-index-validation",
    text_model_name="sentence-transformers/all-MiniLM-L6-v2",
    web_hf_dataset_id="crag-mm-2025/web-search-index-validation",
)
"""

'\nfrom cragmm_search.search import UnifiedSearchPipeline\n\nsearch_pipeline = UnifiedSearchPipeline(\n    image_model_name="openai/clip-vit-large-patch14-336",\n    image_hf_dataset_id="crag-mm-2025/image-search-index-validation",\n    text_model_name="sentence-transformers/all-MiniLM-L6-v2",\n    web_hf_dataset_id="crag-mm-2025/web-search-index-validation",\n)\n'

In [4]:
from huggingface_hub import login
login(token="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")

In [None]:
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers import AutoProcessor, MllamaForConditionalGeneration, MllamaConfig
from transformers.utils import ModelOutput
from torch import nn

@dataclass
class CausalLMOutputWithPastAndBBox(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    bbox_coords: torch.FloatTensor = None


class BoundingBoxHead(nn.Module):
    def __init__(self):
        super().__init__()
        # [x1, y1, x2, y2] pixel relative coordinates (0–1) to input image size.
        self.bbox_regression_head = nn.Linear(4096, 4)
        self.dtype = self.bbox_regression_head.weight.dtype

    def forward(self, attention_mask, last_hidden_states, bbox_binary_label=None, bbox_coords_label=None):
        attention_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).to(self.dtype)
        sub_embeddings = torch.sum(last_hidden_states * attention_mask_expanded, dim=1)
        sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
        pooled_output = sub_embeddings / sum_mask  # (batch_size, hidden_dim)

        bbox_coords = self.bbox_regression_head(pooled_output)  # (batch_size, 4)

        loss = None
        if bbox_binary_label is not None:
            bbox_coord_loss = nn.MSELoss()(bbox_coords, bbox_coords_label)
            loss = bbox_coord_loss

        return {"loss": loss, "bbox_coords": bbox_coords}


class CustomLlamaVLM(MllamaForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        self.bbox = BoundingBoxHead()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        return_dict=None,
        request_bounding_box=False,
        bbox_binary_label=None,
        bbox_coords_label=None,
        **kwargs,
    ):
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # Call the parent `forward` method
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            output_hidden_states=True,  # required for bounding box
            return_dict=return_dict,
            **kwargs,
        )

        loss = None
        bbox_coords = None
        if request_bounding_box:
            bbox = self.bbox(attention_mask, outputs.hidden_states[-1], bbox_binary_label=bbox_binary_label, bbox_coords_label=bbox_coords_label)
            bbox_coords = bbox["bbox_coords"]
            loss = bbox["loss"]

        return CausalLMOutputWithPastAndBBox(
            loss=loss,
            logits=outputs.logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            bbox_coords=bbox_coords
        )

# Usage example
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

# Initialize components
processor = AutoProcessor.from_pretrained(model_id)
model = CustomLlamaVLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
    )

In [None]:
system_prompt = "You are a helpful assistant that can answer questions about images. You will be given an image and a question. If the question requires a bounding box, output the bounding box in the format of [x1, y1, x2, y2]. If the question does not require a bounding box, output \"no\". Do not output any other additional text."
prompt_template = "Given this question \"{question}\", output a single bounding box in the format of [x1, y1, x2, y2] that best captures the object to answer this question."

import torch
from PIL import Image
from io import BytesIO
import requests
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
from matplotlib.patches import patches
from datasets import load_dataset

def draw_image_with_bbox(image, bbox_coords):
    _, ax = plt.subplots()
    plt.imshow(image)
    left, top, right, bottom = bbox_coords
    width, height = image.size
    box_width = (right - left) * width
    box_height = (top - bottom) * height
    rect = patches.Rectangle((left * width, bottom * height), box_width, box_height, linewidth=2, edgecolor='r', facecolor='none')
    ax.add_patch(rect)
    plt.show()

import os
import hashlib
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

home_dir = os.path.expanduser("~")
cache_dir = os.path.join(home_dir, "image_cache")

def preload_and_cache(dataset, cache_dir=cache_dir, num_workers=4):
    os.makedirs(cache_dir, exist_ok=True)
    def download_sample(sample):
        url = sample["image"]
        filename = hashlib.md5(url.encode()).hexdigest() + ".jpg"
        filepath = os.path.join(cache_dir, filename)
        if os.path.exists(filepath):
            return  # Already cached
        try:
            response = requests.get(url, timeout=15)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content)).convert("RGB")
            image.save(filepath)
        except Exception as e:
            print(f"Failed to download {url}: {e}")

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        list(tqdm(executor.map(download_sample, dataset), total=len(dataset)))

class LazyDataset(Dataset):
    def __init__(self, dataset, processor, system_prompt, user_prompt_template, transform=None, cache_dir=cache_dir):
        self.samples = dataset
        self.system_prompt = system_prompt
        self.user_prompt_template = user_prompt_template
        self.processor = processor
        self.transform = transform or transforms.ToTensor()
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample = self.samples[index]
        img = self._load_image(sample)
        bbox_coords = self._process_bbox(sample)
        texts = self._process_text(sample)
        need_bboxs = torch.ones(1)

        return {
            "image": img,
            "texts": texts,
            "need_bboxs": need_bboxs,
            "bbox_coords": bbox_coords,
        }
    
    def _load_image(self, sample):
        url = sample["image"]
        filename = hashlib.md5(url.encode()).hexdigest() + ".jpg"
        filepath = os.path.join(self.cache_dir, filename)

        try:
            if os.path.exists(filepath):
                img = Image.open(filepath).convert("RGB")
            else:
                response = requests.get(url, timeout=15)
                response.raise_for_status()
                img = Image.open(BytesIO(response.content)).convert("RGB")
                img.save(filepath)
            img = img.resize((640, 480))

            if self.transform:
                img = self.transform(img)
            return img
        except Exception as e:
            print(f"Error loading image from {url}: {e}")

    def _process_bbox(self, sample):
        return torch.tensor([sample["left"] / sample["width"], sample["top"] / sample["height"], sample["right"] / sample["width"], sample["bottom"] / sample["height"]])

    def _process_text(self, sample):
        return self.processor.apply_chat_template(
            [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": [
                    {"type": "image"},
                    {"type": "text", "text": self.user_prompt_template.format(question=sample["question"])},
                ]},
            ],
            add_generation_prompt=True,
            tokenize=False
        )

# Load the dataset split
hf_dataset = load_dataset("toloka/WSDMCup2023", split="train")

# Create the dataset
dataset = LazyDataset(hf_dataset, processor, system_prompt, prompt_template)
preload_and_cache(dataset, cache_dir=cache_dir)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=12, shuffle=True, num_workers=2, multiprocessing_context="fork", prefetch_factor=2)


In [None]:
import torch.optim as optim
from tqdm.notebook import tqdm
from IPython.display import display, clear_output

for param in model.parameters():
    param.requires_grad = False
model.bbox.requires_grad = True

optimizer = optim.Adam(model.parameters(), lr=1e-5)
model.train()
running_loss = 0
loss_values = []

for epoch in range(1):
    print(f"Epoch {epoch+1}")
    num_batches = 0
    batch_bar = tqdm(dataloader)
    for batch in batch_bar:
        num_batches += 1

        inputs = processor(
            text=batch["texts"],  # List of str
            images=[[img] for img in batch["image"]],  # must break a batch of image into a list
            return_tensors="pt",
            add_special_tokens=False,
            padding=True, # needed for batch
            truncation=True # needed for batch
        ).to(model.device)

        bbox_binary_label = batch["need_bboxes"].to(model.device, dtype=model.dtype)
        bbox_coords_label = batch["bbox_coords"].to(model.device, dtype=model.dtype)
        outputs = model(**inputs, request_bounding_box=True, bbox_binary_label=bbox_binary_label, bbox_coords_label=bbox_coords_label)
        loss = outputs["loss"]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Plot every 10 batches
        if (num_batches + 1) % 10 == 0:
            avg_loss = running_loss / 10
            loss_values.append(avg_loss)
            print(f"[Epoch {epoch+1}], [Batch {num_batches + 1}], [Loss: {avg_loss:.4f}]")
            running_loss = 0.0
            clear_output(wait=True)
            plt.figure(figsize=(8, 6))
            plt.title('Training Loss')
            plt.xlabel('Every 10 Batches')
            plt.ylabel('Loss')
            plt.plot(loss_values, label='Loss')
            plt.legend()
            plt.show()
            display(batch_bar.container)

In [None]:
from datasets import load_dataset

single_turn_dataset = load_dataset("crag-mm-2025/crag-mm-single-turn-public", revision="v0.1.1")

In [None]:
model.eval()
for index in range(10):
    image = single_turn_dataset["validation"][index]["image"]
    resize_image = image.resize((640, 480))
    question = single_turn_dataset["validation"][index]["turns"][0]["query"]
    ground_truth = single_turn_dataset["validation"][index]["answers"][0]["ans_full"]

    prompt = question
    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt_template.format(question=question)},
            ]
        }
    ]
    
    text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    with torch.no_grad():
        inputs = processor(
            text=[text],
            images=[resize_image],
            return_tensors="pt",
            add_special_tokens=False
        ).to(model.device)

        output = model(**inputs, request_bounding_box=True)

        draw_image_with_bbox(resize_image, output.bbox_coords[0].cpu().float())

        output = model.generate(**inputs, max_new_tokens=20)

        print(ground_truth)
        out = processor.decode(output[0], skip_special_tokens=True).split("assistant\n", 1)[1]
        text_bbox = eval(out)
        draw_image_with_bbox(resize_image, text_bbox)

mps:0
