# Imports

## Required Imports

In [9]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import torchvision.models as models
from typing import List, Tuple
import cv2
import numpy as np
from tqdm import tqdm  # for progress bars

## Video Encoder

In [10]:
class VideoEncoder(nn.Module):
    def __init__(self, d_model=768):
        super().__init__()
        self.spatial_encoder = models.resnet50(pretrained=True)
        self.spatial_encoder = nn.Sequential(*list(self.spatial_encoder.children())[:-1])
        self.projection = nn.Linear(2048, d_model)

        # Modified transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=8,
            dim_feedforward=2048,
            batch_first=True  # Important for dimension handling
        )
        self.temporal_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, F, C, H, W = x.shape

        # Process each frame
        x = x.view(B * F, C, H, W)
        spatial_features = self.spatial_encoder(x)
        spatial_features = spatial_features.squeeze(-1).squeeze(-1)  # Remove spatial dimensions
        spatial_features = self.projection(spatial_features)

        # Reshape for temporal processing
        spatial_features = spatial_features.view(B, F, -1)  # [batch, frames, features]

        # Temporal encoding (now with batch_first=True)
        temporal_features = self.temporal_encoder(spatial_features)

        return temporal_features, spatial_features

## CCM Module

In [11]:
class CCMModule(nn.Module):
    def __init__(self, d_model=768, num_heads=8, num_queries=8):
        super().__init__()
        self.num_queries = num_queries
        self.query_centers = nn.Parameter(torch.randn(num_queries, d_model))

        # Modified decoder layer
        self.decoder = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=2048,
            batch_first=True  # Important for dimension handling
        )

    def forward(self, video_features: torch.Tensor, text_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = video_features.size(0)

        # Expand queries for batch dimension
        query = self.query_centers.unsqueeze(0).expand(batch_size, -1, -1)

        # Apply decoder to both modalities
        video_aligned = self.decoder(query, video_features)
        text_aligned = self.decoder(query, text_features)

        return video_aligned, text_aligned

## CRET Model

In [12]:
class CRET(nn.Module):
    def __init__(self, d_model=768):
        super().__init__()
        self.video_encoder = VideoEncoder(d_model)
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.ccm = CCMModule(d_model)

    def forward(self, video: torch.Tensor, text_ids: torch.Tensor, text_mask: torch.Tensor):
        temporal_features, spatial_features = self.video_encoder(video)
        text_outputs = self.text_encoder(text_ids, attention_mask=text_mask)
        text_features = text_outputs.last_hidden_state
        video_aligned, text_aligned = self.ccm(spatial_features, text_features)
        video_global = temporal_features.mean(1)
        text_global = text_outputs.pooler_output
        return {
            'video_aligned': video_aligned,
            'text_aligned': text_aligned,
            'video_global': video_global,
            'text_global': text_global
        }

## GEES Loss

In [13]:
class GEESLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, video_features: torch.Tensor, text_features: torch.Tensor):
        # Make sure video_features has the right shape [batch_size, seq_len, features]
        if len(video_features.shape) == 2:
            video_features = video_features.unsqueeze(1)

        # Compute mean over sequence dimension (if exists)
        if len(video_features.shape) == 3:
            video_mean = video_features.mean(1)
        else:
            video_mean = video_features

        # Compute covariance
        batch_size = video_features.size(0)
        feature_dim = video_features.size(-1)

        # Reshape video features for batch matrix multiplication
        video_features_reshaped = video_features.view(batch_size, -1, feature_dim)

        # Compute covariance for each item in the batch
        video_cov = torch.bmm(video_features_reshaped.transpose(1, 2),
                            video_features_reshaped) / video_features_reshaped.size(1)

        # Compute similarity scores
        sim = torch.mm(text_features, video_mean.t())

        # Compute covariance term
        cov_term = 0.5 * torch.bmm(
            torch.bmm(text_features.unsqueeze(1), video_cov),
            text_features.unsqueeze(-1)
        ).squeeze()

        # Compute logits and loss
        logits = sim + cov_term
        labels = torch.arange(batch_size).to(video_features.device)

        return F.cross_entropy(logits, labels)

## Dataset

