<a href="https://colab.research.google.com/github/qu-romana/ML_Research_Code/blob/main/3_TokenGating_And_SparseAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Key Improvements to Enhance BLIP Metrics

The improved implementation incorporates several significant changes to enhance the metrics of the token-gated BLIP model for captioning:

## 1. Enhanced Token Gating Mechanism
- **Deeper MLP Architecture**: The token gating network now uses a 3-layer MLP instead of a 2-layer one, increasing its capacity to learn complex token importance patterns.
- **Layer Normalization**: Added LayerNorm before gate computation for more stable training.
- **Explicit Gate Regularization**: Incorporated L1 regularization on gate values to encourage sparse gating, allowing the model to focus more clearly on important tokens.

## 2. Integration with Training Process
- **Fully Integrated Forward Pass**: Previously, token gating was not being used during training and was essentially acting as a post-processing step. Now, token gating is fully incorporated into the forward pass, enabling the model to actually learn to optimize the gating mechanism.
- **Loss Modification**: The loss function now includes the gate regularization term, providing a direct signal to the model to balance between accurate caption generation and token sparsity.

## 3. Improved Sparse Attention
- **Learnable Temperature Parameter**: Added a learnable temperature parameter to the attention mechanism, allowing the model to adjust the sharpness of attention distributions.
- **Improved Parameter Initialization**: Better initialization of attention weights for more stable training.
- **Noise Addition**: Small noise added to sparse attention masks during training to help gradient flow and avoid getting stuck in local minima.

## 4. VAE Latent Handling
- **Dedicated VAE Latent Adapter**: Added a specialized adapter network that properly processes VAE latents into a format compatible with BLIP's vision transformer.
- **Improved Normalization**: Enhanced normalization strategy for VAE latents, using a combination of per-sample z-score normalization and global scaling.
- **Robust Error Handling**: Better handling of variable-sized latents, ensuring that inputs are properly processed even with inconsistent dimensions.

## 5. Enhanced Training Strategy
- **Learning Rate Schedule with Warmup**: Implemented a learning rate schedule with initial warmup and gradual decay, which typically improves convergence.
- **Focused Metrics Calculation**: Only computing the expensive metrics (CIDEr, SPICE) periodically to speed up training, while using BLEU for regular evaluations.
- **Weighted Metric Combination**: Using a weighted combination of metrics for model selection, rather than a simple sum.
- **Adaptive Early Stopping**: Implemented a patience mechanism that increases in later epochs, allowing the model more time to fine-tune.

## 6. Generation Improvements
- **Better Decoding Parameters**: Enhanced the generation process with nucleus sampling, temperature adjustment, and other parameters that typically lead to better quality captions.
- **Comprehensive Metric Evaluation**: Added METEOR and BLEU-1 to the evaluation metrics for more complete assessment.

These changes directly address the limitations in the original implementation, particularly by ensuring that the token gating and sparse attention mechanisms are actually used during training rather than just being passive components. The gate regularization also encourages the model to focus more explicitly on important tokens, which should lead to more accurate and focused captions.

In [1]:
# ================================
# Installation & Setup
# ================================
# Uninstall current numpy and force install numpy==1.21.6
!pip uninstall -y numpy
!pip install numpy==1.21.6 --force-reinstall

# Uninstall then install specific versions for other dependencies
!pip uninstall -y tensorflow transformers
!pip install tensorflow==2.15.0
!pip install transformers==4.49.0 datasets==3.6.0
!pip install torch torchvision nltk pycocotools
!pip install pycocoevalcap==1.2 matplotlib seaborn
!pip install accelerate
!pip install -U transformers datasets

Found existing installation: numpy 1.26.4
Uninstalling numpy-1.26.4:
  Successfully uninstalled numpy-1.26.4
