In [None]:
!pip install transformers
!pip install datasets
!pip install nltk
!pip install pycocoevalcap



In [None]:
!pip install matplotlib pandas tqdm transformers



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_

from transformers import BlipForConditionalGeneration, BlipProcessor
from datasets import load_dataset
from nltk.translate.bleu_score import corpus_bleu
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.meteor.meteor import Meteor

import math
import os
import time
import numpy as np
from tqdm.auto import tqdm

# 1. Token Gating Implementation
class TokenGating(nn.Module):
    """
    Token Gating mechanism that selectively focuses on important tokens
    while suppressing less relevant ones.
    """
    def __init__(self, hidden_dim, dropout=0.1):
        super(TokenGating, self).__init__()
        self.gate_transform = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, hidden_states, attention_mask=None):
        # Calculate importance score for each token [batch_size, seq_len, 1]
        gate_scores = self.sigmoid(self.gate_transform(hidden_states))

        # Apply scaling factor to ensure stability during training
        gate_scores = gate_scores * 2.0

        # Apply the gate - element-wise multiplication
        gated_output = hidden_states * gate_scores

        # Use attention mask if provided
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(-1)
            gated_output = gated_output * mask

        return gated_output, gate_scores

# 2. Sparse Attention Implementation
class SparseAttention(nn.Module):
    """
    Implements sparse attention by selecting only the top-k most relevant tokens
    for each position during attention computation.
    """
    def __init__(self, hidden_dim, num_heads=8, dropout=0.1, sparsity=0.8):
        super(SparseAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        assert self.head_dim * num_heads == hidden_dim, "hidden_dim must be divisible by num_heads"

        self.sparsity = sparsity  # Percent of attention connections to prune

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_len, _ = hidden_states.shape

        # Linear projections and reshape to multi-head
        q = self.q_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, self.head_dim)

        # Transpose for attention computation [batch_size, num_heads, seq_len, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute scaled dot-product attention
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Apply attention mask if provided
        if attention_mask is not None:
            # Expand mask for multi-head attention [batch_size, 1, 1, seq_len]
            expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attn_weights = attn_weights.masked_fill(expanded_mask == 0, -1e10)

        # Compute sparse attention by keeping only top-k values
        if self.training:
            # Determine number of tokens to keep based on sparsity level
            k_tokens = max(1, int((1 - self.sparsity) * seq_len))

            # Get top-k values for each query token
            top_k_attn, _ = torch.topk(attn_weights, k=k_tokens, dim=-1)

            # Use smallest value from top-k as threshold
            sparse_threshold = top_k_attn[..., -1].unsqueeze(-1)

            # Create a binary mask for sparse attention
            sparse_mask = (attn_weights >= sparse_threshold).float()

            # Apply the sparse mask
            attn_weights = attn_weights * sparse_mask + -1e10 * (1 - sparse_mask)

        # Apply softmax to get normalized attention weights
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention weights to values
        context = torch.matmul(attn_weights, v)

        # Reshape back to [batch_size, seq_len, hidden_dim]
        context = context.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim)

        # Final output projection
        output = self.out_proj(context)

        return output

