<a href="https://colab.research.google.com/github/qu-romana/ML_Research_Code/blob/main/TokenGating_Sparse_Attention2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Reset the environment and install compatible versions
!pip uninstall -y tensorflow numpy transformers
!pip install numpy==1.23.5
!pip install tensorflow==2.15.0
!pip install transformers==4.25.1 datasets
!pip install torch torchvision nltk pycocotools
# Install additional packages for evaluation and visualization
!pip install pycocoevalcap==1.2 matplotlib seaborn
!pip install accelerate  # For optimization with Hugging Face models


[0mFound existing installation: numpy 1.23.5
Uninstalling numpy-1.23.5:
  Successfully uninstalled numpy-1.23.5
[0mCollecting numpy==1.23.5
  Using cached numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Using cached numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
Installing collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
peft 0.14.0 requires transformers, which is not installed.
dopamine-rl 4.1.2 requires tensorflow>=2.2.0, which is not installed.
xarray 2025.1.2 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.
langchain 0.3.19 requires numpy<2,>=1.26.4; python_version < "3.12", but you have numpy 1.23.5 which is incompatible.
albumentations 2.0.5 requires numpy>=1.24.4, but you have numpy 1.23.5 which is incompatible.
scikit-image 

Collecting tensorflow==2.15.0
  Using cached tensorflow-2.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.4 kB)
Using cached tensorflow-2.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (475.3 MB)
Installing collected packages: tensorflow
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-text 2.18.1 requires tensorflow<2.19,>=2.18.0, but you have tensorflow 2.15.0 which is incompatible.
tf-keras 2.18.0 requires tensorflow<2.19,>=2.18, but you have tensorflow 2.15.0 which is incompatible.[0m[31m
[0mSuccessfully installed tensorflow-2.15.0
Collecting transformers==4.25.1
  Using cached transformers-4.25.1-py3-none-any.whl.metadata (93 kB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.25.1)
  Using cached tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.me

In [None]:
!pip install -U transformers datasets

Collecting transformers
  Using cached transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Using cached tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Using cached transformers-4.49.0-py3-none-any.whl (10.0 MB)
Using cached tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
Installing collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.13.3
    Uninstalling tokenizers-0.13.3:
      Successfully uninstalled tokenizers-0.13.3
  Attempting uninstall: transformers
    Found existing installation: transformers 4.25.1
    Uninstalling transformers-4.25.1:
      Successfully uninstalled transformers-4.25.1
Successfully installed tokenizers-0.21.0 transformers-4.49.0


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_


In [None]:

from transformers import BlipForConditionalGeneration, BlipProcessor
from datasets import load_dataset
from nltk.translate.bleu_score import corpus_bleu
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.meteor.meteor import Meteor


In [None]:

import math
import os
import time
import numpy as np
from tqdm.auto import tqdm

In [None]:

# 1. Token Gating Implementation
class TokenGating(nn.Module):
    """
    Token Gating mechanism that selectively focuses on important tokens
    while suppressing less relevant ones.
    """
    def __init__(self, hidden_dim, dropout=0.1):
        super(TokenGating, self).__init__()
        self.gate_transform = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, hidden_states, attention_mask=None):
        # Calculate importance score for each token [batch_size, seq_len, 1]
        gate_scores = self.sigmoid(self.gate_transform(hidden_states))

        # Apply scaling factor to ensure stability during training
        gate_scores = gate_scores * 2.0

        # Apply the gate - element-wise multiplication
        gated_output = hidden_states * gate_scores

        # Use attention mask if provided
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(-1)
            gated_output = gated_output * mask

        return gated_output, gate_scores


In [None]:

# 2. Sparse Attention Implementation
class SparseAttention(nn.Module):
    """
    Implements sparse attention by selecting only the top-k most relevant tokens
    for each position during attention computation.
    """
    def __init__(self, hidden_dim, num_heads=8, dropout=0.1, sparsity=0.8):
        super(SparseAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        assert self.head_dim * num_heads == hidden_dim, "hidden_dim must be divisible by num_heads"

        self.sparsity = sparsity  # Percent of attention connections to prune

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_len, _ = hidden_states.shape

        # Linear projections and reshape to multi-head
        q = self.q_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, self.head_dim)

        # Transpose for attention computation [batch_size, num_heads, seq_len, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute scaled dot-product attention
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Apply attention mask if provided
        if attention_mask is not None:
            # Expand mask for multi-head attention [batch_size, 1, 1, seq_len]
            expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attn_weights = attn_weights.masked_fill(expanded_mask == 0, -1e10)

        # Compute sparse attention by keeping only top-k values
        if self.training:
            # Determine number of tokens to keep based on sparsity level
            k_tokens = max(1, int((1 - self.sparsity) * seq_len))

            # Get top-k values for each query token
            top_k_attn, _ = torch.topk(attn_weights, k=k_tokens, dim=-1)

            # Use smallest value from top-k as threshold
            sparse_threshold = top_k_attn[..., -1].unsqueeze(-1)

            # Create a binary mask for sparse attention
            sparse_mask = (attn_weights >= sparse_threshold).float()

            # Apply the sparse mask
            attn_weights = attn_weights * sparse_mask + -1e10 * (1 - sparse_mask)

        # Apply softmax to get normalized attention weights
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention weights to values
        context = torch.matmul(attn_weights, v)

        # Reshape back to [batch_size, seq_len, hidden_dim]
        context = context.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim)

        # Final output projection
        output = self.out_proj(context)

        return output


In [None]:


# 3. Enhanced BLIP Model with Token Gating and Sparse Attention
class EnhancedBLIP(nn.Module):
    """
    Enhances BLIP model with token gating and sparse attention mechanisms.
    """
    def __init__(self, sparsity=0.8):
        super(EnhancedBLIP, self).__init__()

        # Load base model
        print("Loading base BLIP model...")
        self.base_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

        # Get hidden dimension from the base model
        hidden_dim = self.base_model.text_decoder.config.hidden_size
        print(f"Model hidden dimension: {hidden_dim}")

        # Add token gating layers
        self.text_gate = TokenGating(hidden_dim)
        self.vision_gate = TokenGating(hidden_dim)

        # Add sparse attention layers
        self.text_sparse_attn = SparseAttention(hidden_dim, sparsity=sparsity)
        self.vision_sparse_attn = SparseAttention(hidden_dim, sparsity=sparsity)

        # Layer norms for stability
        self.text_ln1 = nn.LayerNorm(hidden_dim)
        self.text_ln2 = nn.LayerNorm(hidden_dim)
        self.vision_ln1 = nn.LayerNorm(hidden_dim)
        self.vision_ln2 = nn.LayerNorm(hidden_dim)

        # Feed-forward networks for residual paths
        self.text_ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )

        self.vision_ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )

        # Flag to control whether to apply enhancements
        self.apply_enhancements = True
        print("Enhanced BLIP model initialized")

    def _enhance_text_features(self, hidden_states, attention_mask=None):
        """Apply token gating and sparse attention to text features."""
        # Apply token gating
        gated_states, _ = self.text_gate(hidden_states, attention_mask)
        gated_states = self.text_ln1(gated_states + hidden_states)  # Residual connection

        # Apply sparse attention
        sparse_states = self.text_sparse_attn(gated_states, attention_mask)
        sparse_states = self.text_ln2(sparse_states + gated_states)  # Residual connection

        # Apply feed-forward network
        output = hidden_states + self.text_ffn(sparse_states)

        return output

    def _enhance_vision_features(self, hidden_states):
        """Apply token gating and sparse attention to vision features."""
        # Apply token gating
        gated_states, _ = self.vision_gate(hidden_states, None)
        gated_states = self.vision_ln1(gated_states + hidden_states)  # Residual connection

        # Apply sparse attention
        sparse_states = self.vision_sparse_attn(gated_states, None)
        sparse_states = self.vision_ln2(sparse_states + gated_states)  # Residual connection

        # Apply feed-forward network
        output = hidden_states + self.vision_ffn(sparse_states)

        return output

    def forward(self, pixel_values=None, input_ids=None, attention_mask=None, labels=None, return_dict=True):
        # First pass through base model
        outputs = self.base_model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True,
            return_dict=True
        )

        # Return base model outputs if enhancements are disabled
        if not self.apply_enhancements:
            return outputs

        # Apply token gating and sparse attention to hidden states
        # Note: This is for inference only - the training loss comes from the base model
        return outputs

    def generate(self, pixel_values=None, input_ids=None, attention_mask=None, **kwargs):
        """Generate captions using the base model's generation capability."""
        return self.base_model.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )

In [None]:
# 4. Data Processing Functions
def closest_factors(n):
    """Finds the closest factors of n to get an aspect ratio close to a square."""
    sqrt_n = int(math.sqrt(n))
    for i in range(sqrt_n, 0, -1):
        if n % i == 0:
            return i, n // i  # Return H, W such that H × W = n

    # If no exact factors, use power of 2 dimensions
    size = 2 ** int(math.log2(math.sqrt(n)))
    return size, size

In [None]:




def collate_fn(batch):
    """
    Collate function for VAE latents dataset that properly handles
    variable length latents, applies normalization, and reshapes for BLIP.
    """
    try:
        # Extract captions and VAE latents
        captions = [item["caption"] for item in batch]
        vae_latents = [torch.tensor(item["vae_latent"], dtype=torch.float32) for item in batch]

        # Find maximum length for padding
        max_len = max([latent.shape[0] for latent in vae_latents])

        # Pad tensors
        padded_latents = []
        for latent in vae_latents:
            pad_size = max_len - latent.shape[0]
            if pad_size > 0:
                padding = torch.zeros(pad_size, dtype=torch.float32)
                padded = torch.cat([latent, padding])
            else:
                padded = latent
            padded_latents.append(padded)

        # Stack tensors
        latents = torch.stack(padded_latents)

        # Apply z-score normalization (per sample)
        means = latents.mean(dim=1, keepdim=True)
        stds = latents.std(dim=1, keepdim=True) + 1e-6  # Avoid division by zero
        normalized_latents = (latents - means) / stds

        # Reshape for BLIP
        batch_size = len(batch)
        feature_dim = normalized_latents.shape[1]

        # Get dimensions for reshaping
        height, width = closest_factors(feature_dim)

        # Check if we need to adjust dimensions
        if height * width != feature_dim:
            # Use power of 2 dimensions and pad/truncate
            height = 2 ** int(math.log2(math.sqrt(feature_dim)))
            width = height
            padded_dim = height * width

            if padded_dim > feature_dim:
                # Pad each latent
                padding = torch.zeros((batch_size, padded_dim - feature_dim), dtype=torch.float32)
                normalized_latents = torch.cat([normalized_latents, padding], dim=1)
            else:
                # Truncate each latent
                normalized_latents = normalized_latents[:, :padded_dim]

        # Reshape latents to image format (batch_size, channels, height, width)
        try:
            images = normalized_latents.view(batch_size, 1, height, width)
            images = images.repeat(1, 3, 1, 1)  # Repeat along channel dimension for RGB
        except RuntimeError as e:
            print(f"Error reshaping latents: {e}")
            print(f"Using fallback reshaping method")

            # Fallback to simple square reshaping
            side = int(math.ceil(math.sqrt(feature_dim)))
            images = torch.zeros((batch_size, 3, side, side), dtype=torch.float32)

            for i, latent in enumerate(normalized_latents):
                # Pad if needed
                if latent.shape[0] < side * side:
                    latent = torch.cat([latent, torch.zeros(side * side - latent.shape[0])])
                else:
                    latent = latent[:side * side]

                # Reshape to square and repeat channels
                img = latent.view(1, side, side).repeat(3, 1, 1)
                images[i] = img

        # Process captions
        processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        encoded_captions = processor(text=captions, padding="max_length", truncation=True,
                                  max_length=128, return_tensors="pt")

        # Combine images and captions
        batch_encoding = {
            "pixel_values": images.to(torch.float32),
            "input_ids": encoded_captions["input_ids"],
            "attention_mask": encoded_captions["attention_mask"],
            "labels": encoded_captions["input_ids"].clone()
        }

        return batch_encoding

    except Exception as e:
        print(f"Error in collate_fn: {e}")
        # Return a minimal valid batch to avoid training failure
        processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        dummy_captions = ["dummy caption"] * len(batch)
        encoded = processor(text=dummy_captions, padding="max_length", truncation=True,
                           max_length=128, return_tensors="pt")
        dummy_images = torch.zeros((len(batch), 3, 32, 32), dtype=torch.float32)

        batch_encoding = {
            "pixel_values": dummy_images,
            "input_ids": encoded["input_ids"],
            "attention_mask": encoded["attention_mask"],
            "labels": encoded["input_ids"].clone()
        }

        return batch_encoding