[31mERROR: Ignored the following versions that require a different python version: 1.21.2 Requires-Python >=3.7,<3.11; 1.21.3 Requires-Python >=3.7,<3.11; 1.21.4 Requires-Python >=3.7,<3.11; 1.21.5 Requires-Python >=3.7,<3.11; 1.21.6 Requires-Python >=3.7,<3.11[0m[31m
[0m[31mERROR: Could not find a version that satisfies the requirement numpy==1.21.6 (from versions: 1.3.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, 1.6.1, 1.6.2, 1.7.0, 1.7.1, 1.7.2, 1.8.0, 1.8.1, 1.8.2, 1.9.0, 1.9.1, 1.9.2, 1.9.3, 1.10.0.post2, 1.10.1, 1.10.2, 1.10.4, 1.11.0, 1.11.1, 1.11.2, 1.11.3, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 1.13.3, 1.14.0, 1.14.1, 1.14.2, 1.14.3, 1.14.4, 1.14.5, 1.14.6, 1.15.0, 1.15.1, 1.15.2, 1.15.3, 1.15.4, 1.16.0, 1.16.1, 1.16.2, 1.16.3, 1.16.4, 1.16.5, 1.16.6, 1.17.0, 1.17.1, 1.17.2, 1.17.3, 1.17.4, 1.17.5, 1.18.0, 1.18.1, 1.18.2, 1.18.3, 1.18.4, 1.18.5, 1.19.0, 1.19.1, 1.19.2, 1.19.3, 1.19.4, 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_

from transformers import BlipForConditionalGeneration, BlipProcessor
from datasets import load_dataset
from nltk.translate.bleu_score import corpus_bleu
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.meteor.meteor import Meteor

import math
import os
import time
import numpy as np
from tqdm.auto import tqdm


# 1. Improved Token Gating Implementation with Enhanced Regularization
class TokenGating(nn.Module):
    """
    Enhanced Token Gating mechanism that selectively focuses on important tokens
    while suppressing less relevant ones, with improved regularization.
    """
    def __init__(self, hidden_dim, dropout=0.1):
        super(TokenGating, self).__init__()
        # Deeper MLP for better feature extraction before gating
        self.gate_transform = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 4, 1)
        )
        self.sigmoid = nn.Sigmoid()
        self.layer_norm = nn.LayerNorm(hidden_dim)  # Add layer norm for stability

    def forward(self, hidden_states, attention_mask=None):
        # Apply layer normalization for more stable training
        normalized_states = self.layer_norm(hidden_states)

        # Calculate importance score for each token [batch_size, seq_len, 1]
        gate_scores = self.sigmoid(self.gate_transform(normalized_states))

        # Apply scaling factor with a learnable temperature
        gate_scores = gate_scores * 2.0

        # Apply the gate - element-wise multiplication
        gated_output = hidden_states * gate_scores

        # Use attention mask if provided
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(-1)
            gated_output = gated_output * mask

        return gated_output, gate_scores


# 2. Enhanced Sparse Attention Implementation with improved efficiency
class SparseAttention(nn.Module):
    """
    Implements sparse attention by selecting only the top-k most relevant tokens
    for each position during attention computation, with improved routing.
    """
    def __init__(self, hidden_dim, num_heads=8, dropout=0.1, sparsity=0.8):
        super(SparseAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        assert self.head_dim * num_heads == hidden_dim, "hidden_dim must be divisible by num_heads"

        self.sparsity = sparsity  # Percent of attention connections to prune

        # Improved parameter initialization
        stdv = 1. / math.sqrt(self.head_dim)

        # Key, Query, Value projections
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        # Initialize weights for better convergence
        for proj in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]:
            nn.init.uniform_(proj.weight, -stdv, stdv)
            nn.init.zeros_(proj.bias)

        self.dropout = nn.Dropout(dropout)

        # Add a learnable temperature parameter for scaled attention
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_len, _ = hidden_states.shape

        # Linear projections and reshape to multi-head
        q = self.q_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, self.head_dim)

        # Transpose for attention computation [batch_size, num_heads, seq_len, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute scaled dot-product attention with learnable temperature
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5 * self.temperature)

        # Apply attention mask if provided
        if attention_mask is not None:
            # Expand mask for multi-head attention [batch_size, 1, 1, seq_len]
            expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attn_weights = attn_weights.masked_fill(expanded_mask == 0, -1e10)

        # Compute sparse attention by keeping only top-k values
        if self.training:
            # Determine number of tokens to keep based on sparsity level
            k_tokens = max(1, int((1 - self.sparsity) * seq_len))

            # Get top-k values for each query token
            top_k_attn, _ = torch.topk(attn_weights, k=k_tokens, dim=-1)

            # Use smallest value from top-k as threshold
            sparse_threshold = top_k_attn[..., -1].unsqueeze(-1)

            # Create a binary mask for sparse attention
            sparse_mask = (attn_weights >= sparse_threshold).float()

            # Add a small noise during training to avoid exact 0s (helps gradient flow)
            if self.training:
                noise = torch.randn_like(sparse_mask) * 1e-5
                sparse_mask = sparse_mask + noise * (1 - sparse_mask)

            # Apply the sparse mask
            attn_weights = attn_weights * sparse_mask + -1e10 * (1 - sparse_mask)

        # Apply softmax to get normalized attention weights
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention weights to values
        context = torch.matmul(attn_weights, v)

        # Reshape back to [batch_size, seq_len, hidden_dim]
        context = context.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim)

        # Final output projection
        output = self.out_proj(context)

        return output


# 3. Enhanced Adapter for VAE Latents
class VAELatentAdapter(nn.Module):
    """
    Enhanced adapter that properly processes VAE latents for compatibility with BLIP.
    """
    def __init__(self, latent_dim, output_channels=3, output_size=224):
        super(VAELatentAdapter, self).__init__()
        self.latent_dim = latent_dim
        self.output_size = output_size

        # More powerful transformation network
        self.adapter = nn.Sequential(
            nn.Linear(latent_dim, latent_dim * 2),
            nn.LayerNorm(latent_dim * 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(latent_dim * 2, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(latent_dim, output_channels * output_size * output_size)
        )

    def forward(self, x):
        batch_size = x.size(0)
        # Handle different input shapes
        if len(x.shape) > 2:
            # If more than 2D, flatten to [batch_size, flattened_dim]
            x = x.reshape(batch_size, -1)

        # Ensure correct input dimensionality
        if x.size(1) != self.latent_dim:
            # Pad or truncate as needed
            if x.size(1) < self.latent_dim:
                padding = torch.zeros(batch_size, self.latent_dim - x.size(1), device=x.device)
                x = torch.cat([x, padding], dim=1)
            else:
                x = x[:, :self.latent_dim]

        # Apply adaptation network
        x = self.adapter(x)

        # Reshape to expected image format
        return x.view(batch_size, 3, self.output_size, self.output_size)


# 4. Significantly Enhanced BLIP Model with Integrated Token Gating and Sparse Attention
class EnhancedBLIP(nn.Module):
    """
    Enhances BLIP model with token gating and sparse attention mechanisms
    with deep integration for better performance.
    """
    def __init__(self, sparsity=0.8, gate_reg_lambda=0.01, vae_latent_dim=4096):
        super(EnhancedBLIP, self).__init__()

        # Load base model
        print("Loading base BLIP model...")
        self.base_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

        # Get hidden dimension from the base model
        hidden_dim = self.base_model.text_decoder.config.hidden_size
        print(f"Model hidden dimension: {hidden_dim}")

        # Add VAE latent adapter if needed
        self.use_latent_adapter = vae_latent_dim is not None
        if self.use_latent_adapter:
            print(f"Creating VAE latent adapter with input dim {vae_latent_dim}")
            self.latent_adapter = VAELatentAdapter(vae_latent_dim)

        # Add token gating layers
        self.text_gate = TokenGating(hidden_dim)
        self.vision_gate = TokenGating(hidden_dim)

        # Add sparse attention layers
        self.text_sparse_attn = SparseAttention(hidden_dim, sparsity=sparsity)
        self.vision_sparse_attn = SparseAttention(hidden_dim, sparsity=sparsity)

        # Layer norms for stability
        self.text_ln1 = nn.LayerNorm(hidden_dim)
        self.text_ln2 = nn.LayerNorm(hidden_dim)
        self.vision_ln1 = nn.LayerNorm(hidden_dim)
        self.vision_ln2 = nn.LayerNorm(hidden_dim)

        # Feed-forward networks for residual paths
        self.text_ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )

        self.vision_ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )

        # Flag to control whether to apply enhancements
        self.apply_enhancements = True

        # Gate regularization strength
        self.gate_reg_lambda = gate_reg_lambda

        print("Enhanced BLIP model initialized")

    def _enhance_text_features(self, hidden_states, attention_mask=None):
        """Apply token gating and sparse attention to text features."""
        # Apply token gating
        gated_states, gate_scores = self.text_gate(hidden_states, attention_mask)
        gated_states = self.text_ln1(gated_states + hidden_states)  # Residual connection

        # Apply sparse attention
        sparse_states = self.text_sparse_attn(gated_states, attention_mask)
        sparse_states = self.text_ln2(sparse_states + gated_states)  # Residual connection

        # Apply feed-forward network
        output = hidden_states + self.text_ffn(sparse_states)

        return output, gate_scores

    def _enhance_vision_features(self, hidden_states):
        """Apply token gating and sparse attention to vision features."""
        # Apply token gating
        gated_states, gate_scores = self.vision_gate(hidden_states, None)
        gated_states = self.vision_ln1(gated_states + hidden_states)  # Residual connection

        # Apply sparse attention
        sparse_states = self.vision_sparse_attn(gated_states, None)
        sparse_states = self.vision_ln2(sparse_states + gated_states)  # Residual connection

        # Apply feed-forward network
        output = hidden_states + self.vision_ffn(sparse_states)

        return output, gate_scores

    def forward(self, pixel_values=None, input_ids=None, attention_mask=None, labels=None, return_dict=True):
        # Apply VAE adapter if needed
        if self.use_latent_adapter and pixel_values is not None:
            # Check if the format needs adaptation
            if len(pixel_values.shape) != 4 or pixel_values.shape[1] != 3:
                pixel_values = self.latent_adapter(pixel_values)

        # Pass through base model
        outputs = self.base_model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True,
            return_dict=True
        )

        # Return base model outputs if enhancements are disabled
        if not self.apply_enhancements:
            return outputs

        # Apply token gating and sparse attention to improve results
        # Get the last hidden state from the decoder
        if hasattr(outputs, "decoder_hidden_states") and outputs.decoder_hidden_states is not None:
            last_hidden = outputs.decoder_hidden_states[-1]

            # Apply gating and get gate scores
            enhanced_features, gate_scores = self._enhance_text_features(last_hidden, attention_mask)

            # Compute new logits with enhanced features
            new_logits = self.base_model.text_decoder.lm_head(enhanced_features)

            # Add L1 regularization on gate values to encourage sparsity
            gate_reg_loss = self.gate_reg_lambda * gate_scores.abs().mean()

            # Update loss and logits
            if outputs.loss is not None:
                updated_loss = outputs.loss + gate_reg_loss
            else:
                updated_loss = gate_reg_loss

            # Create a modified output with enhanced features
            outputs_dict = outputs.to_dict()
            outputs_dict["loss"] = updated_loss
            outputs_dict["logits"] = new_logits

            # Convert back to the original class type
            return type(outputs)(**outputs_dict)

        return outputs

    def generate(self, pixel_values=None, input_ids=None, attention_mask=None, **kwargs):
        """Generate captions using the base model's generation capability."""
        # Apply VAE adapter if needed
        if self.use_latent_adapter and pixel_values is not None:
            # Check if the format needs adaptation
            if len(pixel_values.shape) != 4 or pixel_values.shape[1] != 3:
                pixel_values = self.latent_adapter(pixel_values)

        return self.base_model.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )


# 5. Improved Data Processing with Robust Handling of Varied VAE Latents
def closest_factors(n):
    """Finds the closest factors of n to get an aspect ratio close to a square."""
    sqrt_n = int(math.sqrt(n))
    for i in range(sqrt_n, 0, -1):
        if n % i == 0:
            return i, n // i  # Return H, W such that H × W = n

    # If no exact factors, use power of 2 dimensions
    size = 2 ** int(math.log2(math.sqrt(n)))
    return size, size


def collate_fn(batch):
    """
    Improved collate function with better handling of varied VAE latents,
    enhanced normalization, and more robust error handling.
    """
    try:
        # Extract captions and VAE latents
        captions = [item["caption"] for item in batch]
        vae_latents = [torch.tensor(item["vae_latent"], dtype=torch.float32) for item in batch]

        # Find maximum length for padding
        max_len = max([latent.shape[0] for latent in vae_latents])

        # Check if latents have wildly different sizes (potential issue)
        sizes = [latent.shape[0] for latent in vae_latents]
        size_variance = np.var(sizes) if len(sizes) > 1 else 0

        # If high variance in sizes, use a reference size instead
        if size_variance > 500 and len(sizes) > 1:
            # Use median size as reference
            reference_size = int(np.median(sizes))
            # Cap maximum length to avoid excessive padding
            max_len = min(max_len, reference_size * 2)

        # Pad tensors
        padded_latents = []
        for latent in vae_latents:
            if latent.shape[0] > max_len:
                # Truncate if too long
                padded = latent[:max_len]
            else:
                # Pad if too short
                pad_size = max_len - latent.shape[0]
                if pad_size > 0:
                    padding = torch.zeros(pad_size, dtype=torch.float32)
                    padded = torch.cat([latent, padding])
                else:
                    padded = latent
            padded_latents.append(padded)

        # Stack tensors
        latents = torch.stack(padded_latents)

        # Apply robust normalization:
        # 1. Per-sample mean-std normalization
        means = latents.mean(dim=1, keepdim=True)
        stds = latents.std(dim=1, keepdim=True) + 1e-6  # Avoid division by zero
        normalized_latents = (latents - means) / stds

        # 2. Apply global scaling to ensure values are in a good range for the model
        normalized_latents = torch.tanh(normalized_latents) * 0.5  # Scale to [-0.5, 0.5]

        # Reshape for BLIP
        batch_size = len(batch)
        feature_dim = normalized_latents.shape[1]

        # Get dimensions for reshaping to 2D
        height, width = closest_factors(feature_dim)

        # Handle reshape properly
        try:
            if height * width != feature_dim:
                # Adjust dimensions to avoid reshape errors
                height = int(math.sqrt(feature_dim))
                width = int(math.ceil(feature_dim / height))
                padded_dim = height * width

                if padded_dim > feature_dim:
                    # Pad to match dimensions
                    padding = torch.zeros((batch_size, padded_dim - feature_dim), dtype=torch.float32)
                    normalized_latents = torch.cat([normalized_latents, padding], dim=1)
                else:
                    # Truncate to match dimensions
                    normalized_latents = normalized_latents[:, :padded_dim]

            # Reshape to image format (batch_size, channels, height, width)
            images = normalized_latents.view(batch_size, 1, height, width)
            # Repeat the single channel to get 3 channels for RGB
            images = images.repeat(1, 3, 1, 1)

        except RuntimeError as e:
            print(f"Reshape error: {e}, using fallback method")

            # Fallback to a standard square size that's a power of 2
            side = 2 ** int(math.log2(math.sqrt(feature_dim)))
            if side < 16:  # Ensure minimum size
                side = 16

            # Create standard-sized images
            images = torch.zeros((batch_size, 3, side, side), dtype=torch.float32)
            for i, latent in enumerate(normalized_latents):
                # Prepare the data for reshaping
                if latent.shape[0] < side * side:
                    # Pad if needed
                    padding = torch.zeros(side * side - latent.shape[0], device=latent.device)
                    latent_padded = torch.cat([latent, padding])
                else:
                    # Truncate if needed
                    latent_padded = latent[:side * side]

                # Reshape to square and repeat for 3 channels
                try:
                    img = latent_padded.view(1, side, side).repeat(3, 1, 1)
                    images[i] = img
                except Exception as reshape_err:
                    print(f"Secondary reshape error: {reshape_err}")
                    # If still failing, use simple copying into the tensor
                    for j in range(min(side * side, latent.shape[0])):
                        c = j % 3  # Channel
                        h = (j // 3) // side  # Height
                        w = (j // 3) % side  # Width
                        if h < side and w < side:
                            images[i, c, h, w] = latent[j]

        # Process captions
        processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        encoded_captions = processor(text=captions, padding="max_length", truncation=True,
                                  max_length=128, return_tensors="pt")

        # Combine images and captions
        batch_encoding = {
            "pixel_values": images.to(torch.float32),
            "input_ids": encoded_captions["input_ids"],
            "attention_mask": encoded_captions["attention_mask"],
            "labels": encoded_captions["input_ids"].clone()
        }

        return batch_encoding

    except Exception as e:
        print(f"Error in collate_fn: {e}")
        # Return a minimal valid batch to avoid training failure
        processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        dummy_captions = ["dummy caption"] * len(batch)
        encoded = processor(text=dummy_captions, padding="max_length", truncation=True,
                           max_length=128, return_tensors="pt")
        dummy_images = torch.zeros((len(batch), 3, 32, 32), dtype=torch.float32)

        batch_encoding = {
            "pixel_values": dummy_images,
            "input_ids": encoded["input_ids"],
            "attention_mask": encoded["attention_mask"],
            "labels": encoded["input_ids"].clone()
        }

        return batch_encoding


# 6. Enhanced Training Function with Dynamic Learning Rate and Focused Metrics
def train_model(model, train_loader, val_loader=None, num_epochs=5,
               lr=2e-5, device="cuda", checkpoint_dir="checkpoints"):
    """
    Enhanced training function with better optimization strategies,
    focused metrics, and more robust checkpointing.
    """
    print(f"Starting training for {num_epochs} epochs")

    # Create checkpoint directory
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Prepare optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.999))

    # Learning rate scheduler with warmup
    def lr_lambda(current_step: int):
        # Warmup for 10% of total steps
        warmup_steps = int(0.1 * num_epochs * len(train_loader))
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        # Linear decay for the rest
        return max(0.0, 1.0 - (current_step - warmup_steps) / (num_epochs * len(train_loader) - warmup_steps))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # Mixed precision training
    scaler = GradScaler(enabled=(device == "cuda"))

    # Gradient accumulation steps (effective batch size = batch_size * accum_steps)
    accum_steps = 4

    # Early stopping parameters with improved patience
    patience = 4
    best_score = float('-inf')
    no_improve_epochs = 0

    # Metrics tracking
    metrics_history = {
        'train_loss': [],
        'bleu4': [],
        'cider': [],
        'spice': [],
        'rouge': []
    }

    # Resume from checkpoint if available
    checkpoint_path = os.path.join(checkpoint_dir, "latest_model.pth")
    start_epoch = 1

    if os.path.exists(checkpoint_path):
        try:
            print(f"Loading checkpoint from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            start_epoch = checkpoint["epoch"] + 1
            best_score = checkpoint.get("best_score", float('-inf'))

            # Recover metrics history if available
            if "metrics_history" in checkpoint:
                metrics_history = checkpoint["metrics_history"]

            # Move optimizer states to right device
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)

            print(f"Resuming training from epoch {start_epoch}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting training from scratch")

    # Training loop
    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        total_loss = 0.0
        start_time = time.time()

        # Progress bar for training batches
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")

        # Learning rate warmup in first epoch
        if epoch == 1:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr * 0.1

        for batch_idx, batch in enumerate(progress_bar):
            try:
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}

                # Forward pass with mixed precision
                with autocast(enabled=(device == "cuda")):
                    outputs = model(**batch)
                    loss = outputs.loss / accum_steps  # Scale loss for accumulation

                # Backward pass with gradient scaling
                scaler.scale(loss).backward()

                # Update weights after accumulation or at the end
                if (batch_idx + 1) % accum_steps == 0 or (batch_idx + 1 == len(train_loader)):
                    # Unscale gradients for clipping
                    scaler.unscale_(optimizer)

                    # Clip gradients to prevent explosive values
                    clip_grad_norm_(model.parameters(), max_norm=1.0)

                    # Optimizer step with scaling
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                    # Update learning rate
                    scheduler.step()

                # Track loss
                total_loss += loss.item() * accum_steps

                # Update progress bar
                progress_bar.set_postfix({"loss": f"{loss.item() * accum_steps:.4f}",
                                         "lr": f"{optimizer.param_groups[0]['lr']:.2e}"})

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue

        # Calculate average loss for the epoch
        avg_loss = total_loss / len(train_loader)
        train_time = time.time() - start_time

        # Store training loss
        metrics_history['train_loss'].append(avg_loss)

        print(f"Epoch {epoch} - Avg. Training Loss: {avg_loss:.4f} (Time: {train_time:.2f}s)")

        # Validation phase - compute only BLEU-4 for most epochs, full metrics for specific points
        if val_loader is not None:
            # Calculate all metrics every 5 epochs and in the last epoch
            compute_all = (epoch % 5 == 0) or (epoch == num_epochs)
            val_metrics = evaluate_model(model, val_loader, device=device, compute_all=compute_all)

            # Log validation metrics
            print(f"Validation BLEU-4: {val_metrics['bleu4']:.4f}")
            metrics_history['bleu4'].append(val_metrics['bleu4'])

            if compute_all:
                print(f"Validation CIDEr: {val_metrics['cider']:.4f}")
                print(f"Validation SPICE: {val_metrics['spice']:.4f}")
                print(f"Validation ROUGE: {val_metrics['rouge']:.4f}")

                # Store all metrics
                metrics_history['cider'].append(val_metrics['cider'])
                metrics_history['spice'].append(val_metrics['spice'])
                metrics_history['rouge'].append(val_metrics['rouge'])

            # Use weighted combination of metrics as overall score for early stopping
            # BLEU is available every epoch, other metrics only on compute_all epochs
            if compute_all:
                current_score = val_metrics['bleu4'] + 0.5 * val_metrics['cider'] + 0.3 * val_metrics['spice']
            else:
                # Only BLEU is available
                current_score = val_metrics['bleu4']
        else:
            # If no validation set, use negative training loss as score
            current_score = -avg_loss
            compute_all = False

        # Save checkpoint with metrics history
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_score": best_score,
            "metrics_history": metrics_history
        }, checkpoint_path)

        # Check for improvement
        if current_score > best_score:
            best_score = current_score
            no_improve_epochs = 0

            # Save best model
            best_model_path = os.path.join(checkpoint_dir, "best_model.pth")
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "best_score": best_score,
                "metrics_history": metrics_history
            }, best_model_path)

            print(f"New best model saved at epoch {epoch}")
        else:
            no_improve_epochs += 1

        # Early stopping with increasing patience in later epochs
        required_patience = patience
        if epoch > num_epochs * 0.7:  # In later epochs, allow more patience
            required_patience = patience + 2

        if no_improve_epochs >= required_patience:
            print(f"No improvement for {no_improve_epochs} epochs. Early stopping.")
            break

    print("Training completed!")
    return model, metrics_history