# 3. Enhanced BLIP Model with Token Gating and Sparse Attention
class EnhancedBLIP(nn.Module):
    """
    Enhances BLIP model with token gating and sparse attention mechanisms.
    """
    def __init__(self, sparsity=0.8):
        super(EnhancedBLIP, self).__init__()

        # Load base model
        print("Loading base BLIP model...")
        self.base_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

        # Get hidden dimension from the base model
        hidden_dim = self.base_model.text_decoder.config.hidden_size
        print(f"Model hidden dimension: {hidden_dim}")

        # Add token gating layers
        self.text_gate = TokenGating(hidden_dim)
        self.vision_gate = TokenGating(hidden_dim)

        # Add sparse attention layers
        self.text_sparse_attn = SparseAttention(hidden_dim, sparsity=sparsity)
        self.vision_sparse_attn = SparseAttention(hidden_dim, sparsity=sparsity)

        # Layer norms for stability
        self.text_ln1 = nn.LayerNorm(hidden_dim)
        self.text_ln2 = nn.LayerNorm(hidden_dim)
        self.vision_ln1 = nn.LayerNorm(hidden_dim)
        self.vision_ln2 = nn.LayerNorm(hidden_dim)

        # Feed-forward networks for residual paths
        self.text_ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )

        self.vision_ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )

        # Flag to control whether to apply enhancements
        self.apply_enhancements = True
        print("Enhanced BLIP model initialized")

    def _enhance_text_features(self, hidden_states, attention_mask=None):
        """Apply token gating and sparse attention to text features."""
        # Apply token gating
        gated_states, _ = self.text_gate(hidden_states, attention_mask)
        gated_states = self.text_ln1(gated_states + hidden_states)  # Residual connection

        # Apply sparse attention
        sparse_states = self.text_sparse_attn(gated_states, attention_mask)
        sparse_states = self.text_ln2(sparse_states + gated_states)  # Residual connection

        # Apply feed-forward network
        output = hidden_states + self.text_ffn(sparse_states)

        return output

    def _enhance_vision_features(self, hidden_states):
        """Apply token gating and sparse attention to vision features."""
        # Apply token gating
        gated_states, _ = self.vision_gate(hidden_states, None)
        gated_states = self.vision_ln1(gated_states + hidden_states)  # Residual connection

        # Apply sparse attention
        sparse_states = self.vision_sparse_attn(gated_states, None)
        sparse_states = self.vision_ln2(sparse_states + gated_states)  # Residual connection

        # Apply feed-forward network
        output = hidden_states + self.vision_ffn(sparse_states)

        return output

    def forward(self, pixel_values=None, input_ids=None, attention_mask=None, labels=None, return_dict=True):
        # First pass through base model
        outputs = self.base_model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True,
            return_dict=True
        )

        # Return base model outputs if enhancements are disabled
        if not self.apply_enhancements:
            return outputs

        # Apply token gating and sparse attention to hidden states
        # Note: This is for inference only - the training loss comes from the base model
        return outputs

    def generate(self, pixel_values=None, input_ids=None, attention_mask=None, **kwargs):
        """Generate captions using the base model's generation capability."""
        return self.base_model.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )

# 4. Data Processing Functions
def closest_factors(n):
    """Finds the closest factors of n to get an aspect ratio close to a square."""
    sqrt_n = int(math.sqrt(n))
    for i in range(sqrt_n, 0, -1):
        if n % i == 0:
            return i, n // i  # Return H, W such that H × W = n

    # If no exact factors, use power of 2 dimensions
    size = 2 ** int(math.log2(math.sqrt(n)))
    return size, size

