# Multi modality Example


In [9]:
# !pip install numpy torch transformers datasets opencv-python

import os
import cv2
import random
import subprocess
import numpy as np
from typing import Dict, List, Optional

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
from utils.video_downloader import download_videos, get_available_formats
from utils.dataloader import create_dataloader


os.environ["TOKENIZERS_PARALLELISM"] = "false"

### First thing first, load a multimodal dataset

In [10]:
def load_multimodal_dataset(dataset_name: str, split: str = "train"):
    """
    Load a multimodal dataset from the Hugging Face Hub.
    """
    dataset_dict = load_dataset(dataset_name, split)

    # print(f"Available splits: {list(dataset_dict.keys())}")

    if 'train' in dataset_dict:
        dataset = dataset_dict['train']
    else:
        # If no 'train' split, use the first available split
        first_split = list(dataset_dict.keys())[0]
        dataset = dataset_dict[first_split]
        print(f"Using split '{first_split}' (no 'train' split found)")
    
    print(f"Original dataset size: {len(dataset)}")

    return dataset


dataset = load_multimodal_dataset("friedrichor/MSR-VTT", split="train_7k")

print(dataset)

Original dataset size: 7010
Dataset({
    features: ['video_id', 'video', 'caption', 'source', 'category', 'url', 'start time', 'end time', 'id'],
    num_rows: 7010
})


### Let's analyse the dataset structure

In [11]:
def analyze_dataset_structure(dataset):
    print(f"Dataset Features: {list(dataset.features.keys())}")
    print(f"Dataset size: {len(dataset)}")

    random_sample_data = dataset[random.randint(0, len(dataset))]

    print(f"Sample keys: {random_sample_data.keys()}")
    print(f"Sample values: {random_sample_data}")

    # First, let's examine the structure of captions
    random_sample_data_caption = random_sample_data['caption']
    print(f"\nCaption field type: {type(random_sample_data_caption)}")
    print(f"Sample caption: {random_sample_data_caption}")

    # Handle different caption formats
    if isinstance(random_sample_data_caption, list):
        # If captions are lists, flatten them or take first element
        captions = []
        for item in dataset:
            caption_list = item['caption']
            if isinstance(caption_list, list) and len(caption_list) > 0:
                captions.append(caption_list[0])  # Take first caption
            else:
                captions.append(str(caption_list))
    else:
        # If captions are strings
        captions = [item['caption'] for item in dataset]
    
    # Calculate caption lengths
    caption_lengths = [len(str(caption).split()) for caption in captions]
    
    print(f"\n=== Caption Analysis ===")
    print(f"Average caption length: {np.mean(caption_lengths):.2f} words")
    print(f"Min caption length: {min(caption_lengths)} words")
    print(f"Max caption length: {max(caption_lengths)} words")
    print(f"Median caption length: {np.median(caption_lengths):.2f} words")

    # Analyze video IDs
    video_ids = [item['video_id'] for item in dataset]
    unique_videos = len(set(video_ids))
    print(f"\n=== Video Analysis ===")
    print(f"Total samples: {len(video_ids)}")
    print(f"Unique videos: {unique_videos}")
    print(f"Average captions per video: {len(video_ids) / unique_videos:.2f}")
    
    # Check for categories if available
    if 'category' in dataset.features:
        categories = [item['category'] for item in dataset]
        unique_categories = set(categories)
        print(f"\n=== Category Analysis ===")
        print(f"Unique categories: {len(unique_categories)}")
        print(f"Categories: {sorted(unique_categories)}")

    # Check for other fields
    print(f"\n=== Available Fields ===")
    for key, value in random_sample_data.items():
        print(f"{key}: {type(value)} - {str(value)[:100]}...")
    
    return {
        'caption_lengths': caption_lengths,
        'unique_videos': unique_videos,
        'sample_fields': list(random_sample_data.keys()),
        'total_samples': len(dataset),
        'processed_captions': captions
    }


# Analyze the dataset
analysis = analyze_dataset_structure(dataset)

# print(analysis)