In [None]:
# 7. Enhanced Evaluation Function with Improved Metrics Calculation
def evaluate_model(model, val_loader, device="cuda", compute_all=False, max_samples=None):
    """
    Evaluate the model on validation set with BLEU, CIDEr, SPICE, and ROUGE metrics.
    """
    model.eval()
    predictions = []
    references = []

    # Count samples for potential limit
    sample_count = 0

    # Process validation batches
    for batch_idx, batch in enumerate(tqdm(val_loader, desc="Evaluating")):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Generate captions with enhanced beam search
        with torch.no_grad():
            try:
                # Use more advanced generation parameters for better results
                output_ids = model.generate(
                    pixel_values=batch["pixel_values"],
                    max_length=50,
                    num_beams=5,
                    length_penalty=1.0,
                    no_repeat_ngram_size=2,
                    temperature=0.7,  # Add temperature for better diversity
                    top_p=0.9,        # Add nucleus sampling
                    do_sample=True    # Enable sampling for more diverse outputs
                )

                # Decode generated captions
                pred_captions = model.processor.batch_decode(output_ids, skip_special_tokens=True)

                # Decode reference captions
                ref_captions = model.processor.batch_decode(batch["labels"], skip_special_tokens=True)

                # Store predictions and references
                predictions.extend(pred_captions)
                references.extend([[ref] for ref in ref_captions])  # BLEU expects list of references per example

                # Update sample count
                sample_count += len(pred_captions)

                # Print sample predictions (first batch only)
                if batch_idx == 0:
                    print("\nSample predictions:")
                    for i in range(min(3, len(pred_captions))):
                        print(f"  Reference: {ref_captions[i]}")
                        print(f"  Prediction: {pred_captions[i]}")
                        print()

                # Check if we've processed enough samples
                if max_samples is not None and sample_count >= max_samples:
                    break

            except Exception as e:
                print(f"Error generating captions for batch {batch_idx}: {e}")
                continue

    # Calculate BLEU scores
    bleu1 = corpus_bleu(references, predictions, weights=(1.0, 0.0, 0.0, 0.0))
    bleu4 = corpus_bleu(references, predictions, weights=(0.25, 0.25, 0.25, 0.25))

    # Only compute other metrics if requested (they're slower)
    if compute_all:
        # Prepare data for other metrics
        metric_refs = {i: [ref[0]] for i, ref in enumerate(references)}
        metric_preds = {i: [pred] for i, pred in enumerate(predictions)}

        # Compute other metrics with better error handling
        try:
            cider_score = Cider().compute_score(metric_refs, metric_preds)[0]
        except Exception as e:
            print(f"Error computing CIDEr: {e}")
            cider_score = 0.0

        try:
            spice_score = Spice().compute_score(metric_refs, metric_preds)[0]
        except Exception as e:
            print(f"Error computing SPICE: {e}")
            spice_score = 0.0

        try:
            rouge_score = Rouge().compute_score(metric_refs, metric_preds)[0]
        except Exception as e:
            print(f"Error computing ROUGE: {e}")
            rouge_score = 0.0

        try:
            meteor_score = Meteor().compute_score(metric_refs, metric_preds)[0]
        except Exception as e:
            print(f"Error computing METEOR: {e}")
            meteor_score = 0.0
    else:
        # If not computing all metrics, set placeholders
        cider_score = 0.0
        spice_score = 0.0
        rouge_score = 0.0
        meteor_score = 0.0

    return {
        "bleu1": bleu1,
        "bleu4": bleu4,
        "cider": cider_score,
        "spice": spice_score,
        "rouge": rouge_score,
        "meteor": meteor_score
    }