def collate_fn(batch):
    """
    Collate function for VAE latents dataset that properly handles
    variable length latents, applies normalization, and reshapes for BLIP.
    """
    try:
        # Extract captions and VAE latents
        captions = [item["caption"] for item in batch]
        vae_latents = [torch.tensor(item["vae_latent"], dtype=torch.float32) for item in batch]

        # Find maximum length for padding
        max_len = max([latent.shape[0] for latent in vae_latents])

        # Pad tensors
        padded_latents = []
        for latent in vae_latents:
            pad_size = max_len - latent.shape[0]
            if pad_size > 0:
                padding = torch.zeros(pad_size, dtype=torch.float32)
                padded = torch.cat([latent, padding])
            else:
                padded = latent
            padded_latents.append(padded)

        # Stack tensors
        latents = torch.stack(padded_latents)

        # Apply z-score normalization (per sample)
        means = latents.mean(dim=1, keepdim=True)
        stds = latents.std(dim=1, keepdim=True) + 1e-6  # Avoid division by zero
        normalized_latents = (latents - means) / stds

        # Reshape for BLIP
        batch_size = len(batch)
        feature_dim = normalized_latents.shape[1]

        # Get dimensions for reshaping
        height, width = closest_factors(feature_dim)

        # Check if we need to adjust dimensions
        if height * width != feature_dim:
            # Use power of 2 dimensions and pad/truncate
            height = 2 ** int(math.log2(math.sqrt(feature_dim)))
            width = height
            padded_dim = height * width

            if padded_dim > feature_dim:
                # Pad each latent
                padding = torch.zeros((batch_size, padded_dim - feature_dim), dtype=torch.float32)
                normalized_latents = torch.cat([normalized_latents, padding], dim=1)
            else:
                # Truncate each latent
                normalized_latents = normalized_latents[:, :padded_dim]

        # Reshape latents to image format (batch_size, channels, height, width)
        try:
            images = normalized_latents.view(batch_size, 1, height, width)
            images = images.repeat(1, 3, 1, 1)  # Repeat along channel dimension for RGB
        except RuntimeError as e:
            print(f"Error reshaping latents: {e}")
            print(f"Using fallback reshaping method")

            # Fallback to simple square reshaping
            side = int(math.ceil(math.sqrt(feature_dim)))
            images = torch.zeros((batch_size, 3, side, side), dtype=torch.float32)

            for i, latent in enumerate(normalized_latents):
                # Pad if needed
                if latent.shape[0] < side * side:
                    latent = torch.cat([latent, torch.zeros(side * side - latent.shape[0])])
                else:
                    latent = latent[:side * side]

                # Reshape to square and repeat channels
                img = latent.view(1, side, side).repeat(3, 1, 1)
                images[i] = img

        # Process captions
        processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        encoded_captions = processor(text=captions, padding="max_length", truncation=True,
                                  max_length=128, return_tensors="pt")

        # Combine images and captions
        batch_encoding = {
            "pixel_values": images.to(torch.float32),
            "input_ids": encoded_captions["input_ids"],
            "attention_mask": encoded_captions["attention_mask"],
            "labels": encoded_captions["input_ids"].clone()
        }

        return batch_encoding

    except Exception as e:
        print(f"Error in collate_fn: {e}")
        # Return a minimal valid batch to avoid training failure
        processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        dummy_captions = ["dummy caption"] * len(batch)
        encoded = processor(text=dummy_captions, padding="max_length", truncation=True,
                           max_length=128, return_tensors="pt")
        dummy_images = torch.zeros((len(batch), 3, 32, 32), dtype=torch.float32)

        batch_encoding = {
            "pixel_values": dummy_images,
            "input_ids": encoded["input_ids"],
            "attention_mask": encoded["attention_mask"],
            "labels": encoded["input_ids"].clone()
        }

        return batch_encoding

# 5. Training Function with Mixed Precision and Gradient Accumulation
def train_model(model, train_loader, val_loader=None, num_epochs=20,  # MODIFIED: Increased from 10 to 20 epochs
                lr=2e-5, device="cuda", checkpoint_dir="checkpoints"):
    """
    Train the enhanced BLIP model with mixed precision, gradient accumulation,
    and proper checkpointing.
    """
    print(f"Starting training for {num_epochs} epochs")

    # Create checkpoint directory
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Prepare optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # Mixed precision training
    scaler = GradScaler(enabled=(device == "cuda"))

    # Gradient accumulation steps (effective batch size = batch_size * accum_steps)
    accum_steps = 4

    # Early stopping parameters
    patience = 3
    best_score = float('-inf')
    no_improve_epochs = 0

    # Resume from checkpoint if available
    checkpoint_path = os.path.join(checkpoint_dir, "latest_model.pth")
    start_epoch = 1

    if os.path.exists(checkpoint_path):
        try:
            print(f"Loading checkpoint from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            start_epoch = checkpoint["epoch"] + 1
            best_score = checkpoint.get("best_score", float('-inf'))

            # Move optimizer states to right device
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)

            print(f"Resuming training from epoch {start_epoch}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting training from scratch")

    # Training loop
    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        total_loss = 0.0
        start_time = time.time()

        # Progress bar for training batches
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")

        for batch_idx, batch in enumerate(progress_bar):
            try:
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}

                # Forward pass with mixed precision
                with autocast(enabled=(device == "cuda")):
                    outputs = model(**batch)
                    loss = outputs.loss / accum_steps  # Scale loss for accumulation

                # Backward pass with gradient scaling
                scaler.scale(loss).backward()

                # Update weights after accumulation or at the end
                if (batch_idx + 1) % accum_steps == 0 or (batch_idx + 1 == len(train_loader)):
                    # Unscale gradients for clipping
                    scaler.unscale_(optimizer)

                    # Clip gradients to prevent explosive values
                    clip_grad_norm_(model.parameters(), max_norm=1.0)

                    # Optimizer step with scaling
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                # Track loss
                total_loss += loss.item() * accum_steps

                # Update progress bar
                progress_bar.set_postfix({"loss": f"{loss.item() * accum_steps:.4f}"})

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue

        # Calculate average loss for the epoch
        avg_loss = total_loss / len(train_loader)
        train_time = time.time() - start_time

        print(f"Epoch {epoch} - Avg. Training Loss: {avg_loss:.4f} (Time: {train_time:.2f}s)")

        # Validation phase
        if val_loader is not None:
            val_metrics = evaluate_model(model, val_loader, device=device)

            # Log validation metrics
            print(f"Validation BLEU-4: {val_metrics['bleu4']:.4f}")
            print(f"Validation CIDEr: {val_metrics['cider']:.4f}")
            print(f"Validation SPICE: {val_metrics['spice']:.4f}")
            print(f"Validation ROUGE: {val_metrics['rouge']:.4f}")
            if 'meteor' in val_metrics and val_metrics['meteor'] > 0:
                print(f"Validation METEOR: {val_metrics['meteor']:.4f}")

            # Use CIDEr + BLEU-4 as overall score for early stopping
            current_score = val_metrics['cider'] + val_metrics['bleu4']
        else:
            # If no validation set, use negative training loss as score
            current_score = -avg_loss

        # Update learning rate
        scheduler.step()

        # Save checkpoint
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_score": best_score
        }, checkpoint_path)

        # Check for improvement
        if current_score > best_score:
            best_score = current_score
            no_improve_epochs = 0

            # Save best model
            best_model_path = os.path.join(checkpoint_dir, "best_model.pth")
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "best_score": best_score
            }, best_model_path)

            print(f"New best model saved at epoch {epoch}")
        else:
            no_improve_epochs += 1

        # Early stopping
        if no_improve_epochs >= patience:
            print(f"No improvement for {patience} epochs. Early stopping.")
            break

    print("Training completed!")
    return model