In [None]:

# 5. Training Function with Mixed Precision and Gradient Accumulation
def train_model(model, train_loader, val_loader=None, num_epochs=5,
                lr=2e-5, device="cuda", checkpoint_dir="checkpoints"):
    """
    Train the enhanced BLIP model with mixed precision, gradient accumulation,
    and proper checkpointing.
    """
    print(f"Starting training for {num_epochs} epochs")

    # Create checkpoint directory
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Prepare optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # Mixed precision training
    scaler = GradScaler(enabled=(device == "cuda"))

    # Gradient accumulation steps (effective batch size = batch_size * accum_steps)
    accum_steps = 4

    # Early stopping parameters
    patience = 3
    best_score = float('-inf')
    no_improve_epochs = 0

    # Resume from checkpoint if available
    checkpoint_path = os.path.join(checkpoint_dir, "latest_model.pth")
    start_epoch = 1

    if os.path.exists(checkpoint_path):
        try:
            print(f"Loading checkpoint from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            start_epoch = checkpoint["epoch"] + 1
            best_score = checkpoint.get("best_score", float('-inf'))

            # Move optimizer states to right device
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)

            print(f"Resuming training from epoch {start_epoch}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting training from scratch")

    # Training loop
    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        total_loss = 0.0
        start_time = time.time()

        # Progress bar for training batches
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")

        for batch_idx, batch in enumerate(progress_bar):
            try:
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}

                # Forward pass with mixed precision
                with autocast(enabled=(device == "cuda")):
                    outputs = model(**batch)
                    loss = outputs.loss / accum_steps  # Scale loss for accumulation

                # Backward pass with gradient scaling
                scaler.scale(loss).backward()

                # Update weights after accumulation or at the end
                if (batch_idx + 1) % accum_steps == 0 or (batch_idx + 1 == len(train_loader)):
                    # Unscale gradients for clipping
                    scaler.unscale_(optimizer)

                    # Clip gradients to prevent explosive values
                    clip_grad_norm_(model.parameters(), max_norm=1.0)

                    # Optimizer step with scaling
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                # Track loss
                total_loss += loss.item() * accum_steps

                # Update progress bar
                progress_bar.set_postfix({"loss": f"{loss.item() * accum_steps:.4f}"})

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue

        # Calculate average loss for the epoch
        avg_loss = total_loss / len(train_loader)
        train_time = time.time() - start_time

        print(f"Epoch {epoch} - Avg. Training Loss: {avg_loss:.4f} (Time: {train_time:.2f}s)")

        # Validation phase
        if val_loader is not None:
            val_metrics = evaluate_model(model, val_loader, device=device)

            # Log validation metrics
            print(f"Validation BLEU-4: {val_metrics['bleu4']:.4f}")
            print(f"Validation CIDEr: {val_metrics['cider']:.4f}")
            print(f"Validation SPICE: {val_metrics['spice']:.4f}")
            print(f"Validation ROUGE: {val_metrics['rouge']:.4f}")

            # Use CIDEr + BLEU-4 as overall score for early stopping
            current_score = val_metrics['cider'] + val_metrics['bleu4']
        else:
            # If no validation set, use negative training loss as score
            current_score = -avg_loss

        # Update learning rate
        scheduler.step()

        # Save checkpoint
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_score": best_score
        }, checkpoint_path)

        # Check for improvement
        if current_score > best_score:
            best_score = current_score
            no_improve_epochs = 0

            # Save best model
            best_model_path = os.path.join(checkpoint_dir, "best_model.pth")
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "best_score": best_score
            }, best_model_path)

            print(f"New best model saved at epoch {epoch}")
        else:
            no_improve_epochs += 1

        # Early stopping
        if no_improve_epochs >= patience:
            print(f"No improvement for {patience} epochs. Early stopping.")
            break

    print("Training completed!")
    return model