In [14]:
class MSRVTTDataset(Dataset):
    def __init__(self, video_paths, captions, tokenizer):
        self.video_paths = video_paths
        self.captions = captions
        self.tokenizer = tokenizer

        # Add cv2 import inside the class
        import cv2
        self.cv2 = cv2

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video = self.load_video(self.video_paths[idx])
        caption = self.captions[idx]
        tokens = self.tokenizer(
            caption,
            padding='max_length',
            truncation=True,
            max_length=128,
            return_tensors='pt'
        )
        return video, tokens.input_ids.squeeze(0), tokens.attention_mask.squeeze(0)

    def load_video(self, path):
        import numpy as np
        frames = []

        # Load video with OpenCV
        cap = self.cv2.VideoCapture(path)

        # Sample 4 frames as mentioned in the paper
        total_frames = int(cap.get(self.cv2.CAP_PROP_FRAME_COUNT))
        indices = np.linspace(0, total_frames-1, 4, dtype=int)

        for frame_idx in indices:
            cap.set(self.cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if ret:
                # Convert BGR to RGB
                frame = self.cv2.cvtColor(frame, self.cv2.COLOR_BGR2RGB)
                # Resize to 224x224
                frame = self.cv2.resize(frame, (224, 224))
                # Convert to tensor and normalize
                frame = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
                frames.append(frame)

        cap.release()

        # Stack frames
        if len(frames) == 0:  # Handle empty video case
            return torch.zeros((4, 3, 224, 224))

        video_tensor = torch.stack(frames)
        return video_tensor

## Training Function

In [15]:
def train_cret(model, train_loader, num_epochs=2, device='cpu'):
    print(f"Training on device: {device}")
    model = model.to(device)
    gees_loss = GEESLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch_idx, (videos, text_ids, text_mask) in enumerate(train_loader):
            try:
                # Move to device
                videos = videos.to(device).float()
                text_ids = text_ids.to(device)
                text_mask = text_mask.to(device)

                # Forward pass
                outputs = model(videos, text_ids, text_mask)

                # Print shapes for debugging
                if batch_idx == 0:
                    print(f"Video global shape: {outputs['video_global'].shape}")
                    print(f"Text global shape: {outputs['text_global'].shape}")
                    print(f"Video aligned shape: {outputs['video_aligned'].shape}")
                    print(f"Text aligned shape: {outputs['text_aligned'].shape}")

                # Compute losses
                ccm_loss = F.cosine_embedding_loss(
                    outputs['video_aligned'].view(-1, 768),
                    outputs['text_aligned'].view(-1, 768),
                    torch.ones(outputs['video_aligned'].size(0) * outputs['video_aligned'].size(1)).to(device)
                )

                gees_loss_val = gees_loss(outputs['video_global'], outputs['text_global'])
                loss = ccm_loss + gees_loss_val

                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

                if batch_idx % 10 == 0:
                    print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')

            except Exception as e:
                print(f"Error in batch {batch_idx}: {str(e)}")
                print(f"Shapes:")
                for k, v in outputs.items():
                    print(f"{k}: {v.shape}")
                continue

        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch} completed, Average Loss: {avg_loss:.4f}')

## Main Execution

In [16]:
def main():
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Load MSRVTT dataset
    msrvtt_root = './data/MSRVTT'
    # Create the videos directory if it doesn't exist
    os.makedirs(os.path.join(msrvtt_root, 'videos'), exist_ok=True)

    # Download video data and captions if they don't exist
    if not os.listdir(os.path.join(msrvtt_root, 'videos')):
        print("Downloading video data...")
        # Replace <updated_video_data_download_link> with the actual download link
        os.system('wget -P ./data/MSRVTT/videos <updated_video_data_download_link>')
    if not os.path.exists(os.path.join(msrvtt_root, 'captions.json')):
        print("Downloading captions...")
        os.system('wget -P ./data/MSRVTT https://www.robots.ox.ac.uk/~vgg/research/collaborative-experts/data/msrvtt_data.json')

    video_paths = [os.path.join(msrvtt_root, 'videos', file) for file in os.listdir(os.path.join(msrvtt_root, 'videos')) if file.endswith('.mp4')]

    with open(os.path.join(msrvtt_root, 'captions.json'), 'r') as f:
        captions_data = json.load(f)

    captions = [captions_data[video_id[:-4]] for video_id in os.listdir(os.path.join(msrvtt_root, 'videos')) if video_id.endswith('.mp4')]
    captions = [[caption['caption'] for caption in captions_per_video] for captions_per_video in captions]
    # Create dataset and dataloader
    dataset = MSRVTTDataset(video_paths, captions, tokenizer)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Initialize model
    model = CRET()

    # Train model
    train_cret(model, train_loader)

In [17]:
def main():
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Set up paths for your Colab environment
    msrvtt_root = '/content/MSRVTT'  # Adjust this to your actual path

    # Get video paths and captions
    video_paths = [os.path.join(msrvtt_root, 'videos', file)
                  for file in os.listdir(os.path.join(msrvtt_root, 'videos'))
                  if file.endswith('.mp4')]

    # Load captions
    with open(os.path.join(msrvtt_root, 'annotation', 'captions.json'), 'r') as f:
        captions_data = json.load(f)

    # Create dataset and dataloader
    dataset = MSRVTTDataset(video_paths[:100], captions[:100], tokenizer)  # Start with 100 videos for testing
    train_loader = DataLoader(dataset, batch_size=8, shuffle=True)  # Smaller batch size for Colab

    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CRET().to(device)

    # Train model
    train_cret(model, train_loader, num_epochs=5, device=device)  # Start with 5 epochs for testing

In [18]:
import torch
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")