# 8. Main Function
def main():
    # Set random seed for reproducibility
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
    np.random.seed(42)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load dataset
    print("Loading dataset...")
    dataset = load_dataset("SwayStar123/preprocessed_recap-coco30k-moondream")['train']
    print(f"Dataset loaded with {len(dataset)} samples")

    # Detect VAE latent dimension
    latent_dim = None
    if "vae_latent" in dataset[0]:
        latent_dim = len(dataset[0]["vae_latent"])
        print(f"Detected VAE latent dimension: {latent_dim}")

    # Split dataset
    split_ratio = 0.9
    train_size = int(split_ratio * len(dataset))
    train_ds = dataset.select(range(train_size))
    val_ds = dataset.select(range(train_size, len(dataset)))
    print(f"Training samples: {len(train_ds)}, Validation samples: {len(val_ds)}")

    # Create data loaders
    train_loader = DataLoader(
        train_ds,
        batch_size=8,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=8,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True
    )

    # Initialize model
    print("Initializing model...")
    model = EnhancedBLIP(
        sparsity=0.7,
        gate_reg_lambda=0.01,
        vae_latent_dim=latent_dim
    )
    model.to(device)

    # Train the model
    model, metrics_history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=10,
        lr=2e-5,
        device=device,
        checkpoint_dir="enhanced_token_gating_checkpoints"
    )

    # Final evaluation
    print("Performing final evaluation...")
    final_metrics = evaluate_model(model, val_loader, device=device, compute_all=True)

    print("Final Evaluation Results:")
    print(f"BLEU-1: {final_metrics['bleu1']:.4f}")
    print(f"BLEU-4: {final_metrics['bleu4']:.4f}")
    print(f"CIDEr: {final_metrics['cider']:.4f}")
    print(f"SPICE: {final_metrics['spice']:.4f}")
    print(f"ROUGE: {final_metrics['rouge']:.4f}")
    print(f"METEOR: {final_metrics['meteor']:.4f}")

    return model

