# Week 8.1: Large Vision Language Models - Core Intuitions

**Resource Requirements**: Google Colab (T4) GPU or Provisioned GPU with >8GB VRAM

## Learning Objectives

By the end of this notebook, you will understand:
- How Vision Transformers (ViT) convert images into sequences of patches
- The evolution from CLIP to SigLIP for vision-language alignment
- How vision and language models are integrated through cross-attention mechanisms
- The architecture of PaliGemma as a concrete example of a Vision Language Model

---

## Setup and Installation

First, let's install the necessary packages for our exploration of Vision Language Models.

In [None]:
# Install required packages
!pip install torch torchvision transformers pillow numpy matplotlib -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import math
from typing import Optional, Tuple

---

## Part 1: Understanding Vision Transformers (ViT)

Vision Transformers revolutionized computer vision by applying the transformer architecture to images. The key insight is treating an image as a sequence of patches, similar to how text is treated as a sequence of tokens.

### Core Concepts:
1. **Patch Embedding**: Divide image into fixed-size patches
2. **Positional Encoding**: Add spatial information to patches
3. **Self-Attention**: Allow patches to attend to each other
4. **Transformer Blocks**: Stack multiple attention + MLP layers

### 1.1 Patch Embedding: Converting Images to Sequences

The first step in a Vision Transformer is converting an image into a sequence of patch embeddings. Let's implement this step by step:

In [None]:
class PatchEmbedding(nn.Module):
    """Converts an image into a sequence of patch embeddings"""
    
    def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        
        # Use Conv2d to extract patches and project to embedding dimension
        # Kernel size = patch size, stride = patch size for non-overlapping patches
        self.patch_embedding = nn.Conv2d(
            in_channels=num_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
            padding="valid"  # No padding
        )
        
    def forward(self, pixel_values):
        # Input shape: [Batch_Size, Channels, Height, Width]
        batch_size, channels, height, width = pixel_values.shape
        
        # Apply convolution to extract patches
        # Output shape: [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W]
        patch_embeds = self.patch_embedding(pixel_values)
        
        # Flatten spatial dimensions and transpose
        # Shape: [Batch_Size, Embed_Dim, Num_Patches]
        patch_embeds = patch_embeds.flatten(2)
        
        # Transpose to sequence format
        # Final shape: [Batch_Size, Num_Patches, Embed_Dim]
        patch_embeds = patch_embeds.transpose(1, 2)
        
        return patch_embeds

# Demonstrate patch embedding
patch_embed = PatchEmbedding()
dummy_image = torch.randn(1, 3, 224, 224)  # Batch of 1 RGB image
patches = patch_embed(dummy_image)
print(f"Input image shape: {dummy_image.shape}")
print(f"Output patches shape: {patches.shape}")
print(f"Number of patches: {patches.shape[1]}")

### 1.2 Positional Embeddings: Adding Spatial Information

Since transformers are permutation-invariant, we need to add positional information to help the model understand the spatial arrangement of patches:

In [None]:
class VisionEmbeddings(nn.Module):
    """Complete embedding layer with patch embedding and positional encoding"""
    
    def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
        super().__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, num_channels, embed_dim)
        
        # Calculate number of patches
        self.num_patches = (image_size // patch_size) ** 2
        
        # Learnable positional embeddings for each patch position
        self.position_embedding = nn.Embedding(self.num_patches, embed_dim)
        
        # Register buffer for position IDs (not learnable)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_patches).expand((1, -1)),
            persistent=False,
        )
        
    def forward(self, pixel_values):
        # Get patch embeddings
        embeddings = self.patch_embedding(pixel_values)
        
        # Add positional embeddings
        embeddings = embeddings + self.position_embedding(self.position_ids)
        
        return embeddings

# Demonstrate complete embedding
vision_embed = VisionEmbeddings()
embedded_patches = vision_embed(dummy_image)
print(f"Embedded patches with positions shape: {embedded_patches.shape}")

### 1.3 Multi-Head Self-Attention for Vision