GPU available: False


In [19]:
# Check current directory and list files
print("Current directory:", os.getcwd())
print("\nContents of current directory:")
!ls

print("\nContents of /content directory:")
!ls /content

Current directory: /content

Contents of current directory:
MSRVTT	sample_data

Contents of /content directory:
MSRVTT	sample_data


In [20]:
print("Contents of MSRVTT folder:")
!ls /content/MSRVTT

Contents of MSRVTT folder:


In [33]:
def test_dataset():
    print("Testing dataset loading...")

    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Set paths
    msrvtt_root = '/content/MSRVTT'
    print(f"\nChecking contents of MSRVTT directory:")
    !ls {msrvtt_root}

    # Get video paths
    video_paths = [os.path.join(msrvtt_root, f) for f in os.listdir(msrvtt_root) if f.endswith('.mp4')]
    print(f"\nFound {len(video_paths)} video files")

    # For testing, let's use just one video with a dummy caption
    test_dataset = MSRVTTDataset([video_paths[0]], ['a person is doing something'], tokenizer)
    print("\nTrying to load a single video...")

    try:
        video, text_ids, text_mask = test_dataset[0]
        print(f"\nSuccessfully loaded:")
        print(f"Video tensor shape: {video.shape}")
        print(f"Text IDs shape: {text_ids.shape}")
        print(f"Text mask shape: {text_mask.shape}")
    except Exception as e:
        print(f"Error loading video: {e}")

test_dataset()

Testing dataset loading...

Checking contents of MSRVTT directory:
test_annotations.json  video249.mp4  video39.mp4   video54.mp4	 video6.mp4    video850.mp4
video0.mp4	       video24.mp4   video3.mp4    video550.mp4  video700.mp4  video851.mp4
video1000.mp4	       video250.mp4  video400.mp4  video551.mp4  video701.mp4  video852.mp4
video100.mp4	       video251.mp4  video401.mp4  video552.mp4  video702.mp4  video853.mp4
video101.mp4	       video252.mp4  video402.mp4  video553.mp4  video703.mp4  video854.mp4
video102.mp4	       video253.mp4  video403.mp4  video554.mp4  video704.mp4  video855.mp4
video103.mp4	       video254.mp4  video404.mp4  video555.mp4  video705.mp4  video856.mp4
video104.mp4	       video255.mp4  video405.mp4  video556.mp4  video706.mp4  video857.mp4
video105.mp4	       video256.mp4  video406.mp4  video557.mp4  video707.mp4  video858.mp4
video106.mp4	       video257.mp4  video407.mp4  video558.mp4  video708.mp4  video859.mp4
video107.mp4	       video258.mp4  video408

In [34]:
def create_test_annotations():
    video_files = [f for f in os.listdir('/content/MSRVTT') if f.endswith('.mp4')]
    annotations = {}

    for video_file in video_files:
        video_id = video_file[:-4]  # Remove .mp4
        annotations[video_id] = [{
            'caption': f'This is a test caption for video {video_id}',
            'video_id': video_id
        }]

    with open('/content/MSRVTT/test_annotations.json', 'w') as f:
        json.dump(annotations, f)

    print("Created test annotations file")

In [35]:
# Create test annotations
def create_test_annotations():
    video_files = [f for f in os.listdir('/content/MSRVTT') if f.endswith('.mp4')]
    annotations = {}

    for video_file in video_files:
        video_id = video_file[:-4]  # Remove .mp4
        annotations[video_id] = [{
            'caption': f'This is a test caption for video {video_id}',
            'video_id': video_id
        }]

    with open('/content/MSRVTT/test_annotations.json', 'w') as f:
        json.dump(annotations, f)

    print("Created test annotations file")

create_test_annotations()

Created test annotations file


In [36]:
class MSRVTTDataset(Dataset):
    def __init__(self, video_paths, captions, tokenizer):
        self.video_paths = video_paths
        self.captions = captions
        self.tokenizer = tokenizer

        # Add cv2 import inside the class
        import cv2
        self.cv2 = cv2

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video = self.load_video(self.video_paths[idx])
        caption = self.captions[idx]
        tokens = self.tokenizer(
            caption,
            padding='max_length',
            truncation=True,
            max_length=128,
            return_tensors='pt'
        )
        return video, tokens.input_ids.squeeze(0), tokens.attention_mask.squeeze(0)

    def load_video(self, path):
        import numpy as np
        frames = []

        try:
            # Load video with OpenCV
            cap = self.cv2.VideoCapture(path)

            # Get total frames
            total_frames = int(cap.get(self.cv2.CAP_PROP_FRAME_COUNT))
            if total_frames == 0:
                # Return dummy frames if video is empty
                return torch.zeros((4, 3, 224, 224))

            # Sample 4 evenly spaced frames
            indices = np.linspace(0, total_frames-1, 4, dtype=int)

            for frame_idx in indices:
                cap.set(self.cv2.CAP_PROP_POS_FRAMES, frame_idx)
                ret, frame = cap.read()
                if ret:
                    # Convert BGR to RGB
                    frame = self.cv2.cvtColor(frame, self.cv2.COLOR_BGR2RGB)
                    # Resize to 224x224
                    frame = self.cv2.resize(frame, (224, 224))
                    # Convert to tensor and normalize
                    frame = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
                    frames.append(frame)
                else:
                    # If frame read fails, append zero frame
                    frames.append(torch.zeros((3, 224, 224)))

            cap.release()

            # Ensure we have exactly 4 frames
            while len(frames) < 4:
                frames.append(torch.zeros((3, 224, 224)))
            frames = frames[:4]  # Take only first 4 if we somehow got more

            # Stack frames
            video_tensor = torch.stack(frames)
            return video_tensor

        except Exception as e:
            print(f"Error loading video {path}: {str(e)}")
            # Return dummy frames if loading fails
            return torch.zeros((4, 3, 224, 224))