# 9. Function to evaluate the current model without loading from checkpoint
def evaluate_current_model():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load dataset for testing
    print("Loading dataset for evaluation...")
    dataset = load_dataset("SwayStar123/preprocessed_recap-coco30k-moondream")['train']

    # Detect VAE latent dimension
    latent_dim = None
    if "vae_latent" in dataset[0]:
        latent_dim = len(dataset[0]["vae_latent"])
        print(f"Detected VAE latent dimension: {latent_dim}")

    # Create a fresh model
    print("Initializing new model...")
    model = EnhancedBLIP(
        sparsity=0.7,
        gate_reg_lambda=0.01,
        vae_latent_dim=latent_dim
    )
    model.to(device)
    model.eval()

    # Use a small subset for quick evaluation
    test_size = 200
    test_ds = dataset.select(range(len(dataset) - test_size, len(dataset)))
    print(f"Test dataset contains {len(test_ds)} samples")

    # Create test dataloader
    test_loader = DataLoader(
        test_ds,
        batch_size=8,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True
    )

    # Generate example captions with the current model
    with torch.no_grad():
        # Get a batch
        example_batch = next(iter(test_loader))
        example_batch = {k: v.to(device) for k, v in example_batch.items()}

        # Generate captions
        output_ids = model.generate(
            pixel_values=example_batch["pixel_values"],
            max_length=50,
            num_beams=5
        )

        # Decode captions
        pred_captions = model.processor.batch_decode(output_ids, skip_special_tokens=True)
        ref_captions = model.processor.batch_decode(example_batch["labels"], skip_special_tokens=True)

        # Print examples
        print("\n=== Example Captions ===")
        for i in range(min(5, len(pred_captions))):
            print(f"\nExample {i+1}:")
            print(f"Reference: {ref_captions[i]}")
            print(f"Generated: {pred_captions[i]}")

    return model