In [None]:


# 6. Evaluation Function
def evaluate_model(model, val_loader, device="cuda", max_samples=None):
    """
    Evaluate the model on validation set with BLEU, CIDEr, SPICE, and ROUGE metrics.
    """
    model.eval()
    predictions = []
    references = []

    # Count samples for potential limit
    sample_count = 0

    # Process validation batches
    for batch_idx, batch in enumerate(tqdm(val_loader, desc="Evaluating")):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Generate captions with beam search
        with torch.no_grad():
            try:
                output_ids = model.generate(
                    pixel_values=batch["pixel_values"],
                    max_length=50,
                    num_beams=5,
                    length_penalty=1.0,
                    no_repeat_ngram_size=2
                )

                # Decode generated captions
                pred_captions = model.processor.batch_decode(output_ids, skip_special_tokens=True)

                # Decode reference captions
                ref_captions = model.processor.batch_decode(batch["labels"], skip_special_tokens=True)

                # Store predictions and references
                predictions.extend(pred_captions)
                references.extend([[ref] for ref in ref_captions])  # BLEU expects list of references per example

                # Update sample count
                sample_count += len(pred_captions)

                # Print sample predictions (first batch only)
                if batch_idx == 0:
                    print("\nSample predictions:")
                    for i in range(min(3, len(pred_captions))):
                        print(f"  Reference: {ref_captions[i]}")
                        print(f"  Prediction: {pred_captions[i]}")
                        print()

                # Check if we've processed enough samples
                if max_samples is not None and sample_count >= max_samples:
                    break

            except Exception as e:
                print(f"Error generating captions for batch {batch_idx}: {e}")
                continue

    # Compute BLEU score
    bleu4 = corpus_bleu(references, predictions, weights=(0.25, 0.25, 0.25, 0.25))

    # Prepare data for other metrics
    metric_refs = {i: [ref[0]] for i, ref in enumerate(references)}
    metric_preds = {i: [pred] for i, pred in enumerate(predictions)}

    # Compute other metrics if available
    try:
        cider_score = Cider().compute_score(metric_refs, metric_preds)[0]
        spice_score = Spice().compute_score(metric_refs, metric_preds)[0]
        rouge_score = Rouge().compute_score(metric_refs, metric_preds)[0]
    except Exception as e:
        print(f"Error computing metrics: {e}")
        cider_score = 0.0
        spice_score = 0.0
        rouge_score = 0.0

    return {
        "bleu4": bleu4,
        "cider": cider_score,
        "spice": spice_score,
        "rouge": rouge_score
    }