## Monitor Training

In [37]:
from tqdm import tqdm  # For progress bars

def train_cret(model, train_loader, num_epochs=2, device='cpu'):
    print(f"Training on device: {device}")
    model = model.to(device)
    gees_loss = GEESLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')

        for batch_idx, (videos, text_ids, text_mask) in enumerate(progress_bar):
            try:
                videos = videos.to(device).float()
                text_ids = text_ids.to(device)
                text_mask = text_mask.to(device)

                outputs = model(videos, text_ids, text_mask)

                ccm_loss = F.cosine_embedding_loss(
                    outputs['video_aligned'].view(-1, 768),
                    outputs['text_aligned'].view(-1, 768),
                    torch.ones(outputs['video_aligned'].size(0) * outputs['video_aligned'].size(1)).to(device)
                )

                gees_loss_val = gees_loss(outputs['video_global'], outputs['text_global'])
                loss = ccm_loss + gees_loss_val

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                progress_bar.set_postfix({'loss': loss.item()})

            except Exception as e:
                print(f"Error in batch {batch_idx}: {str(e)}")
                continue

        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1} completed, Average Loss: {avg_loss:.4f}')

## Preserve Trained Model

In [38]:
def save_checkpoint(model, optimizer, epoch, loss, filename):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, filename)
    print(f"Checkpoint saved: {filename}")

def train_cret(model, train_loader, num_epochs=2, device='cpu'):
    print(f"Training on device: {device}")
    model = model.to(device)
    gees_loss = GEESLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')

        for batch_idx, (videos, text_ids, text_mask) in enumerate(progress_bar):
            try:
                videos = videos.to(device).float()
                text_ids = text_ids.to(device)
                text_mask = text_mask.to(device)

                outputs = model(videos, text_ids, text_mask)

                ccm_loss = F.cosine_embedding_loss(
                    outputs['video_aligned'].view(-1, 768),
                    outputs['text_aligned'].view(-1, 768),
                    torch.ones(outputs['video_aligned'].size(0) * outputs['video_aligned'].size(1)).to(device)
                )

                gees_loss_val = gees_loss(outputs['video_global'], outputs['text_global'])
                loss = ccm_loss + gees_loss_val

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                progress_bar.set_postfix({'loss': loss.item()})

            except Exception as e:
                print(f"Error in batch {batch_idx}: {str(e)}")
                continue

        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1} completed, Average Loss: {avg_loss:.4f}')

        # Save checkpoint after each epoch
        save_checkpoint(model, optimizer, epoch, avg_loss,
                      f'cret_checkpoint_epoch_{epoch+1}.pt')

## Test Retrieval

In [39]:
def test_retrieval(model, video_path, query_text, tokenizer, device='cpu'):
    model.eval()

    # Load and process video
    dataset = MSRVTTDataset([video_path], [query_text], tokenizer)
    video, text_ids, text_mask = dataset[0]

    with torch.no_grad():
        video = video.unsqueeze(0).to(device)
        text_ids = text_ids.unsqueeze(0).to(device)
        text_mask = text_mask.unsqueeze(0).to(device)

        outputs = model(video, text_ids, text_mask)

        # Compute similarity score
        video_emb = outputs['video_global']
        text_emb = outputs['text_global']
        similarity = F.cosine_similarity(video_emb, text_emb)

        return similarity.item()

## Run Training

In [40]:
def create_test_annotations():
    # Get list of videos
    video_files = sorted([f for f in os.listdir('MSRVTT') if f.endswith('.mp4')])
    print(f"Found {len(video_files)} video files")

    # Create annotations dictionary
    annotations = {}
    for video_file in video_files:
        video_id = video_file[:-4]  # Remove .mp4
        # Store multiple captions per video for better training
        annotations[video_id] = [
            {'caption': f'A video showing content from {video_id}'},
            {'caption': f'This is video file {video_id}'},
            {'caption': f'A clip from video sequence {video_id}'}
        ]

    # Save annotations
    with open('MSRVTT/test_annotations.json', 'w') as f:
        json.dump(annotations, f)

    print("Created test annotations file")
    return annotations  # Added this line to return the annotations