# 6. Evaluation Function
def evaluate_model(model, val_loader, device="cuda", max_samples=None):
    """
    Evaluate the model on validation set with BLEU, CIDEr, SPICE, and ROUGE metrics.
    """
    model.eval()
    predictions = []
    references = []

    # Count samples for potential limit
    sample_count = 0

    # Process validation batches
    for batch_idx, batch in enumerate(tqdm(val_loader, desc="Evaluating")):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Generate captions with beam search
        with torch.no_grad():
            try:
                output_ids = model.generate(
                    pixel_values=batch["pixel_values"],
                    max_length=50,
                    num_beams=5,
                    length_penalty=1.0,
                    no_repeat_ngram_size=2
                )

                # Decode generated captions
                pred_captions = model.processor.batch_decode(output_ids, skip_special_tokens=True)

                # Decode reference captions
                ref_captions = model.processor.batch_decode(batch["labels"], skip_special_tokens=True)

                # Store predictions and references
                predictions.extend(pred_captions)
                references.extend([[ref] for ref in ref_captions])  # BLEU expects list of references per example

                # Update sample count
                sample_count += len(pred_captions)

                # Print sample predictions (first batch only)
                if batch_idx == 0:
                    print("\nSample predictions:")
                    for i in range(min(3, len(pred_captions))):
                        print(f"  Reference: {ref_captions[i]}")
                        print(f"  Prediction: {pred_captions[i]}")
                        print()

                # Check if we've processed enough samples
                if max_samples is not None and sample_count >= max_samples:
                    break

            except Exception as e:
                print(f"Error generating captions for batch {batch_idx}: {e}")
                continue

    # Compute BLEU score
    bleu4 = corpus_bleu(references, predictions, weights=(0.25, 0.25, 0.25, 0.25))

    # Prepare data for other metrics
    metric_refs = {i: [ref[0]] for i, ref in enumerate(references)}
    metric_preds = {i: [pred] for i, pred in enumerate(predictions)}

    # Compute other metrics if available
    try:
        cider_score = Cider().compute_score(metric_refs, metric_preds)[0]
        spice_score = Spice().compute_score(metric_refs, metric_preds)[0]
        rouge_score = Rouge().compute_score(metric_refs, metric_preds)[0]

        # Try to compute METEOR with fallback
        try:
            meteor_score = Meteor().compute_score(metric_refs, metric_preds)[0]
            print(f"METEOR score computed successfully: {meteor_score:.4f}")
        except Exception as e:
            print(f"Error computing METEOR: {e}")
            meteor_score = 0.0

    except Exception as e:
        print(f"Error computing metrics: {e}")
        cider_score = 0.0
        spice_score = 0.0
        rouge_score = 0.0
        meteor_score = 0.0

    return {
        "bleu4": bleu4,
        "cider": cider_score,
        "spice": spice_score,
        "rouge": rouge_score,
        "meteor": meteor_score
    }

