In [None]:
import av
import torch
import numpy as np

from transformers import AutoProcessor, AutoModel

np.random.seed(0)
categories = ['Soccer Corner', 'Soccer Free Kick', 'soccer Throw In']
# Define video categories
video_categories = {
    "Corner": ["1_CornerKick_2.mp4"],
    "FreeKick": ["1_FreeKick_1.mp4"],
    "ThrowIn": ["1_ThrowIn_25.mp4"]
}

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices

# Load and process videos by category
video_data = {}
for category, file_names in video_categories.items():
    video_data[category] = []
    for file_name in file_names:
        container = av.open(file_name)
        indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
        video = read_video_pyav(container, indices)
        video_data[category].append(video)

In [None]:
# Load model and processor
processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch16-zero-shot")
model = AutoModel.from_pretrained("microsoft/xclip-base-patch16-zero-shot")

In [None]:
# Process each category
for category, videos in video_data.items():
    print(videos[0].shape)
    print(category)
    inputs = processor(
        text=categories,
        videos=[frame for video in videos for frame in video],
        return_tensors="pt",
        padding=True,
    )

    with torch.no_grad():
        outputs = model(**inputs)
    
    logits_per_video = outputs.logits_per_video
    probs = logits_per_video.softmax(dim=1)
    print(f"Category: {category}")
    print(probs)