The core of the Vision Transformer is the self-attention mechanism, which allows patches to attend to each other:

In [None]:
class VisionAttention(nn.Module):
    """Multi-head self-attention for vision transformers"""
    
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5  # 1/sqrt(d_k) for scaled dot-product
        
        # Linear projections for Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, hidden_states):
        batch_size, seq_len, _ = hidden_states.size()
        
        # Project to Q, K, V
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        
        # Reshape for multi-head attention
        # [Batch, Seq_Len, Embed_Dim] -> [Batch, Num_Heads, Seq_Len, Head_Dim]
        query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Calculate attention scores using Q * K^T / sqrt(d_k)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
        
        # Apply softmax to get attention probabilities
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, value_states)
        
        # Reshape back to [Batch, Seq_Len, Embed_Dim]
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
        
        # Final projection
        attn_output = self.out_proj(attn_output)
        
        return attn_output, attn_weights

# Demonstrate attention
attention = VisionAttention()
attended_patches, attention_weights = attention(embedded_patches)
print(f"Attention output shape: {attended_patches.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

### 1.4 Complete Vision Transformer Block

A Vision Transformer block combines attention with an MLP and uses residual connections:

In [None]:
class VisionMLP(nn.Module):
    """Feed-forward network (MLP) for vision transformer"""
    
    def __init__(self, embed_dim=768, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class VisionTransformerBlock(nn.Module):
    """Complete transformer block with attention and MLP"""
    
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        # Pre-normalization architecture
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attention = VisionAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = VisionMLP(embed_dim, mlp_ratio, dropout)
        
    def forward(self, hidden_states):
        # Self-attention block with residual connection
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states, _ = self.attention(hidden_states)
        hidden_states = residual + hidden_states
        
        # MLP block with residual connection
        residual = hidden_states
        hidden_states = self.norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        
        return hidden_states

# Stack multiple transformer blocks
class VisionTransformer(nn.Module):
    """Complete Vision Transformer model"""
    
    def __init__(self, image_size=224, patch_size=16, num_channels=3, 
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0):
        super().__init__()
        self.embeddings = VisionEmbeddings(image_size, patch_size, num_channels, embed_dim)
        self.blocks = nn.ModuleList([
            VisionTransformerBlock(embed_dim, num_heads, mlp_ratio)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, pixel_values):
        # Embed patches
        hidden_states = self.embeddings(pixel_values)
        
        # Pass through transformer blocks
        for block in self.blocks:
            hidden_states = block(hidden_states)
            
        # Final normalization
        hidden_states = self.norm(hidden_states)
        
        return hidden_states

# Create a small ViT model
vit = VisionTransformer(depth=3)  # Using 3 layers for demonstration
vision_features = vit(dummy_image)
print(f"Vision Transformer output shape: {vision_features.shape}")

### Visualizing Patch Extraction

In [None]:
# Visualize how an image is divided into patches
def visualize_patches(image_size=224, patch_size=16):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    
    # Create a sample image with gradient
    x = np.linspace(0, 1, image_size)
    y = np.linspace(0, 1, image_size)
    X, Y = np.meshgrid(x, y)
    image = np.sin(5 * X) * np.cos(5 * Y)
    
    # Show original image
    ax1.imshow(image, cmap='viridis')
    ax1.set_title('Original Image')
    ax1.axis('off')
    
    # Show patches
    ax2.imshow(image, cmap='viridis')
    
    # Draw patch boundaries
    for i in range(0, image_size, patch_size):
        ax2.axhline(i, color='red', linewidth=0.5)
        ax2.axvline(i, color='red', linewidth=0.5)
    
    ax2.set_title(f'Image divided into {patch_size}x{patch_size} patches')
    ax2.axis('off')
    
    num_patches = (image_size // patch_size) ** 2
    fig.suptitle(f'Total number of patches: {num_patches}')
    plt.tight_layout()
    plt.show()

visualize_patches()

---

## Part 2: Contrastive Learning - From CLIP to SigLIP

Contrastive Language-Image Pre-training (CLIP) and its successor SigLIP learn to align vision and language representations in a shared embedding space. This enables zero-shot image classification and forms the foundation for many vision-language models.

### Key Concepts:
1. **Dual Encoders**: Separate encoders for images and text
2. **Contrastive Loss**: Learn to match correct image-text pairs
3. **SigLIP Innovation**: Sigmoid loss instead of softmax for better efficiency

### 2.1 Building a CLIP-style Model

Let's implement a simplified CLIP architecture to understand the core concepts:

In [None]:
class CLIPModel(nn.Module):
    """Simplified CLIP model with vision and text encoders"""
    
    def __init__(self, vision_dim=768, text_dim=512, projection_dim=256):
        super().__init__()
        # Vision encoder (using our ViT)
        self.vision_encoder = VisionTransformer(depth=3)
        
        # Text encoder (simplified - in practice would be a transformer)
        self.text_encoder = nn.Sequential(
            nn.Embedding(50000, text_dim),  # Vocabulary size 50k
            nn.LSTM(text_dim, text_dim, batch_first=True),
        )
        
        # Projection heads to shared space
        self.vision_projection = nn.Linear(vision_dim, projection_dim)
        self.text_projection = nn.Linear(text_dim, projection_dim)
        
        # Temperature parameter for contrastive loss
        self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        
    def encode_image(self, pixel_values):
        # Encode image to features
        vision_features = self.vision_encoder(pixel_values)
        # Use mean pooling to get single vector per image
        vision_features = vision_features.mean(dim=1)
        # Project to shared space and normalize
        vision_embeds = self.vision_projection(vision_features)
        vision_embeds = F.normalize(vision_embeds, dim=-1)
        return vision_embeds
    
    def encode_text(self, input_ids):
        # Encode text to features
        text_embeds = self.text_encoder[0](input_ids)  # Embedding
        output, (h_n, c_n) = self.text_encoder[1](text_embeds)  # LSTM
        # Use the last hidden state from LSTM
        text_features = h_n[-1]  # Shape: [batch_size, text_dim]
        # Project to shared space and normalize
        text_embeds = self.text_projection(text_features)
        text_embeds = F.normalize(text_embeds, dim=-1)
        return text_embeds
    
    def forward(self, pixel_values, input_ids):
        # Encode both modalities
        image_embeds = self.encode_image(pixel_values)
        text_embeds = self.encode_text(input_ids)
        
        # Compute similarity matrix
        logit_scale = self.temperature.exp()
        logits_per_image = torch.matmul(image_embeds, text_embeds.t()) * logit_scale
        logits_per_text = logits_per_image.t()
        
        return logits_per_image, logits_per_text

# Create CLIP model
clip_model = CLIPModel()

# Dummy inputs
dummy_images = torch.randn(4, 3, 224, 224)  # Batch of 4 images
dummy_text = torch.randint(0, 50000, (4, 20))  # Batch of 4 text sequences

logits_per_image, logits_per_text = clip_model(dummy_images, dummy_text)
print(f"Similarity matrix shape (image->text): {logits_per_image.shape}")
print(f"Similarity matrix shape (text->image): {logits_per_text.shape}")

### 2.2 Contrastive Loss Implementation

The key to CLIP's training is the contrastive loss that pushes matching pairs together and non-matching pairs apart:

In [None]:
def clip_contrastive_loss(logits_per_image, logits_per_text):
    """Compute CLIP's symmetric cross-entropy loss"""
    batch_size = logits_per_image.shape[0]
    
    # Labels: diagonal elements are positive pairs
    labels = torch.arange(batch_size, device=logits_per_image.device)
    
    # Cross entropy loss for image->text
    loss_i2t = F.cross_entropy(logits_per_image, labels)
    
    # Cross entropy loss for text->image  
    loss_t2i = F.cross_entropy(logits_per_text, labels)
    
    # Total loss is average of both directions
    loss = (loss_i2t + loss_t2i) / 2
    
    return loss

# Compute loss for our dummy batch
loss = clip_contrastive_loss(logits_per_image, logits_per_text)
print(f"Contrastive loss: {loss.item():.4f}")

# Visualize similarity matrix
plt.figure(figsize=(6, 6))
plt.imshow(logits_per_image.detach().numpy(), cmap='RdBu_r')
plt.colorbar(label='Similarity')
plt.xlabel('Text Index')
plt.ylabel('Image Index')
plt.title('Image-Text Similarity Matrix\n(Diagonal should be high after training)')
plt.show()

### 2.3 SigLIP: Improving Efficiency with Sigmoid Loss

SigLIP replaces CLIP's softmax-based loss with a sigmoid loss, allowing more efficient training:

In [None]:
def siglip_loss(logits_per_image, logits_per_text):
    """SigLIP's sigmoid loss - more efficient than CLIP's softmax loss"""
    batch_size = logits_per_image.shape[0]
    
    # Create target matrix: 1 for matching pairs, -1 for non-matching
    targets = torch.eye(batch_size, device=logits_per_image.device) * 2 - 1
    
    # Sigmoid loss for image->text
    loss_i2t = F.logsigmoid(targets * logits_per_image).mean()
    
    # Sigmoid loss for text->image
    loss_t2i = F.logsigmoid(targets * logits_per_text).mean()
    
    # Total loss
    loss = -(loss_i2t + loss_t2i) / 2
    
    return loss

# Compare losses
clip_loss = clip_contrastive_loss(logits_per_image, logits_per_text)
siglip_loss_value = siglip_loss(logits_per_image, logits_per_text)

print(f"CLIP loss: {clip_loss.item():.4f}")
print(f"SigLIP loss: {siglip_loss_value.item():.4f}")
print("\nKey advantage: SigLIP doesn't require all negative pairs in a batch,")
print("enabling more efficient distributed training!")

---

## Part 3: Cross-Attention and Multimodal Integration

The final piece is understanding how vision and language models are integrated. PaliGemma demonstrates a simple yet effective approach: treating image patches as special tokens in the language model's input sequence.

### Key Concepts:
1. **Projection Layer**: Align vision features to language model space
2. **Token Merging**: Combine image and text tokens in a single sequence  
3. **Unified Attention**: Let all tokens attend to each other

### 3.1 Multimodal Projector

First, we need to project vision features to the language model's embedding space:

In [None]:
class MultiModalProjector(nn.Module):
    """Projects vision features to language model embedding space"""
    
    def __init__(self, vision_hidden_size=768, text_hidden_size=2048):
        super().__init__()
        # Simple linear projection
        self.linear = nn.Linear(vision_hidden_size, text_hidden_size)
        
    def forward(self, image_features):
        # Project each patch embedding to text embedding dimension
        # Input: [Batch_Size, Num_Patches, Vision_Hidden_Size]
        # Output: [Batch_Size, Num_Patches, Text_Hidden_Size]
        hidden_states = self.linear(image_features)
        return hidden_states

# Example projection
projector = MultiModalProjector(vision_hidden_size=768, text_hidden_size=2048)
projected_features = projector(vision_features)
print(f"Vision features shape: {vision_features.shape}")
print(f"Projected features shape: {projected_features.shape}")

### 3.2 Token Merging Strategy

PaliGemma's approach treats image patches as special tokens that are prepended to the text sequence:

In [None]:
def merge_vision_and_text_tokens(
    image_features,
    text_embeddings,
    input_ids,
    image_token_id=32000,  # Special token ID for <image>
    pad_token_id=0
):
    """Merge image features with text embeddings based on special tokens"""
    
    batch_size, seq_length = input_ids.shape
    _, _, embed_dim = image_features.shape
    
    # Scale image features to match text embedding magnitude
    # This is important for stable training
    scaled_image_features = image_features / (embed_dim ** 0.5)
    
    # Create output tensor
    merged_embeddings = torch.zeros(
        batch_size, seq_length, embed_dim,
        dtype=text_embeddings.dtype,
        device=text_embeddings.device
    )
    
    # Create masks for different token types
    text_mask = (input_ids != image_token_id) & (input_ids != pad_token_id)
    image_mask = input_ids == image_token_id
    
    # Fill in text embeddings
    text_positions = text_mask.nonzero()
    if len(text_positions) > 0:
        merged_embeddings[text_mask] = text_embeddings[text_mask]
    
    # Fill in image features at image token positions
    # In practice, we need to handle the mapping carefully
    # This is a simplified version
    for batch_idx in range(batch_size):
        image_positions = (image_mask[batch_idx]).nonzero().squeeze(-1)
        if len(image_positions) > 0:
            num_image_tokens = min(len(image_positions), scaled_image_features.shape[1])
            merged_embeddings[batch_idx, image_positions[:num_image_tokens]] = \
                scaled_image_features[batch_idx, :num_image_tokens]
    
    return merged_embeddings

# Demonstrate token merging
# Create dummy inputs
batch_size = 2
num_patches = 16
text_seq_len = 10
embed_dim = 512

# Image features from vision encoder + projector
dummy_image_features = torch.randn(batch_size, num_patches, embed_dim)

# Text embeddings 
dummy_text_embeddings = torch.randn(batch_size, text_seq_len + num_patches, embed_dim)

# Input IDs with image tokens at the beginning
image_token_id = 32000
dummy_input_ids = torch.cat([
    torch.full((batch_size, num_patches), image_token_id),
    torch.randint(1, 30000, (batch_size, text_seq_len))
], dim=1)

merged = merge_vision_and_text_tokens(
    dummy_image_features,
    dummy_text_embeddings,
    dummy_input_ids,
    image_token_id
)

print(f"Input IDs shape: {dummy_input_ids.shape}")
print(f"Merged embeddings shape: {merged.shape}")
print(f"First few input IDs: {dummy_input_ids[0, :20]}")

### 3.3 Simplified Vision-Language Model

Let's put it all together in a simplified vision-language model inspired by PaliGemma:

In [None]:
class SimpleVisionLanguageModel(nn.Module):
    """Simplified vision-language model inspired by PaliGemma"""
    
    def __init__(self, vision_config, text_config):
        super().__init__()
        # Vision encoder
        self.vision_encoder = VisionTransformer(
            embed_dim=vision_config['hidden_size'],
            depth=vision_config['num_layers']
        )
        
        # Multimodal projector
        self.projector = MultiModalProjector(
            vision_hidden_size=vision_config['hidden_size'],
            text_hidden_size=text_config['hidden_size']
        )
        
        # Text embedding layer
        self.text_embeddings = nn.Embedding(
            text_config['vocab_size'],
            text_config['hidden_size']
        )
        
        # Transformer decoder (simplified)
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=text_config['hidden_size'],
                nhead=text_config['num_heads'],
                dim_feedforward=text_config['hidden_size'] * 4,
                batch_first=True
            ),
            num_layers=text_config['num_layers']
        )
        
        # Output projection
        self.lm_head = nn.Linear(
            text_config['hidden_size'],
            text_config['vocab_size']
        )
        
        self.image_token_id = text_config['image_token_id']
        
    def forward(self, pixel_values, input_ids):
        batch_size = pixel_values.shape[0]
        
        # 1. Encode images
        vision_features = self.vision_encoder(pixel_values)
        
        # 2. Project to text space  
        image_features = self.projector(vision_features)
        
        # 3. Get text embeddings
        text_embeds = self.text_embeddings(input_ids)
        
        # 4. Merge image and text embeddings
        merged_embeds = merge_vision_and_text_tokens(
            image_features,
            text_embeds,
            input_ids,
            self.image_token_id
        )
        
        # 5. Pass through decoder
        # Create causal mask for autoregressive generation
        seq_len = merged_embeds.shape[1]
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len) * float('-inf'),
            diagonal=1
        ).to(merged_embeds.device)
        
        hidden_states = self.decoder(
            merged_embeds,
            merged_embeds,
            tgt_mask=causal_mask
        )
        
        # 6. Project to vocabulary
        logits = self.lm_head(hidden_states)
        
        return logits