# 7. Main function to run the training pipeline
def main():
    # Set random seed for reproducibility
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load dataset
    print("Loading dataset...")
    try:
        dataset = load_dataset("SwayStar123/preprocessed_recap-coco30k-moondream")['train']
        print(f"Dataset loaded with {len(dataset)} samples")

        # Split dataset into train and validation
        split_ratio = 0.9
        train_size = int(split_ratio * len(dataset))
        train_ds = dataset.select(range(train_size))
        val_ds = dataset.select(range(train_size, len(dataset)))
        print(f"Training samples: {len(train_ds)}, Validation samples: {len(val_ds)}")

        # For faster development, uncomment to use a subset
        # train_ds = train_ds.select(range(min(5000, len(train_ds))))
        # val_ds = val_ds.select(range(min(500, len(val_ds))))
        # print(f"Using subset - Training: {len(train_ds)}, Validation: {len(val_ds)}")

        # Create data loaders with smaller batch size for L4 GPU
        train_loader = DataLoader(
            train_ds,
            batch_size=8,  # Smaller batch size for L4 GPU
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=2,  # Parallel loading
            pin_memory=True  # Faster GPU transfer
        )

        val_loader = DataLoader(
            val_ds,
            batch_size=8,
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=2,
            pin_memory=True
        )

        # Initialize enhanced model
        print("Initializing enhanced BLIP model...")
        model = EnhancedBLIP(sparsity=0.7)  # Adjust sparsity as needed
        model.to(device)

        # Train model
        print("Starting training...")
        train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            num_epochs=10,  # Increase for better results
            lr=2e-5,
            device=device,
            checkpoint_dir="token_gating_checkpoints"
        )

        # Final evaluation
        print("Performing final evaluation...")
        final_metrics = evaluate_model(model, val_loader, device=device)

        print("Final Evaluation Results:")
        print(f"BLEU-4: {final_metrics['bleu4']:.4f}")
        print(f"CIDEr: {final_metrics['cider']:.4f}")
        print(f"SPICE: {final_metrics['spice']:.4f}")
        print(f"ROUGE: {final_metrics['rouge']:.4f}")

    except Exception as e:
        print(f"Error in main function: {e}")

if __name__ == "__main__":
    main()

Using device: cuda
Loading dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Resolving data files:   0%|          | 0/96 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/96 [00:00<?, ?files/s]

