In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as transforms
from typing import List, Dict, Tuple
import numpy as np
from sklearn.metrics import f1_score
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import os
import random
from torch.optim.lr_scheduler import StepLR

# Set random seeds
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

class MultiFrameScorer(nn.Module):
    def __init__(self, pretrained: bool = True):
        super().__init__()

        # Image encoder (ResNet-50); When frames extracted from gmflow is passed
        self.image_encoder = models.resnet50(pretrained=pretrained)
        self.image_encoder = nn.Sequential(*list(self.image_encoder.children())[:-1])

        # Text encoder (BERT); Where generated summaries from Modified_DPO is passed
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.text_encoder = AutoModel.from_pretrained('bert-base-uncased')

        # Projection layers
        self.image_projection = nn.Sequential(
            nn.Linear(2048, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.text_projection = nn.Sequential(
            nn.Linear(768, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # Frame sequence encoder (Transformer)
        self.frame_sequence_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=512,
                nhead=8,
                dim_feedforward=2048,
                dropout=0.2
            ),
            num_layers=4
        )

        # Frame selection attention
        self.frame_attention = nn.MultiheadAttention(
            embed_dim=512,
            num_heads=8,
            dropout=0.2
        )

        # Image preprocessing
        self.image_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),  # Data augmentation
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

    def encode_frames(self, frames: torch.Tensor) -> torch.Tensor:
        """Encode batch of frames"""
        batch_size, num_frames = frames.shape[:2]
        # Reshape to encode all frames
        frames_flat = frames.view(-1, 3, 224, 224)
        features = self.image_encoder(frames_flat)
        features = features.squeeze()
        features = self.image_projection(features)
        # Reshape back to batch format
        return features.view(batch_size, num_frames, -1)

    def encode_summary(self, summaries: List[str]) -> torch.Tensor:
        """Encode text summaries"""
        tokens = self.tokenizer(summaries, padding=True, truncation=True,
                              return_tensors="pt", max_length=128)
        tokens = {k: v.to(next(self.parameters()).device) for k, v in tokens.items()}
        text_features = self.text_encoder(**tokens).last_hidden_state.mean(dim=1)
        return self.text_projection(text_features)

    def compute_similarity(self, frame_features: torch.Tensor, summary_features: torch.Tensor) -> torch.Tensor:
        """
        Compute cosine similarity between frame features and summary features.
        frame_features: [batch_size, num_frames, 512]
        summary_features: [batch_size, 512]
        Returns: [batch_size, num_frames] similarity scores
        """
        # Normalize features
        frame_features = F.normalize(frame_features, p=2, dim=-1)
        summary_features = F.normalize(summary_features, p=2, dim=-1)

        # Expand summary features to match frame features
        summary_features = summary_features.unsqueeze(1)  # [batch_size, 1, 512]

        # Compute cosine similarity
        similarity = torch.sum(frame_features * summary_features, dim=-1)  # [batch_size, num_frames]
        return similarity

    def forward(self, frames: torch.Tensor, summaries: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass returning similarity scores and attention weights.
        frames: [batch_size, num_frames, 3, 224, 224]
        """
        # Encode frames and summary
        frame_features = self.encode_frames(frames)  # [batch_size, num_frames, 512]
        summary_features = self.encode_summary(summaries)  # [batch_size, 512]

        # Compute similarity scores
        similarity_scores = self.compute_similarity(frame_features, summary_features)  # [batch_size, num_frames]

        # Process frame sequence (optional, for attention weights)
        frame_seq = self.frame_sequence_encoder(frame_features.transpose(0, 1)).transpose(0, 1)
        summary_expanded = summary_features.unsqueeze(1)  # [batch_size, 1, 512]
        _, attention_weights = self.frame_attention(
            summary_expanded.transpose(0, 1),
            frame_seq.transpose(0, 1),
            frame_seq.transpose(0, 1)
        )

        return similarity_scores, attention_weights


class MultiFrameDataset(Dataset):
    def __init__(self, frame_sequences: List[List[str]],
                 summaries: List[str],
                 target_frames: List[List[int]],
                 transform=None,
                 max_frames: int = 10):  # Maximum number of frames per sequence
        """
        frame_sequences: List of lists containing paths to frame sequences
        summaries: List of summary texts
        target_frames: List of lists containing indices of correct frames
        max_frames: Maximum number of frames to pad/truncate to
        """
        self.frame_sequences = frame_sequences
        self.summaries = summaries
        self.target_frames = target_frames
        self.transform = transform
        self.max_frames = max_frames

    def __len__(self):
        return len(self.summaries)

    def __getitem__(self, idx):
        # Load all frames for this sequence
        frames = []
        for frame_path in self.frame_sequences[idx]:
            image = Image.open(frame_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            frames.append(image)

        # Pad or truncate frames to a fixed length
        if len(frames) < self.max_frames:
            # Pad with black frames
            padding = [torch.zeros_like(frames[0]) for _ in range(self.max_frames - len(frames))]
            frames.extend(padding)
        else:
            # Truncate to max_frames
            frames = frames[:self.max_frames]

        frames_tensor = torch.stack(frames)

        # Create target mask (1 for selected frames, 0 for others)
        target_mask = torch.zeros(self.max_frames)
        target_indices = [min(idx, self.max_frames - 1) for idx in self.target_frames[idx]]  # Ensure indices are within bounds
        target_mask[target_indices] = 1

        return {
            'frames': frames_tensor,
            'summary': self.summaries[idx],
            'target_mask': target_mask
        }

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable-length frame sequences.
    """
    frames = torch.stack([item['frames'] for item in batch])
    summaries = [item['summary'] for item in batch]
    target_masks = torch.stack([item['target_mask'] for item in batch])

    return {
        'frames': frames,
        'summary': summaries,
        'target_mask': target_masks
    }

class MultiFrameTrainer:
    def __init__(self, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.model = MultiFrameScorer().to(device)

    def train(self, train_dataset: MultiFrameDataset,
              val_dataset: MultiFrameDataset = None,
              batch_size: int = 8, epochs: int = 10):
        """Train the multi-frame scorer using BCE and cosine similarity loss"""
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        scheduler = StepLR(optimizer, step_size=3, gamma=0.1)  # Learning rate scheduler
        criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            collate_fn=custom_collate_fn
        )

        for epoch in range(epochs):
            self.model.train()
            total_loss = 0

            for batch in train_loader:
                frames = batch['frames'].to(self.device)
                summaries = batch['summary']
                target_masks = batch['target_mask'].to(self.device)

                optimizer.zero_grad()

                # Get model predictions
                similarity_scores, _ = self.model(frames, summaries)

                # Compute BCE loss
                loss = criterion(similarity_scores, target_masks.float())

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)  # Gradient clipping
                optimizer.step()

                total_loss += loss.item()

            scheduler.step()  # Update learning rate
            avg_loss = total_loss / len(train_loader)
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

            # Evaluate on validation set
            if val_dataset:
                metrics = evaluate_model(self, val_dataset, threshold=0.5)
                print(f"Validation Metrics: {metrics}")
    def select_frames(self, frames: List[str], summary: str, threshold: float = 0.5, max_frames: int = 10) -> List[Dict]:
            """
            Select frames that best match the given summary based on similarity scores.
            frames: List of frame paths.
            summary: Text summary.
            threshold: Similarity score threshold for selecting frames.
            max_frames: Maximum number of frames to consider.
            Returns: List of selected frames with their scores and positions.
            """
            self.model.eval()

            # Prepare images
            processed_frames = []
            for frame_path in frames:
                image = Image.open(frame_path).convert('RGB')
                image = self.model.image_transform(image)
                processed_frames.append(image)

            frames_tensor = torch.stack(processed_frames).unsqueeze(0).to(self.device)

            with torch.no_grad():
                similarity_scores, _ = self.model(frames_tensor, [summary])
                similarity_scores = similarity_scores.squeeze().cpu().numpy()

            # Select frames with similarity scores above the threshold
            selected_indices = np.where(similarity_scores > threshold)[0]

            # Ensure selected indices are within bounds and valid
            selected_indices = [idx for idx in selected_indices if idx < max_frames]

            # If no frames are selected, return an empty list
            if not selected_indices:
                return []

            # Convert selected_indices to a NumPy array of integers
            selected_indices = np.array(selected_indices, dtype=int)

            # Sort selected frames by their scores
            sorted_indices = selected_indices[np.argsort(-similarity_scores[selected_indices])]

            results = []
            for idx in sorted_indices:
                results.append({
                    'frame_path': frames[idx],
                    'similarity_score': float(similarity_scores[idx]),
                    'position': idx
                })

            return results

def evaluate_model(trainer, dataset, threshold: float = 0.5):
    """
    Evaluate the model on the given dataset.
    """
    trainer.model.eval()
    all_targets = []
    all_predictions = []
    all_frames = []
    all_selected_frames = []

    for idx in range(len(dataset)):
        # Get data
        sample = dataset[idx]
        frames = sample['frames'].unsqueeze(0).to(trainer.device)
        summary = sample['summary']
        target_mask = sample['target_mask'].numpy()

        # Select frames based on similarity scores
        selected_frames = trainer.select_frames(
            dataset.frame_sequences[idx],
            summary,
            threshold=threshold,
            max_frames=dataset.max_frames  # Pass max_frames to select_frames
        )

        # If no frames are selected, skip this example
        if not selected_frames:
            continue

        selected_indices = [frame['position'] for frame in selected_frames]

        # Create prediction mask
        prediction_mask = torch.zeros_like(sample['target_mask'])
        prediction_mask[selected_indices] = 1

        # Store results
        all_targets.append(target_mask)
        all_predictions.append(prediction_mask.numpy())
        all_frames.append(frames.squeeze(0).cpu().numpy())
        all_selected_frames.append(frames.squeeze(0)[selected_indices].cpu().numpy())

    # If no frames were selected for any example, return default metrics
    if not all_targets:
        return {
            'RMSE': float('inf'),
            'F1 Score': 0.0,
            'SSIM': 0.0,
            'PSNR': 0.0
        }

    # Flatten lists
    all_targets = np.concatenate(all_targets)
    all_predictions = np.concatenate(all_predictions)
    all_frames = np.concatenate(all_frames)
    all_selected_frames = np.concatenate(all_selected_frames)

    # Compute RMSE
    rmse = np.sqrt(np.mean((all_targets - all_predictions) ** 2))

    # Compute F1 Score
    f1 = f1_score(all_targets, all_predictions, average='binary')

    # Compute SSIM and PSNR
    ssim_scores = []
    psnr_scores = []
    for target_frame, selected_frame in zip(all_frames, all_selected_frames):
        # Ensure the image is large enough for SSIM
        min_side = min(target_frame.shape[0], target_frame.shape[1])
        win_size = min(7, min_side)  # Set win_size to the smallest odd value <= min_side
        if win_size < 3:  # SSIM requires win_size >= 3
            ssim_scores.append(0.0)  # Default value for small images
        else:
            ssim_scores.append(ssim(
                target_frame,
                selected_frame,
                win_size=win_size,
                channel_axis=-1 if target_frame.shape[-1] == 3 else None,
                data_range=1.0  # Assuming normalized images (pixel range [0, 1])
            ))

        # Compute PSNR
        psnr_scores.append(psnr(
            target_frame,
            selected_frame,
            data_range=1.0  # Assuming normalized images (pixel range [0, 1])
        ))

    mean_ssim = np.mean(ssim_scores)
    mean_psnr = np.mean(psnr_scores)

    return {
        'RMSE': rmse,
        'F1 Score': f1,
        'SSIM': mean_ssim,
        'PSNR': mean_psnr
    }

# Example usage
if __name__ == "__main__":

    # Create training dataset
    train_dataset = MultiFrameDataset(
        frame_sequences=frame_sequences,
        summaries=summaries,
        target_frames=target_frames,
        transform=MultiFrameScorer().image_transform,
        max_frames=20  # Set a maximum number of frames
    )

    # Create validation dataset (similar to train_dataset)
    val_dataset = MultiFrameDataset(
        frame_sequences=frame_sequences,
        summaries=summaries,
        target_frames=target_frames,
        transform=MultiFrameScorer().image_transform,
        max_frames=20
    )
    test_dataset = MultiFrameDataset(
        frame_sequences=test_frame_seq,
        summaries=test_summary,
        target_frames=test_target_frames,
        transform=MultiFrameScorer().image_transform,
        max_frames=20  # Set a maximum number of frames
    )
    # Train model
    trainer = MultiFrameTrainer()
    trainer.train(train_dataset, test_dataset, epochs=5)





In [None]:
# Create test dataset
    test_dataset = MultiFrameDataset(
        frame_sequences=test_frame_seq,
        summaries=test_summary,
        target_frames=test_target_frames,
        transform=MultiFrameScorer().image_transform,
        max_frames=20  # Set a maximum number of frames
    )

    # Evaluate on test dataset
    test_metrics = evaluate_model(trainer, test_dataset, threshold=0.5)
    print("\nTest Metrics:")
    print(f"RMSE: {test_metrics['RMSE']:.4f}")
    print(f"F1 Score: {test_metrics['F1 Score']:.4f}")
    print(f"SSIM: {test_metrics['SSIM']:.4f}")
    print(f"PSNR: {test_metrics['PSNR']:.4f}")