In [40]:
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import json
import os
import re
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import cv2
import pickle
from collections import defaultdict

def get_text(image, prompt, processor, model):
    img = Image.fromarray(image)
    inputs = processor.process(
        images=[img],
        text=prompt
    )
    
    # move inputs to the correct device and make a batch of size 1
    inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
    output = model.generate_from_batch(
      inputs,
      GenerationConfig(max_new_tokens=400, stop_strings="<|endoftext|>"),
      tokenizer=processor.tokenizer
    )

    # only get generated tokens; decode them to text
    generated_tokens = output[0,inputs['input_ids'].size(1):]
    generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
    return generated_text


def crop_square_with_context(image, bounding_box, context_percent=0.0, min_size=10):
    """
    Crops a square region from the image defined by a bounding box, ensuring minimum size
    and adding additional context as a percentage of box size.
    """
    bounding_box = np.array(bounding_box).round().astype(int)
    x1, y1, x2, y2 = bounding_box

    # Ensure minimum box size
    box_width = x2 - x1
    box_height = y2 - y1

    if box_width < min_size:
        padding_x = (min_size - box_width) // 2
        x1 -= padding_x
        x2 += padding_x

    if box_height < min_size:
        padding_y = (min_size - box_height) // 2
        y1 -= padding_y
        y2 += padding_y

    # Recalculate box dimensions
    box_width = x2 - x1
    box_height = y2 - y1

    # Ensure the box is square by making both sides equal to the larger dimension
    side_length = max(box_width, box_height)

    # Center the square box around the original bounding box
    x_center = (x1 + x2) // 2
    y_center = (y1 + y2) // 2

    x1 = x_center - side_length // 2
    x2 = x_center + side_length // 2
    y1 = y_center - side_length // 2
    y2 = y_center + side_length // 2

    # Add context as a percentage of the side length
    context = int(side_length * context_percent)
    x1 -= context
    y1 -= context
    x2 += context
    y2 += context

    # Ensure the box stays within image bounds
    height, width = image.shape[:2]
    x1 = max(0, x1)
    y1 = max(0, y1)
    x2 = min(width, x2)
    y2 = min(height, y2)

    # Crop the square region
    cropped_image = image[y1:y2, x1:x2]
    return cropped_image



def get_img(filename, frame_id, folder_path="../data"):
    cap = cv2.VideoCapture(os.path.join(folder_path, filename))
    frame_count = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Process the target frame
        if frame_count == frame_id:
            video = filename.split('.')[0]
            image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            break
        frame_count += 1    
    cap.release()
    return image


def get_area(bbox):
    """
    Calculate the area of a bounding box.
    """
    x1, y1, x2, y2 = bbox
    
    # Ensure valid bounding box
    if x2 <= x1 or y2 <= y1:
        raise ValueError("Invalid bounding box dimensions.")

    width = x2 - x1
    height = y2 - y1
    
    return width * height

## Load MOLMO and Ground Truth annotations

In [None]:
processor = AutoProcessor.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)

# load the model
model = AutoModelForCausalLM.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)
model = model.to('cuda')


# Load annotations
with open('../annotations_public.pkl', 'rb') as f:
    anns = pickle.load(f)

# Process annotation to get area
data = []
for video, video_data in anns.items():
    for frame, frame_data in video_data.items():
        for track in frame_data['challenge_object']:
            data.append({
                'video': video,
                'frame': frame,
                'track_id': track['track_id'],
                'bbox': track['bbox'],
                'area': get_area(track['bbox']),
            })
df = pd.DataFrame(data)


## Run MOLMO
- For each object in each video:
    - Do inference for top 5 largest boxes in video
    - Crop square to prevent distortion

In [None]:
prompt = '''
    Propose 5 most likely class labels of the object, context of the image is traffic and unusual hazards such as various animals on the road. Write only the class names separated by spaces.
    '''

data = defaultdict(list)
for (video, track_id), group in tqdm(df.groupby(['video', 'track_id'])):
    select = group.sort_values(by='area', ascending=False).head(5)
    for _, row in select.iterrows():
        img = get_img(f'{video}.mp4', frame_id=row['frame'])
        img_crop = crop_square_with_context(img, row['bbox'], 0.0)
        text = get_text(img_crop, prompt, processor, model)
        text = text.strip(' ').replace('\n', ' ')
        data[video, track_id].append(text)

data = dict(data)
torch.save(data, f'results/molmo-obj-cap-largest.pkl')