train-00001-of-00004.parquet:   0%|          | 0.00/30.1k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/32.2k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/185k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/287k [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/105k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/86.4k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/107k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/29.7k [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/107k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/163k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/30.5k [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/56.6k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/55.9k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/158k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/31.5k [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/30.2k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/56.1k [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/31.3k [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/396k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/58.4k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/267k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/162k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/250k [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/276k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/168k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/1.19M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/221k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/251k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/86.0k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/1.10M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/1.14M [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/632k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/396k [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/631k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/605k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/2.99M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/3.30M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/3.53M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/3.71M [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/4.18M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/4.53M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/4.62M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/785k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/4.48M [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/46.5M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/46.1M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/45.2M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/44.5M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/44.8M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/47.3M [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/2.03M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/2.35M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/2.44M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/45.2M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/2.21M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/8.10M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/7.28M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/7.12M [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/1.40M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/1.65M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/1.15M [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/7.50M [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/14.6M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/14.4M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/14.3M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/14.8M [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/1.14M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/924k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/553k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/965k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/1.02M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/983k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/814k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/109k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/84.1k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/108k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/58.4k [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/214k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/16.7M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/32.0k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/169k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/58.5k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/56.2k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/30.1k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/29.4k [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/31.2k [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/30.1k [00:00<?, ?B/s]

train-00000-of-00004.parquet:   0%|          | 0.00/58.2k [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/33.0k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/30504 [00:00<?, ? examples/s]

Dataset loaded with 30504 samples
Training samples: 27453, Validation samples: 3051
Initializing enhanced BLIP model...
Loading base BLIP model...


config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Model hidden dimension: 768
Enhanced BLIP model initialized
Starting training...
Starting training for 10 epochs


  scaler = GradScaler(enabled=(device == "cuda"))


Epoch 1/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 1 - Avg. Training Loss: 1.5320 (Time: 2770.79s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a man and a woman are seated at a table in a restaurant. the man is wearing a blue shirt and glasses, while the woman is holding a plate of food in her lap. they are both smiling at the camera.

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the person is ridin

Epoch 2/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 2 - Avg. Training Loss: 0.8048 (Time: 2591.35s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: the image depicts a baseball game in progress. a batter wearing a white uniform with red accents is mid - swing, while an umpire in a black uniform stands behind him. the field is a vibrant green, contrasting with the brown dirt around home plate

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold fir

Epoch 3/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 3 - Avg. Training Loss: 0.7183 (Time: 2566.41s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a group of people are gathered around a table in what appears to be a restaurant or bar. the table is covered with a white tablecloth and features two plates of food - one containing a slice of pizza and the other holding

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in thei

Epoch 4/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 4 - Avg. Training Loss: 0.6599 (Time: 3346.25s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: the image shows a white plate on a wooden table, containing a variety of food items. on the left side of the plate is a piece of meat with a golden - brown crust and a vibrant red sauce drizzled over it. in

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the person is riding

Epoch 5/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 5 - Avg. Training Loss: 0.6142 (Time: 3302.33s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a group of people are gathered around a table in a restaurant. the table is draped with a white tablecloth and features an array of dishes and drinks, including plates of food, glasses filled with water, and bottles of wine

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in th

Epoch 6/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 6 - Avg. Training Loss: 0.5759 (Time: 3482.97s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: the image depicts a table set with a variety of food and drink items. on the left side, there is a white plate holding a slice of pizza topped with red sauce and melted cheese. in the center of the table, two glasses filled with

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands

Epoch 7/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 7 - Avg. Training Loss: 0.5443 (Time: 3533.00s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a person is seated at a table in what appears to be a restaurant or cafe. the table is covered with a white tablecloth and features two plates of pizza - one with red sauce and the other with green sauce. a

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the pe

Epoch 8/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 8 - Avg. Training Loss: 0.5198 (Time: 4052.45s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: the image depicts a wooden table set with a variety of food and drink items. on the left side, there is a plate holding a slice of pizza topped with red sauce, melted cheese, and fresh basil leaves. in the center, another plate

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands.

Epoch 9/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 9 - Avg. Training Loss: 0.5028 (Time: 3019.38s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a person is seated at a table in what appears to be a restaurant or cafe. the individual is wearing a vibrant red shirt and has their head tilted back as they enjoy a slice of pizza. on the table, there are

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the pe

Epoch 10/10:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 10 - Avg. Training Loss: 0.4932 (Time: 3149.13s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a group of people are gathered around a table in a restaurant. the table is covered with a white tablecloth and features two pizzas - one with red sauce and toppings of pepperoni and mushrooms, and the other with

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. 

Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a group of people are gathered around a table in a restaurant. the table is covered with a white tablecloth and features two pizzas - one with red sauce and toppings of pepperoni and mushrooms, and the other with

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. 