# Configure and create model
vision_config = {
    'hidden_size': 768,
    'num_layers': 3
}

text_config = {
    'hidden_size': 512,
    'num_heads': 8,
    'num_layers': 3,
    'vocab_size': 32001,
    'image_token_id': 32000
}

vlm = SimpleVisionLanguageModel(vision_config, text_config)

# Test forward pass
output_logits = vlm(dummy_images[:2], dummy_input_ids)
print(f"Output logits shape: {output_logits.shape}")
print(f"Can generate text autoregressively using these logits!")

### 3.4 Understanding Cross-Modal Attention

Let's visualize how image and text tokens attend to each other:

In [None]:
def visualize_cross_modal_attention(num_image_tokens=16, num_text_tokens=10):
    """Visualize attention patterns between image and text tokens"""
    
    total_tokens = num_image_tokens + num_text_tokens
    
    # Create a sample attention matrix
    # In practice, this would come from the model
    attention_matrix = torch.rand(total_tokens, total_tokens)
    
    # Apply causal mask (each token can only attend to previous tokens)
    causal_mask = torch.triu(torch.ones_like(attention_matrix), diagonal=1)
    attention_matrix.masked_fill_(causal_mask.bool(), 0)
    
    # Normalize rows
    attention_matrix = F.softmax(attention_matrix + (causal_mask * -1e9), dim=-1)
    
    # Visualize
    fig, ax = plt.subplots(figsize=(8, 8))
    im = ax.imshow(attention_matrix.numpy(), cmap='Blues')
    
    # Add labels
    ax.set_xlabel('Tokens (Image + Text)')
    ax.set_ylabel('Tokens (Image + Text)')
    ax.set_title('Cross-Modal Attention Pattern\n(Causal: each token attends only to previous tokens)')
    
    # Add separators
    ax.axvline(x=num_image_tokens-0.5, color='red', linestyle='--', alpha=0.5)
    ax.axhline(y=num_image_tokens-0.5, color='red', linestyle='--', alpha=0.5)
    
    # Add region labels
    ax.text(num_image_tokens/2, -1, 'Image Tokens', ha='center', va='bottom', color='red')
    ax.text(num_image_tokens + num_text_tokens/2, -1, 'Text Tokens', ha='center', va='bottom', color='blue')
    ax.text(-1, num_image_tokens/2, 'Image\nTokens', ha='right', va='center', color='red', rotation=90)
    ax.text(-1, num_image_tokens + num_text_tokens/2, 'Text\nTokens', ha='right', va='center', color='blue', rotation=90)
    
    plt.colorbar(im, ax=ax, label='Attention Weight')
    plt.tight_layout()
    plt.show()

