## Classify objects using CIFAR pre-trained classifier -- ❕Runs around 30minutes❕

In [None]:
import torch
import numpy as np
from tqdm import tqdm
import os
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor
import cv2
import pickle

from utils import cifar100_classes, crop_with_context, get_img

### Classify each object using vit_base-224

In [None]:
extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTForImageClassification.from_pretrained(
    "edumunozsala/vit_base-224-in21k-ft-cifar100"
)
model.to("cuda:1")

# Open annotations
with open("../../resources/annotations_public.pkl", "rb") as f:
    anns = pickle.load(f)


def process(image, objects):
    for obj in objects:
        img_cropped = crop_with_context(image, obj["bbox"], 0.2)
        inputs = extractor(images=img_cropped, return_tensors="pt").to("cuda:1")

        with torch.no_grad():
            outputs = model(**inputs)

        probabilities = torch.nn.functional.softmax(outputs.logits.cpu(), dim=-1)
        prob, idx = probabilities.topk(10)
        top10_class = cifar100_classes[idx[0]]
        top10_probs = prob[0].numpy()

        obj["top10_probs"] = top10_class
        obj["top10_class"] = top10_probs
    return objects


folder_path = "../../dataset/coool-benchmark/"

results = {}
for filename in tqdm(os.listdir(folder_path)):
    if filename.endswith(".mp4"):
        video = filename.split(".")[0]
        cap = cv2.VideoCapture(os.path.join(folder_path, filename))
        original_fps = cap.get(cv2.CAP_PROP_FPS)

        # Initialize storage for this video
        video_results = {}

        frame_count = 0
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Process every Nth frame to match the target FPS
            image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            objects = anns[video][frame_count]["challenge_object"]
            video_results[frame_count] = process(image, objects)
            frame_count += 1

        cap.release()
        results[video] = video_results

torch.save(results, f"../../resources/cifar-classification/all-dense-test.pkl")

 21%|██▏       | 43/201 [06:08<16:59,  6.45s/it]  