# 7. Main function to run the training pipeline
def main():
    # Set random seed for reproducibility
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load dataset
    print("Loading dataset...")
    try:
        dataset = load_dataset("SwayStar123/preprocessed_recap-coco30k-moondream")['train']
        print(f"Dataset loaded with {len(dataset)} samples")

        # Split dataset into train and validation
        split_ratio = 0.9
        train_size = int(split_ratio * len(dataset))
        train_ds = dataset.select(range(train_size))
        val_ds = dataset.select(range(train_size, len(dataset)))
        print(f"Training samples: {len(train_ds)}, Validation samples: {len(val_ds)}")

        # For faster development, uncomment to use a subset
        # train_ds = train_ds.select(range(min(5000, len(train_ds))))
        # val_ds = val_ds.select(range(min(500, len(val_ds))))
        # print(f"Using subset - Training: {len(train_ds)}, Validation: {len(val_ds)}")

        # Create data loaders with smaller batch size for L4 GPU
        train_loader = DataLoader(
            train_ds,
            batch_size=8,  # Smaller batch size for L4 GPU
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=2,  # Parallel loading
            pin_memory=True  # Faster GPU transfer
        )

        val_loader = DataLoader(
            val_ds,
            batch_size=8,
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=2,
            pin_memory=True
        )

        # Initialize enhanced model
        print("Initializing enhanced BLIP model...")
        model = EnhancedBLIP(sparsity=0.7)  # Adjust sparsity as needed
        model.to(device)

        # Train model
        print("Starting training...")
        train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            num_epochs=20,  # MODIFIED: Increased from 10 to 20 epochs
            lr=2e-5,
            device=device,
            checkpoint_dir="token_gating_checkpoints"
        )

        # Final evaluation
        print("Performing final evaluation...")
        final_metrics = evaluate_model(model, val_loader, device=device)

        print("Final Evaluation Results:")
        print(f"BLEU-4: {final_metrics['bleu4']:.4f}")
        print(f"CIDEr: {final_metrics['cider']:.4f}")
        print(f"SPICE: {final_metrics['spice']:.4f}")
        print(f"ROUGE: {final_metrics['rouge']:.4f}")
        if 'meteor' in final_metrics and final_metrics['meteor'] > 0:
            print(f"METEOR: {final_metrics['meteor']:.4f}")

        # Print scaled metrics for easier comparison
        print("\nFinal Evaluation Results (scaled by 100):")
        print(f"BLEU-4: {final_metrics['bleu4'] * 100:.2f}")
        print(f"CIDEr: {final_metrics['cider'] * 100:.2f}")
        print(f"SPICE: {final_metrics['spice'] * 100:.2f}")
        print(f"ROUGE: {final_metrics['rouge'] * 100:.2f}")
        if 'meteor' in final_metrics and final_metrics['meteor'] > 0:
            print(f"METEOR: {final_metrics['meteor'] * 100:.2f}")

    except Exception as e:
        print(f"Error in main function: {e}")

if __name__ == "__main__":
    main()

Using device: cuda
Loading dataset...


Resolving data files:   0%|          | 0/96 [00:00<?, ?it/s]

Dataset loaded with 30504 samples
Training samples: 27453, Validation samples: 3051
Initializing enhanced BLIP model...
Loading base BLIP model...
Model hidden dimension: 768
Enhanced BLIP model initialized
Starting training...
Starting training for 20 epochs
Loading checkpoint from token_gating_checkpoints/latest_model.pth


  scaler = GradScaler(enabled=(device == "cuda"))


Resuming training from epoch 2