visualize_cross_modal_attention()

---

## Summary and Key Takeaways

We've explored the three core components of Vision Language Models:

### 1. Vision Transformers (ViT)
- Images are divided into patches and treated as sequences
- Positional embeddings preserve spatial information
- Self-attention enables global reasoning across patches

### 2. Contrastive Learning (CLIP/SigLIP)
- Dual encoders learn aligned representations for images and text
- Contrastive loss pushes matching pairs together
- SigLIP improves efficiency with sigmoid loss

### 3. Multimodal Integration
- Vision features are projected to language model space
- Image patches become special tokens in the text sequence
- Unified attention enables cross-modal understanding

### Real-World Applications

This architecture enables:
- **Image Captioning**: Generate descriptions of images
- **Visual Question Answering**: Answer questions about images
- **Multimodal Reasoning**: Combine visual and textual information
- **Zero-shot Image Classification**: Classify images without training

The elegance of modern VLMs like PaliGemma lies in their simplicity - by treating images as token sequences, we can leverage the full power of language models for multimodal understanding.

### Further Reading

To dive deeper into these concepts:
- Original Vision Transformer paper: "An Image is Worth 16x16 Words"
- CLIP paper: "Learning Transferable Visual Models From Natural Language Supervision"
- SigLIP paper: "Sigmoid Loss for Language Image Pre-Training"
- PaliGemma technical report for a complete VLM implementation