def main():
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Set correct path
    msrvtt_root = 'MSRVTT'
    video_paths = sorted([os.path.join(msrvtt_root, f) for f in os.listdir(msrvtt_root) if f.endswith('.mp4')])
    annotations = create_test_annotations()  # Now this will receive the returned annotations

    # Get captions for each video
    captions = []
    for video_path in video_paths:
        video_id = os.path.basename(video_path)[:-4]
        if video_id in annotations:
            caption = annotations[video_id][0]['caption']
            captions.append(caption)

    print(f"Prepared {len(captions)} video-caption pairs")
    dataset = MSRVTTDataset(video_paths[:len(captions)], captions, tokenizer)
    train_loader = DataLoader(dataset, batch_size=2, shuffle=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CRET().to(device)

    # Train model with progress bar
    train_cret(model, train_loader, num_epochs=2, device=device)

if __name__ == "__main__":
    main()

Found 999 video files
Created test annotations file
Prepared 999 video-caption pairs
Training on device: cpu


Epoch 1/2: 100%|██████████| 500/500 [1:22:54<00:00,  9.95s/it, loss=0.000117]


Epoch 1 completed, Average Loss: 21.1536
Checkpoint saved: cret_checkpoint_epoch_1.pt


Epoch 2/2: 100%|██████████| 500/500 [1:25:22<00:00, 10.25s/it, loss=7.23e-5]


Epoch 2 completed, Average Loss: 1.1550
Checkpoint saved: cret_checkpoint_epoch_2.pt


## CRET Model

In [44]:
# Check what variables exist
print("Current variables:")
%who

Current variables:
BertModel	 BertTokenizer	 CCMModule	 CRET	 DataLoader	 Dataset	 F	 GEESLoss	 List	 
MSRVTTDataset	 Tuple	 VideoEncoder	 create_test_annotations	 cv2	 json	 main	 model_save_path	 models	 
nn	 np	 os	 save_checkpoint	 test_dataset	 test_retrieval	 torch	 tqdm	 train_cret	 



In [3]:
# Import required libraries
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import torchvision.models as models
from typing import List, Tuple
import numpy as np
from tqdm import tqdm

# 1. Define VideoEncoder
class VideoEncoder(nn.Module):
    def __init__(self, d_model=768):
        super().__init__()
        self.spatial_encoder = models.resnet50(pretrained=True)
        self.spatial_encoder = nn.Sequential(*list(self.spatial_encoder.children())[:-1])
        self.projection = nn.Linear(2048, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=8)
        self.temporal_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, F, C, H, W = x.shape
        x = x.view(B*F, C, H, W)
        spatial_features = self.spatial_encoder(x)
        spatial_features = spatial_features.view(B, F, -1)
        spatial_features = self.projection(spatial_features)
        temporal_features = self.temporal_encoder(spatial_features.transpose(0, 1)).transpose(0, 1)
        return temporal_features, spatial_features

# 2. Define CCMModule
class CCMModule(nn.Module):
    def __init__(self, d_model=768, num_heads=8, num_queries=8):
        super().__init__()
        self.num_queries = num_queries
        self.query_centers = nn.Parameter(torch.randn(num_queries, d_model))
        self.decoder = nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads)

    def forward(self, video_features: torch.Tensor, text_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        query = self.query_centers.unsqueeze(0).expand(video_features.size(0), -1, -1)
        video_aligned = self.decoder(query, video_features)
        text_aligned = self.decoder(query, text_features)
        return video_aligned, text_aligned

# 3. Define CRET
class CRET(nn.Module):
    def __init__(self, d_model=768):
        super().__init__()
        self.video_encoder = VideoEncoder(d_model)
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.ccm = CCMModule(d_model)

    def forward(self, video: torch.Tensor, text_ids: torch.Tensor, text_mask: torch.Tensor):
        temporal_features, spatial_features = self.video_encoder(video)
        text_outputs = self.text_encoder(text_ids, attention_mask=text_mask)
        text_features = text_outputs.last_hidden_state
        video_aligned, text_aligned = self.ccm(spatial_features, text_features)
        video_global = temporal_features.mean(1)
        text_global = text_outputs.pooler_output
        return {
            'video_aligned': video_aligned,
            'text_aligned': text_aligned,
            'video_global': video_global,
            'text_global': text_global
        }

# 4. Define GEESLoss
class GEESLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, video_features: torch.Tensor, text_features: torch.Tensor):
        video_mean = video_features.mean(1)
        video_cov = torch.bmm(video_features.transpose(1, 2), video_features) / video_features.size(1)
        sim = torch.mm(text_features, video_mean.t())
        cov_term = 0.5 * torch.bmm(torch.bmm(text_features.unsqueeze(1), video_cov), text_features.unsqueeze(-1)).squeeze()
        logits = sim + cov_term
        labels = torch.arange(video_features.size(0)).to(video_features.device)
        return F.cross_entropy(logits, labels)