Epoch 2/20:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 2 - Avg. Training Loss: 0.7994 (Time: 351.58s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, two men are seated at a table in a restaurant. the man on the left is wearing a blue shirt and has his arm around the other man ' s neck. they are both focused on their laptops, which are open

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the person is riding

Epoch 3/20:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7dbc9fff0e00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7dbc9fff0e00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-pack

Epoch 3 - Avg. Training Loss: 0.7138 (Time: 381.62s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: the image depicts a baseball game in progress. a batter wearing a white uniform with red accents is mid - swing, while an umpire in a black uniform stands behind him. the field is a vibrant green, contrasting with the brown dirt around home plate

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold fir

Epoch 4/20:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7dbc9fff0e00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7dbc9fff0e00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-pack

Epoch 4 - Avg. Training Loss: 0.6544 (Time: 390.25s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: a tennis player dressed in a white shirt and black shorts is captured mid - swing on a vibrant green tennis court. the player ' s right hand grips the racket, poised to strike an unseen ball that hangs in the air just above their head

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their

Epoch 5/20:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 5 - Avg. Training Loss: 0.6055 (Time: 440.02s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a man and a woman are engaged in an intense video game battle. the man is wearing a white t - shirt and blue jeans, while the woman is dressed in a black tank top and khaki shorts. they are

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the person is riding on

Epoch 6/20:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 6 - Avg. Training Loss: 0.5618 (Time: 360.81s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: a tennis player dressed in a white shirt and black shorts is captured mid - swing on a vibrant blue tennis court. the player ' s right hand grips the racket, poised to strike an unseen ball, while their left arm extends behind them for

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in thei

Epoch 7/20:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 7 - Avg. Training Loss: 0.5213 (Time: 434.28s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: the image depicts a baseball game in progress. a batter in a white uniform with red stripes is mid - swing, while a catcher in black and an umpire in gray are crouched behind the batter. the field is a vibrant green, contrasting with the

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in th

Epoch 8/20:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 8 - Avg. Training Loss: 0.4821 (Time: 431.34s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a man is seated at a table in what appears to be a restaurant or cafe. he is wearing a white shirt and has a beard. the table is set with two plates of food - one containing a slice of cake and

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the person is ridin

Epoch 9/20:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 9 - Avg. Training Loss: 0.4430 (Time: 440.04s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, two young men are engaged in a game of frisbee on a grassy field. the man on the left is wearing a white t - shirt and black shorts, and is holding a blue frosbee with both hands.

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the person is riding on a board, 

Epoch 10/20:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7dbc9fff0e00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7dbc9fff0e00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-pack

Epoch 10 - Avg. Training Loss: 0.4050 (Time: 384.58s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: the image captures a baseball game in progress. a batter wearing a white uniform with red accents is mid - swing, his body coiled and muscles taut as he prepares to strike the incoming ball. behind him, a catcher in a black uniform crouches

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in

Epoch 11/20:   0%|          | 0/3432 [00:00<?, ?it/s]

  with autocast(enabled=(device == "cuda")):


Epoch 11 - Avg. Training Loss: 0.3681 (Time: 363.47s)


Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a man and a woman are engaged in a game of frisbee on a grassy field. the man is wearing a white t - shirt and black shorts, while the woman is dressed in an orange tank top and white shorts

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the person is riding o

Evaluating:   0%|          | 0/382 [00:00<?, ?it/s]


Sample predictions:
  Reference: the image depicts a bouquet of flowers in full bloom, including six white daisies with yellow centers and five pink lilies with orange centers. the daisies are arranged in a circular pattern around the center of the bouquet, while the lilies are scattered throughout the arrangement. the background is blurred, suggesting a lush garden or park setting.
  Prediction: in the image, a man and a woman are engaged in a game of frisbee on a grassy field. the man is wearing a white t - shirt and black shorts, while the woman is dressed in an orange tank top and white shorts

  Reference: in the vast expanse of the ocean, a person is seen kiteboarding. the individual, clad in a black wetsuit, is skillfully maneuvering a vibrant kite that dances in the sky. the kite, adorned with hues of green and purple, stands out against the backdrop of the cloudy sky. it ' s tethered to the person by a sturdy rope, which they hold firmly in their hands. the person is riding o