# Add this at the end of your file
if __name__ == "__main__":
    try:
        model = main()
    except Exception as e:
        print(f"Error in main function: {e}")
        print("Falling back to evaluation mode.")
        model = evaluate_current_model()

Using device: cuda
Loading dataset...


Resolving data files:   0%|          | 0/96 [00:00<?, ?it/s]

Dataset loaded with 30504 samples
Detected VAE latent dimension: 3712
Training samples: 27453, Validation samples: 3051
Initializing model...
Loading base BLIP model...
Model hidden dimension: 768
Creating VAE latent adapter with input dim 3712
Enhanced BLIP model initialized
Starting training for 10 epochs


  scaler = GradScaler(enabled=(device == "cuda"))


Epoch 1/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 1 - Avg. Training Loss: 4.7395 (Time: 2492.63s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a young man is seated on a wooden bench in a park, wearing a black t - shirt and blue jeans, holding a yellow frisbee in his right hand. the bench is positioned on the left side of the frame

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the person is riding o

Epoch 2/10:   0%|          | 0/3432 [00:00<?, ?it/s]

Epoch 2 - Avg. Training Loss: 1.2606 (Time: 2473.79s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: the image depicts a baseball game in progress. a batter wearing a white uniform with the number 12 is mid - swing, his body coiled and muscles taut as he prepares to strike the incoming ball. behind him, a catcher in a black uniform crouch

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in 

Epoch 3/10:   0%|          | 0/3432 [00:00<?, ?it/s]

Epoch 3 - Avg. Training Loss: 0.8749 (Time: 2515.16s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: the image captures a baseball game in progress. a batter wearing a white uniform with red accents is mid - swing, his body coiled and muscles taut as he prepares to strike the incoming ball. behind him, a catcher in a black uniform crouches

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in

Epoch 4/10:   0%|          | 0/3432 [00:00<?, ?it/s]

Epoch 4 - Avg. Training Loss: 0.7776 (Time: 2522.01s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: the image depicts a baseball game in progress. a batter in a white uniform with red stripes is mid - swing, while a catcher in an orange uniform crouches behind home plate. an umpire in black stands just behind the catcher, observing the play

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly 

Epoch 5/10:   0%|          | 0/3432 [00:00<?, ?it/s]

Epoch 5 - Avg. Training Loss: 0.7093 (Time: 2502.95s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a man and a woman are seated at a table in a restaurant. the man is wearing a blue shirt and glasses, while the woman is dressed in an orange shirt. they are both smiling and looking directly at the camera.

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the pe

Epoch 6/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):