Dataset Features: ['video_id', 'video', 'caption', 'source', 'category', 'url', 'start time', 'end time', 'id']
Dataset size: 7010
Sample keys: dict_keys(['video_id', 'video', 'caption', 'source', 'category', 'url', 'start time', 'end time', 'id'])
Sample values: {'video_id': 'video4907', 'video': 'video4907.mp4', 'caption': ['two guys dressed as spiderman are on a basketball court', 'a man in superman mask is talking to another masked man', 'there is a man is walking through the street', 'there is a man with mask is talking with the same', 'there is a mask man is talking with the same', 'spider man starts a fight with a guy in a costume', 'two spidermen meet in the court with ball and have an argument', 'a man in a spiderman costume is standing on a basketball court with another man in a deadpool costume', 'a man walks by a rough-textured wall with a door which turns wavy and results in two superheroes revealing their identity in a park', 'a man with a spiderman costume interacts with

### Now, we have a basic understanding of the data

In [12]:
def preprocess_text_captions(captions: List[str], tokenizer_name: str = "bert-base-uncased", max_length: int = 77) -> Dict:
    print(f"Preprocessing {len(captions)} captions with {tokenizer_name}")

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    # Show original captions
    print("\nOriginal captions (first 3):")
    for i, caption in enumerate(captions[:3]):
        print(f"{i+1}: {caption}")
    
    # Tokenize captions
    tokenized = tokenizer(
        captions,
        padding=True,           # Pad to same length
        truncation=True,        # Truncate if too long
        max_length=max_length,
        return_tensors="pt"     # Return PyTorch tensors
    )
    
    print(f"\nTokenized output shape:")
    print(f"Input IDs: {tokenized['input_ids'].shape}")
    print(f"Attention mask: {tokenized['attention_mask'].shape}")
    
    # Show tokenized example
    print(f"\nTokenized example (first caption):")
    tokens = tokenizer.convert_ids_to_tokens(tokenized['input_ids'][0])
    print(f"Tokens: {tokens[:15]}...")  # Show first 15 tokens
    
    # Analyze tokenization statistics
    token_lengths = tokenized['attention_mask'].sum(dim=1)
    print(f"\nTokenization statistics:")
    print(f"Average tokens per caption: {token_lengths.float().mean():.2f}")
    print(f"Max tokens used: {token_lengths.max().item()}")
    print(f"Min tokens used: {token_lengths.min().item()}")
    
    return {
        'input_ids': tokenized['input_ids'],
        'attention_mask': tokenized['attention_mask'],
        'tokenizer': tokenizer,
        'token_lengths': token_lengths
    }

# Test text preprocessing with actual captions from dataset
sample_captions = analysis['processed_captions'][:10]  # Use processed captions from analysis
text_data = preprocess_text_captions(sample_captions)

Preprocessing 10 captions with bert-base-uncased

Original captions (first 3):
1: a car is shown
2: in a kitchen a woman adds different ingredients into the pot and stirs it
3: a guying showing a tool

Tokenized output shape:
Input IDs: torch.Size([10, 17])
Attention mask: torch.Size([10, 17])

Tokenized example (first caption):
Tokens: ['[CLS]', 'a', 'car', 'is', 'shown', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']...

Tokenization statistics:
Average tokens per caption: 10.50
Max tokens used: 17
Min tokens used: 6


### Download Video to analyse it later

In [13]:
random_sample_data = dataset[random.randint(0, len(dataset))]
downloaded_path = download_videos(random_sample_data)
print(downloaded_path)

Getting available formats for video2383...
Available formats: ['269', '230']
Trying format ID: 269
Downloaded: video2383 with format 269
videos/video2383.mp4


---
## Till Now:
* Loaded Multimodal Dataset "friedrichor/MSR-VTT" from HuggingFace
* Analysed dataset structure Like:
    * Dataset Features: columns/keys present in the dataset
    * Dataset Size 
    * Checked sample datas/values
    * Also analysed caption key, and got that it's a list so further we analysed the captions like avg, min, max, median length of each caption.
    * Checked total videos, toal unique videos, average captions per video 
    * Check categories of the videos etc.
* Processed texts - in this case captions and tokenised them in same size.
* Also Download Video mentioned in the dataset

### Time To Create Custom **MultimodalDataset** Class!

In [14]:
class MultimodalDataset(Dataset):

    def __init__(self, dataset, video_dir="videos", max_frames=8):
        self.dataset = dataset
        self.video_dir = video_dir
        self.max_frames = max_frames
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        
        # Find valid samples
        self.valid_indices = []
        for i in range(len(dataset)):
            video_id = dataset[i]['video_id']
            if os.path.exists(os.path.join(video_dir, f"{video_id}.mp4")):
                self.valid_indices.append(i)
        
        print(f"Found {len(self.valid_indices)} valid samples")
    
    def __len__(self):
        return len(self.valid_indices)
    
    def __getitem__(self, idx):
        sample = self.dataset[self.valid_indices[idx]]
        
        # Process text
        caption = sample.get('caption', '')
        if isinstance(caption, list):
            caption = caption[0]
        
        tokenized = self.tokenizer(
            caption, padding='max_length', truncation=True, 
            max_length=77, return_tensors='pt'
        )
        
        # Load video frames using cv2
        video_path = os.path.join(self.video_dir, f"{sample['video_id']}.mp4")
        video_frames = self._load_video_cv2(video_path)
        
        return {
            'text_ids': tokenized['input_ids'].squeeze(0),
            'text_mask': tokenized['attention_mask'].squeeze(0),
            'caption': caption,
            'video_frames': video_frames,
            'video_id': sample['video_id'],
            'video': sample['video'],
            'source': sample['source'],
            'category': sample['category'],
            'url': sample['url'],
            'start_time': sample['start time'],
            'end_time': sample['end time'],
            'id': sample['id']
        }
    
    def _load_video_cv2(self, video_path):
        """Load video frames using OpenCV with comprehensive preprocessing.
        
        This function loads a video file, extracts frames uniformly across the video duration,
        preprocesses them for machine learning (resizing, color conversion, normalization),
        and returns them as a PyTorch tensor suitable for multimodal model training.
        
        Args:
            video_path (str): Full path to the video file (.mp4, .avi, etc.)
            
        Returns:
            torch.Tensor: A tensor of shape (max_frames, 3, 224, 224) containing
                        RGB video frames normalized to [0, 1] range.
                        
        Raises:
            Exception: If video cannot be loaded, falls back to placeholder frames.
            
        Processing Pipeline:
            1. Open video file using OpenCV VideoCapture
            2. Extract video metadata (total frames, fps)
            3. Calculate uniform frame sampling indices
            4. Extract and preprocess each frame:
            - Convert from BGR to RGB color space
            - Resize to 224x224 pixels (standard vision model input size)
            - Normalize pixel values from [0, 255] to [0, 1] range
            - Convert to PyTorch tensor with channels-first format (C, H, W)
            5. Handle frame padding to ensure consistent output size
            6. Stack all frames into a single tensor
        """
        try:
            # Step 1: Initialize OpenCV VideoCapture object
            # VideoCapture is the primary interface for reading video files in OpenCV
            cap = cv2.VideoCapture(video_path)
            
            # Step 2: Verify video file can be opened
            # isOpened() returns True if the video source has been initialized successfully
            if not cap.isOpened():
                print(f"Could not open video: {video_path}")
                return None
            
            # Step 3: Extract video metadata for frame sampling strategy
            # CAP_PROP_FRAME_COUNT: Total number of frames in the video
            # CAP_PROP_FPS: Frames per second of the video
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            fps = cap.get(cv2.CAP_PROP_FPS)
            
            # Step 4: Validate video properties
            # Videos with 0 frames or invalid fps cannot be processed
            if total_frames <= 0 or fps <= 0:
                print(f"Invalid video properties: frames={total_frames}, fps={fps}")
                cap.release()  # Always release resources
                return None
            
            # Step 5: Calculate frame sampling indices for uniform temporal coverage
            # We want to sample max_frames uniformly across the entire video duration
            # This ensures we capture the video's temporal progression regardless of length
            if total_frames <= self.max_frames:
                # If video has fewer frames than needed, use all available frames
                frame_indices = list(range(total_frames))
            else:
                # Calculate uniformly spaced indices across the video
                # Formula: index = (sample_position * total_frames) / max_frames
                # This gives us evenly distributed frames across the entire video
                frame_indices = [
                    int(i * total_frames / self.max_frames) 
                    for i in range(self.max_frames)
                ]
            
            frames = []
            
            # Step 6: Extract and preprocess each sampled frame
            for frame_idx in frame_indices:
                # Step 6a: Seek to specific frame position
                # CAP_PROP_POS_FRAMES sets the 0-based index of the frame to be decoded/captured next
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
                
                # Step 6b: Read the frame at current position
                # ret: boolean indicating if frame was read successfully
                # frame: numpy array containing the frame data (H, W, C) in BGR format
                ret, frame = cap.read()
                
                # Step 6c: Verify frame was read successfully
                if ret and frame is not None:
                    # Step 6d: Color space conversion from BGR to RGB
                    # OpenCV uses BGR (Blue-Green-Red) by default, but most ML models expect RGB
                    # This is crucial for correct color representation in the model
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    
                    # Step 6e: Resize frame to standard input size
                    # 224x224 is the standard input size for many vision models (ResNet, ViT, etc.)
                    # Bicubic interpolation preserves image quality during resizing
                    frame = cv2.resize(frame, (224, 224))
                    
                    # Step 6f: Convert to PyTorch tensor and normalize
                    # Convert numpy array to PyTorch tensor with float32 precision
                    # Permute dimensions from (H, W, C) to (C, H, W) - channels first format
                    # Normalize pixel values from [0, 255] to [0, 1] range for numerical stability
                    frame_tensor = torch.tensor(frame, dtype=torch.float32).permute(2, 0, 1) / 255.0
                    
                    frames.append(frame_tensor)
                else:
                    # Step 6g: Handle frame reading failure
                    print(f"Could not read frame {frame_idx}")
                    break  # Stop processing if we encounter read errors
            
            # Step 7: Release video capture resources
            # Always release to prevent memory leaks and file locks
            cap.release()
            
            # Step 8: Handle frame padding to ensure consistent output size
            # If we have fewer frames than max_frames, pad with the last valid frame
            # This ensures all samples have the same tensor dimensions for batching
            while len(frames) < self.max_frames:
                if frames:
                    # Clone the last frame to avoid tensor sharing issues
                    frames.append(frames[-1].clone())
                else:
                    # If no frames were successfully read, return None
                    return None
            
            # Step 9: Ensure exact frame count by truncating if necessary
            # This handles edge cases where we might have extracted more frames than needed
            frames = frames[:self.max_frames]
            
            # Step 10: Stack individual frame tensors into a single tensor
            # Result shape: (max_frames, 3, 224, 224)
            # This creates a 4D tensor suitable for video processing models
            return torch.stack(frames)
            
        except Exception as e:
            # Step 11: Comprehensive error handling
            # Catch any unexpected errors (codec issues, corrupted files, etc.)
            print(f"Error loading video {video_path}: {e}")
            return None

### Let's use MultimodalDataset to analyse videos

In [15]:
multimodal_dataset = MultimodalDataset(dataset, video_dir="videos")
dataloader = create_dataloader(multimodal_dataset, batch_size=2)

# Test
batch = next(iter(dataloader))
print("Available keys:")
for key in batch.keys():
    print(f"  - {key}")

print(f"\nBatch size: {len(batch['captions'])}")
print(f"Video shape: {batch['video_frames'].shape}")
print(f"Text shape: {batch['text_ids'].shape}")
print(f"Video frames range: [{batch['video_frames'].min():.3f}, {batch['video_frames'].max():.3f}]")

# Display data for ALL videos in the batch
for i in range(len(batch['captions'])):
    print(f"\n=== Video {i+1} ===")
    print(f"URL: {batch['urls'][i]}")
    print(f"Caption: {batch['captions'][i]}")
    print(f"Video ID: {batch['video_ids'][i]}")
    print(f"Source: {batch['sources'][i]}")
    print(f"Category: {batch['categories'][i]}")
    print(f"Start time: {batch['start_times'][i]}")
    print(f"End time: {batch['end_times'][i]}")
    print(f"Duration: {batch['end_times'][i] - batch['start_times'][i]:.1f}s")

Found 3 valid samples
Available keys:
  - text_ids
  - text_masks
  - captions
  - video_frames
  - video_ids
  - videos
  - sources
  - categories
  - urls
  - start_times
  - end_times
  - ids

Batch size: 2
Video shape: torch.Size([2, 8, 3, 224, 224])
Text shape: torch.Size([2, 77])
Video frames range: [0.000, 1.000]

=== Video 1 ===
URL: https://www.youtube.com/watch?v=trHNRK7NfUc
Caption: a clip of a football team celebrating
Video ID: video2383
Source: MSR-VTT
Category: 3
Start time: 17.75
End time: 28.81
Duration: 11.1s

=== Video 2 ===
URL: https://www.youtube.com/watch?v=c0uW5eQqQjM
Caption: a boy is playing a video game
Video ID: video4471
Source: MSR-VTT
Category: 2
Start time: 524.24
End time: 535.65
Duration: 11.4s
