In [1]:
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import json
import os
import re
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import cv2
import pickle
from collections import defaultdict

from utils import get_area, get_img, crop_square_with_context, get_text

## Load MOLMO and Ground Truth annotations

In [None]:
processor = AutoProcessor.from_pretrained(
    "allenai/Molmo-7B-D-0924",
    trust_remote_code=True,
    torch_dtype="auto",
    device_map="auto",
)

# load the model
model = AutoModelForCausalLM.from_pretrained(
    "allenai/Molmo-7B-D-0924",
    trust_remote_code=True,
    torch_dtype="auto",
    device_map="auto",
)
model = model.to("cuda")


# Load annotations
with open("../../resources/annotations_public.pkl", "rb") as f:
    anns = pickle.load(f)

# Process annotation to get area
data = []
for video, video_data in anns.items():
    for frame, frame_data in video_data.items():
        for track in frame_data["challenge_object"]:
            data.append(
                {
                    "video": video,
                    "frame": frame,
                    "track_id": track["track_id"],
                    "bbox": track["bbox"],
                    "area": get_area(track["bbox"]),
                }
            )
df = pd.DataFrame(data)


## Run MOLMO
- For each object in each video:
    - Do inference for top 5 largest boxes in video
    - Crop square to prevent distortion

In [None]:
prompt = """
    Propose 5 most likely class labels of the object, context of the image is traffic and unusual hazards such as various animals on the road. Write only the class names separated by spaces.
    """

data = defaultdict(list)
for (video, track_id), group in tqdm(df.groupby(["video", "track_id"])):
    select = group.sort_values(by="area", ascending=False).head(5)
    for _, row in select.iterrows():
        img = get_img(f"{video}.mp4", frame_id=row["frame"])
        img_crop = crop_square_with_context(img, row["bbox"], 0.0)
        text = get_text(img_crop, prompt, processor, model)
        text = text.strip(" ").replace("\n", " ")
        data[video, track_id].append(text)

data = dict(data)
torch.save(data, f"results/molmo-obj-cap-largest.pkl")