# 5. Define MSRVTTDataset
class MSRVTTDataset(Dataset):
    def __init__(self, video_paths, captions, tokenizer):
        self.video_paths = video_paths
        self.captions = captions
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video = self.load_video(self.video_paths[idx])
        caption = self.captions[idx]
        tokens = self.tokenizer(caption, padding='max_length', truncation=True,
                              max_length=128, return_tensors='pt')
        return video, tokens.input_ids.squeeze(0), tokens.attention_mask.squeeze(0)

    def load_video(self, path):
        import cv2
        frames = []
        cap = cv2.VideoCapture(path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        indices = np.linspace(0, total_frames-1, 4, dtype=int)

        for frame_idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, (224, 224))
                frame = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
                frames.append(frame)

        cap.release()
        video_tensor = torch.stack(frames)
        return video_tensor

# Now initialize model and continue with your code
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CRET().to(device)

# Initialize tokenizer and continue with rest of your code...

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
# Save the model
model_save_path = 'cret_model.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to cret_model.pth


In [9]:
# 1. First update VideoEncoder
class VideoEncoder(nn.Module):
    def __init__(self, d_model=768):
        super().__init__()
        self.d_model = d_model
        # Spatial encoder (ResNet)
        self.spatial_encoder = models.resnet50(pretrained=True)
        self.spatial_encoder = nn.Sequential(*list(self.spatial_encoder.children())[:-1])
        self.projection = nn.Linear(2048, d_model)

        # Temporal encoder with fixed parameters
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=8,
            batch_first=True,
            dim_feedforward=2048
        )
        self.temporal_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, F, C, H, W = x.shape  # Batch, Frames, Channels, Height, Width

        # Process each frame
        x = x.view(B * F, C, H, W)
        x = self.spatial_encoder(x)
        x = x.squeeze(-1).squeeze(-1)  # Remove spatial dimensions
        x = self.projection(x)

        # Reshape for temporal processing
        x = x.view(B, F, -1)  # [batch, frames, features]

        # Apply temporal encoding
        temporal_features = self.temporal_encoder(x)

        return temporal_features, x

