In [None]:
# !pip install --upgrade "pyarrow>=21.0.0"
# !pip install -q "transformers>=4.57.0"
# !pip install -q datasets av
# !pip install -q bitsandbytes accelerate
# !pip install "pydantic<2.12" --no-deps

In [5]:
%pip freeze | grep pydantic

Note: you may need to restart the kernel to use updated packages.


In [1]:
# %pip install "pydantic<2.12" --no-deps

In [2]:
import torch
import numpy as np
import av
import pandas as pd
import pickle
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from tqdm import tqdm

In [None]:
import pandas as pd
import requests
import subprocess
import os
import zipfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import threading
from datasets import load_dataset

ZIP_PATH

category_map = {
    "0": "music",
    "1": "people", 
    "2": "gaming",
    "3": "sports/actions",
    "4": "news/events/politics",
    "5": "education",
    "6": "tv shows",
    "7": "movie/comedy", 
    "8": "animation",
    "9": "vehicles/autos",
    "10": "howto",
    "11": "travel",
    "12": "science/technology",
    "13": "animals/pets",
    "14": "kids/family",
    "15": "documentary",
    "16": "food/drink",
    "17": "cooking",
    "18": "beauty/fashion",
    "19": "advertisement"
}

def download_msrvtt_zip():
    zip_path = "MSRVTT_Videos.zip"
    if not os.path.exists(zip_path):
        url = "https://huggingface.co/datasets/friedrichor/MSR-VTT/resolve/main/MSRVTT_Videos.zip"
        print("Скачиваю архив...")
        response = requests.get(url)
        with open(zip_path, 'wb') as f:
            f.write(response.content)
    return zip_path

def extract_video(zip_ref, video_file, output_dir):
    video_path = os.path.join(output_dir, video_file.split('/')[-1])
    if os.path.exists(video_path):
        return video_path
    
    try:
        with zip_ref.open(f"video/{video_file}") as src, open(video_path, 'wb') as dst:
            dst.write(src.read())
    except:
        return None
    
    return video_path

def process_msrvtt_video(zip_ref, item, output_dir):
    video_file = item['video']
    duration = item['end time'] - item['start time']
    
    if duration > 30:
        return None
    
    video_path = extract_video(zip_ref, video_file, output_dir)
    if not video_path:
        return None
    
    return {
        'video_path': video_path,
        'caption': item['caption'],
        'category': category_map.get(str(item['category']), "unknown")
    }

def download_msrvtt_dataset(save_path, num_videos, max_workers=10):
    dataset = load_dataset("friedrichor/MSR-VTT", "test_1k", split="test", streaming=True)
    
    items = []
    for i, item in enumerate(dataset):
        if item['end time'] - item['start time'] <= 30:
            items.append(item)
        if len(items) >= num_videos:
            break
    
    print(f"Найдено {len(items)} видео до 30 секунд")
    
    zip_path = download_msrvtt_zip()
    os.makedirs(save_path, exist_ok=True)
    
    samples = []
    lock = threading.Lock()
    
    with zipfile.ZipFile(zip_path, 'r') as zip_ref, ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_msrvtt_video, zip_ref, item, save_path): item for item in items}
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Извлечение"):
            result = future.result()
            if result:
                with lock:
                    samples.append(result)
    
    df = pd.DataFrame(samples)
    df.to_csv(os.path.join(save_path, "msrvtt_dataset.csv"), index=False)
    print(f"Обработано {len(df)} видео")
    return df

In [None]:
BATCH_SIZE = 20
MAX_PIXELS = 298 * 224

In [None]:
def sample_frames(video_path, num_frames=10):
    container = av.open(video_path)
    total_frames = container.streams.video[0].frames
    if total_frames <= 0: return None
    
    indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
    frames = []
    container.seek(0)
    
    for i, frame in enumerate(container.decode(video=0)):
        if i in indices:
            frames.append(frame.to_image())
        if len(frames) == num_frames:
            break
    
    while len(frames) < num_frames and len(frames) > 0:
        frames.append(frames[-1])
        
    return frames

class VideoDataset(Dataset):
    def __init__(self, dataframe, processor):
        self.df = dataframe
        self.processor = processor
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        video_path = row['video_path']
        frames = sample_frames(video_path, num_frames=10)
        if frames is None: return None
            
        conversation = [{
            "role": "user",
            "content": [
                {"type": "video", "video": frames},
                {"type": "text", "text": "A very short, concise caption for this video."},
            ],
        }]
        
        prompt = self.processor.apply_chat_template(
            conversation, tokenize=False, add_generation_prompt=True
        )
        
        return {
            "frames": frames, 
            "prompt": prompt,
            "video_path": video_path
        }

def collate_fn(batch):
    return batch

class CaptionGenerationPipeline:
    def __init__(self):        
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        
        self.processor = AutoProcessor.from_pretrained(
            "Qwen/Qwen3-VL-2B-Instruct",
            max_pixels=MAX_PIXELS,
        )
        
        self.model = AutoModelForVision2Seq.from_pretrained(
            "Qwen/Qwen3-VL-2B-Instruct",
            quantization_config=bnb_config,
            device_map="cuda:0",
            dtype=torch.bfloat16,
            trust_remote_code=True
        )

        self.model.eval() 

    @torch.no_grad()
    def process_batch_(self, batch_items):        
        texts = [item["prompt"] for item in batch_items]
        videos = [item["frames"] for item in batch_items]
        
        inputs = self.processor(text=texts, videos=videos, padding=True, return_tensors="pt")
        inputs = inputs.to("cuda:0")

        output_ids = self.model.generate(**inputs, max_new_tokens=40, min_new_tokens=5)
        output_ids = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids)]
        output_texts = self.processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        
        return output_texts
    
    def generate_captions(self, dataloader):     
        results = []

        for batch in tqdm(dataloader):
            captions = self.process_batch_(batch)
            for i, item in enumerate(batch):
                results.append({
                    "video_path": item["video_path"],
                    "generated_caption": captions[i]
                })
                
        return pd.DataFrame(results)

In [None]:
df = download_msrvtt_dataset("MSRVTT_videos", 1000)

In [None]:
pipeline = CaptionGenerationPipeline()
dataset = VideoDataset(df, pipeline.processor)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

In [None]:
results = pipeline.generate_captions(dataloader)

with open("msrvtt_captions.pkl", 'wb') as f:
    pickle.dump(results, f)