In [39]:
import torch
import numpy as np
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection 
from tqdm import tqdm

import sys
sys.path.insert(0, '/root/videollm-online/data')
from utils import load_frames_mp4

class videoObjectDetector():
    def __init__(self, model_id, device) -> None:
        self.processor = AutoProcessor.from_pretrained(model_id)
        self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
        self.device = device
    
    @staticmethod
    def noun2class(nouns):
        # Check for cats and remote controls
        # VERY important: text queries need to be lowercased + end with a dot
        cls_str = ""
        for noun in nouns:
            cls_str += f"a {noun}. "
        return cls_str
    
    def detect_objects(self, image, nouns, box_threshold=0.5, text_threshold=0.5):
        text = videoObjectDetector.noun2class(nouns)
        inputs = self.processor(images=image, text=text, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
        results = self.processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            target_sizes=[image.shape[:2]]
        )

        return results
    
    def detect_video(self, path, nouns, load_range):
        frames = load_frames_mp4(path, load_range)
        results = []
        for frame in tqdm(frames):
            results.append(self.detect_objects(frame.permute(1,2,0).numpy(), nouns))
        return results, frames

model_id = "IDEA-Research/grounding-dino-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
video_detector = videoObjectDetector(model_id, device)




100%|██████████| 100/100 [00:23<00:00,  4.32it/s]


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

def visualize_detections(image, boxes, labels, scores):
    """
    在图像上可视化检测框、类别和分数。

    参数:
    - image_path: 图像的路径
    - boxes: 检测框的坐标 (tensor)，形状为 [N, 4]，每一行表示 [x_min, y_min, x_max, y_max]
    - labels: 类别索引 (tensor)，形状为 [N]，表示每个框的类别
    - scores: 检测分数 (tensor)，形状为 [N]，表示每个框的置信度分数
    - category_names: 类别名称列表，索引对应类别标签
    """
    fig, ax = plt.subplots(1)
    ax.imshow(image)

    for i, box in enumerate(boxes):
        x_min, y_min, x_max, y_max = box.tolist()
        label = labels[i]  # 类别索引
        score = scores[i].item()  # 检测分数
        rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
                                 linewidth=0.5, edgecolor='r', facecolor='none')

        ax.add_patch(rect)
        
        label_text = f'{label}: {score:.2f}'
        plt.text(x_min, y_min - 10, label_text, color='white', fontsize=10, bbox=dict(facecolor='red', alpha=0.5))

    # 显示图像
    plt.axis('off')  # 隐藏坐标轴
    plt.show()

In [1]:
path = '/root/videollm-online/datasets/ego4d/v2/full_scale_1fps/9ff8c35c-bd28-436f-b35b-ee460f983a67.mp4'
nouns = ['vehicles']
load_range = range(0, 100)

results, frames = video_detector.detect_video(path, nouns, load_range)
for result, frame in zip(results, frames):
    image = frame.permute(1,2,0).numpy()
    print(image.shape)
    visualize_detections(image, result[0]['boxes'], result[0]['labels'],result[0]['scores'])

NameError: name 'video_detector' is not defined

In [36]:
image = Image.open('/root/videollm-online/data/preprocess/output.png')
print(np.array(image).shape)
result = video_detector.detect_objects(image, nouns)
print(image.size)
visualize_detections(np.array(image)[:,:,:3], result[0]['boxes'], result[0]['labels'], result[0]['scores'])

(402, 512, 4)
(512, 402)