# 2. Update CCMModule
class CCMModule(nn.Module):
    def __init__(self, d_model=768, num_heads=8, num_queries=8):
        super().__init__()
        self.num_queries = num_queries
        self.d_model = d_model

        # Learnable query centers
        self.query_centers = nn.Parameter(torch.randn(1, num_queries, d_model))

        # Transformer decoder layer with fixed parameters
        self.decoder = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=num_heads,
            batch_first=True,
            dim_feedforward=2048
        )

    def forward(self, video_features: torch.Tensor, text_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = video_features.size(0)

        # Expand queries for batch
        query = self.query_centers.expand(batch_size, -1, -1)

        # Apply decoder to both modalities
        video_aligned = self.decoder(query, video_features)
        text_aligned = self.decoder(query, text_features)

        return video_aligned, text_aligned

# 3. Re-initialize the model with updated components
model = CRET().to(device)

In [10]:
def retrieve_video(query_text, top_k=5):
    model.eval()
    device = next(model.parameters()).device

    # Process query
    tokens = tokenizer(
        query_text,
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    ).to(device)

    # Get video paths (start with just a few videos)
    video_paths = sorted([os.path.join('MSRVTT', f) for f in os.listdir('MSRVTT')
                         if f.endswith('.mp4')])[:10]  # Start with 10 videos

    results = []
    with torch.no_grad():
        for video_path in tqdm(video_paths, desc="Processing videos"):
            try:
                # Load video
                dataset = MSRVTTDataset([video_path], ["dummy"], tokenizer)
                video, _, _ = dataset[0]
                video = video.unsqueeze(0).to(device)

                # Print shapes for debugging
                if len(results) == 0:
                    print(f"Video shape: {video.shape}")
                    print(f"Text IDs shape: {tokens.input_ids.shape}")

                # Get embeddings
                outputs = model(video, tokens.input_ids, tokens.attention_mask)

                # Calculate similarity
                similarity = F.cosine_similarity(
                    outputs['video_global'],
                    outputs['text_global']
                ).item()

                results.append((video_path, similarity))

            except Exception as e:
                print(f"Error processing {video_path}: {str(e)}")
                print(f"Shapes at error:")
                print(f"Video shape: {video.shape}")
                print(f"Text shape: {tokens.input_ids.shape}")
                continue

    # Sort by similarity
    results.sort(key=lambda x: x[1], reverse=True)
    return results[:top_k]

# Test with a single query
test_query = "a person is cooking food"
print(f"\nQuery: {test_query}")
results = retrieve_video(test_query)
print("\nTop matching videos:")
for i, (video_path, similarity) in enumerate(results, 1):
    print(f"{i}. {os.path.basename(video_path)} (similarity: {similarity:.3f})")


Query: a person is cooking food


Processing videos:   0%|          | 0/10 [00:00<?, ?it/s]

Video shape: torch.Size([1, 4, 3, 224, 224])
Text IDs shape: torch.Size([1, 128])


Processing videos: 100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


Top matching videos:
1. video1000.mp4 (similarity: -0.011)
2. video102.mp4 (similarity: -0.016)
3. video0.mp4 (similarity: -0.033)
4. video104.mp4 (similarity: -0.045)
5. video1.mp4 (similarity: -0.050)





## Interface

In [1]:
!pip install gradio



In [5]:
def retrieve_video_from_larger_set(query, num_videos=100, top_k=5):
    """
    Retrieves the top-k most relevant video paths from a larger set of videos based on the given query.

    Args:
        query (str): The text query to search for.
        num_videos (int): The number of videos to consider in the larger set.
        top_k (int): The number of most relevant videos to return.

    Returns:
        list: A list of tuples, where each tuple contains the video path and the similarity score.
    """
    # Implement the CRET model logic to retrieve the top-k most relevant video paths
    # This might involve:
    # 1. Encoding the query using a text encoder
    # 2. Encoding the video metadata (titles, descriptions, etc.) using the same encoder
    # 3. Calculating the similarity between the query and the video metadata
    # 4. Sorting the videos by similarity and returning the top-k results

    # For now, this is a placeholder implementation that simply returns a list of random video paths
    import random
    video_paths = [f"video_{i}.mp4" for i in range(num_videos)]
    similarities = [random.uniform(0, 1) for _ in range(num_videos)]

    # Sort the results by similarity in descending order
    sorted_results = sorted(zip(video_paths, similarities), key=lambda x: x[1], reverse=True)

    return sorted_results[:top_k]

In [11]:
import cv2
from PIL import Image
import numpy as np

def process_video_for_display(video_path, max_frames=6):
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Could not open video: {video_path}")
            return []

        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames == 0:
            print(f"Video {video_path} has no frames.")
            return []

        indices = np.linspace(0, total_frames-1, max_frames, dtype=int)

        frames = []
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(Image.fromarray(frame))
            else:
                print(f"Error reading frame {idx} from video {video_path}")

        cap.release()
        return frames
    except Exception as e:
        print(f"Error processing video {video_path}: {str(e)}")
        return []

In [12]:
def video_search(query, num_results=5):
    try:
        # Process query through CRET model
        results = retrieve_video_from_larger_set(query, num_videos=100, top_k=num_results)

        if not results:
            return [], "No results found"

        # Process results for display
        output_frames = []
        captions = []

        for video_path, similarity in results:
            frames = process_video_for_display(video_path)
            if frames:
                output_frames.extend(frames)
                video_name = os.path.basename(video_path)
                captions.extend([f"{video_name} (similarity: {similarity:.3f})"] * len(frames))

        if not output_frames:
            return [], "No frames could be extracted"

        return output_frames, f"Found {len(results)} matching videos"

    except Exception as e:
        return [], f"Error: {str(e)}"

In [22]:
from google.colab import files
uploaded = files.upload()  # Follow the prompt to upload your video files

Saving video0.mp4 to video0.mp4
Saving video1.mp4 to video1.mp4
Saving video2.mp4 to video2.mp4
Saving video3.mp4 to video3.mp4
Saving video4.mp4 to video4.mp4
Saving video5.mp4 to video5.mp4
Saving video6.mp4 to video6.mp4
Saving video7.mp4 to video7.mp4
Saving video8.mp4 to video8.mp4
Saving video9.mp4 to video9.mp4
Saving video10.mp4 to video10.mp4
Saving video11.mp4 to video11.mp4
Saving video12.mp4 to video12.mp4
Saving video13.mp4 to video13.mp4
Saving video14.mp4 to video14.mp4
Saving video15.mp4 to video15.mp4
Saving video16.mp4 to video16.mp4
Saving video17.mp4 to video17.mp4
Saving video18.mp4 to video18.mp4
Saving video19.mp4 to video19.mp4
Saving video20.mp4 to video20.mp4
Saving video21.mp4 to video21.mp4
Saving video22.mp4 to video22.mp4
Saving video23.mp4 to video23.mp4
Saving video24.mp4 to video24.mp4
Saving video25.mp4 to video25.mp4
Saving video26.mp4 to video26.mp4
Saving video27.mp4 to video27.mp4
Saving video28.mp4 to video28.mp4
Saving video29.mp4 to video29.mp4


In [31]:
pip install opencv-python



In [32]:
import cv2
import os

# Specify the path to your video file
video_path = 'videos/video1.mp4'  # Adjust this path accordingly

# Print absolute path for debugging
print(f"Attempting to open video file at: {os.path.abspath(video_path)}")

# Create a VideoCapture object
cap = cv2.VideoCapture(video_path)

# Check if the video opened successfully
if not cap.isOpened():
    print(f"Could not open video: {video_path}")
else:
    print("Video opened successfully.")

# Read and display the video frame by frame
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        print("End of video or cannot read the video frame.")
        break

    # Display the frame in a window
    cv2.imshow('Video', frame)

    # Press 'q' on the keyboard to exit the video early
    if cv2.waitKey(25) & 0xFF == ord('q'):
        break

# Release the video capture object and close all OpenCV windows
cap.release()
cv2.destroyAllWindows()

Attempting to open video file at: /content/videos/video1.mp4
Could not open video: videos/video1.mp4


In [46]:
import os
import gradio as gr
import cv2
from PIL import Image
import numpy as np

# Sample video database
video_database = [
    {"path": "MSRVTT/video1.mp4", "tags": ["cooking", "food"]},
    {"path": "MSRVTT/video2.mp4", "tags": ["dancing", "party", "performance"]},
    {"path": "MSRVTT/video3.mp4", "tags": ["sports", "exercise", "game"]},
]

In [42]:
def process_video_for_display(video_path, max_frames=6):
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Error: Could not open video '{video_path}'.")
            return []

        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames == 0:
            print(f"Video '{video_path}' has no frames.")
            return []

        # Ensure we have at least 1 frame
        if total_frames < max_frames:
            max_frames = total_frames

        indices = np.linspace(0, total_frames-1, max_frames, dtype=int)

        frames = []
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(Image.fromarray(frame))
            else:
                print(f"Error reading frame {idx} from video '{video_path}'.")

        cap.release()
        return frames
    except Exception as e:
        print(f"Error processing video '{video_path}': {str(e)}")
        return []

In [43]:
def retrieve_video_from_larger_set(query, num_videos=1, top_k=5):
    relevant_videos = []
    query_tags = query.lower().split()  # Tokenize the input query

    print(f"Query: {query} | Query Tags: {query_tags}")  # Debugging purpose

    for video in video_database:
        # Check if any of the query tags are in the video's tags
        if any(tag in video['tags'] for tag in query_tags):
            relevant_videos.append((video['path'], 1.0))  # Mock similarity score
            print(f"Match found: {video}")  # Debugging

    # Sort by similarity score (if there were real scores) and take the top k results
    relevant_videos.sort(key=lambda x: x[1], reverse=True)
    return relevant_videos[:top_k]

In [44]:
def video_search(query, num_results=5):
    try:
        results = retrieve_video_from_larger_set(query, top_k=num_results)

        if not results:
            return [], "No results found."

        output_frames = []
        captions = []

        for video_path, similarity in results:
            frames = process_video_for_display(video_path)
            if frames:
                output_frames.extend(frames)
                video_name = os.path.basename(video_path)
                captions.extend([f"{video_name} (similarity: {similarity:.3f})"] * len(frames))
            else:
                print(f"Warning: No frames extracted from video '{video_path}'.")

        if not output_frames:
            return [], "No frames could be extracted."

        return output_frames, f"Found {len(results)} matching videos."

    except Exception as e:
        return [], f"Error: {str(e)}"

In [45]:
# Example usage
output_frames, status = video_search("cooking")

Query: cooking | Query Tags: ['cooking']
Match found: {'path': 'videos/video1.mp4', 'tags': ['cooking', 'food']}
Error: Could not open video 'videos/video1.mp4'.


In [48]:
# Gradio Interface
with gr.Blocks(title="Video Search", theme=gr.themes.Base()) as demo:
    gr.Markdown("""
    # Video Search using CRET Model
    Search through videos using natural language descriptions.
    """)

    with gr.Row():
        with gr.Column(scale=3):
            query_input = gr.Textbox(
                label="Enter your text query",
                placeholder="Example: a person is cooking food",
                lines=2
            )

            example_queries = gr.Examples(
                examples=[
                    "cooking",
                    "dancing",
                    "sports"
                ],
                inputs=query_input,
                label="Example Queries"
            )

        with gr.Column(scale=1):
            num_results = gr.Slider(
                minimum=1,
                maximum=10,
                value=5,
                step=1,
                label="Number of results"
            )

    search_button = gr.Button("Search", variant="primary")

    status = gr.Textbox(label="Status", value="Ready", interactive=False)

    gallery = gr.Gallery(
        label="Results",
        show_label=True,
        columns=3,
        rows=None,
        height="500px"
    )

    search_button.click(
        fn=video_search,
        inputs=[query_input, num_results],
        outputs=[gallery, status]
    )

# Launch the app
demo.launch(debug=True, share=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://1eb81cb5e4cb600e62.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Query: car | Query Tags: ['car']
Query: dancing | Query Tags: ['dancing']
Match found: {'path': 'MSRVTT/video2.mp4', 'tags': ['dancing', 'party', 'performance']}
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7861 <> https://1eb81cb5e4cb600e62.gradio.live


