# Phase 2: Data Preparation for Vision-Language Alignment

**Goal:** Create a dataset that returns `(image, tokenized_caption)` pairs for training.

**Key insight:** We prepend `<image>` token to each caption. During forward pass, this single token gets **replaced** by 196 vision tokens from the ViT encoder.

In [1]:
import torch
from torch.utils.data import IterableDataset, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer
from torchvision import transforms
from pathlib import Path

## Step 3: Load Dataset

We use COCO Captions via HuggingFace `datasets` library with streaming.

**Why streaming?**
- COCO has ~118K images (~13GB)
- Streaming loads data on-the-fly, no need to download everything upfront
- Memory efficient for large datasets

In [2]:
# Load COCO captions with streaming (no full download needed)
ds = load_dataset("jxie/coco_captions", streaming=True)
print(ds)

Resolving data files:   0%|          | 0/182 [00:00<?, ?it/s]

IterableDatasetDict({
    train: IterableDataset({
        features: ['image', 'filename', 'cocoid', 'caption'],
        num_shards: 182
    })
    validation: IterableDataset({
        features: ['image', 'filename', 'cocoid', 'caption'],
        num_shards: 10
    })
    test: IterableDataset({
        features: ['image', 'filename', 'cocoid', 'caption'],
        num_shards: 9
    })
})


## Step 4: Load Tokenizer

Load our BPE tokenizer with the `<image>` token we added in Phase 1.

In [4]:
# Load tokenizer with <image> token
tokenizer = Tokenizer.from_file("../bpe_tokenizer_with_image_tag.json")

# Important IDs we'll need
IMAGE_TOKEN_ID = tokenizer.token_to_id("<image>")  # 32000
PAD_TOKEN_ID = tokenizer.token_to_id("<pad>")      # Usually 0
EOS_TOKEN_ID = tokenizer.token_to_id("</s>")      # End of sequence

print(f"<image> token ID: {IMAGE_TOKEN_ID}")
print(f"<pad> token ID: {PAD_TOKEN_ID}")
print(f"</s> token ID: {EOS_TOKEN_ID}")
print(f"Vocab size: {tokenizer.get_vocab_size()}")

<image> token ID: 32000
<pad> token ID: 3
</s> token ID: 2
Vocab size: 32001


## Step 5: Image Transform Pipeline

ViT-B/16 was trained on ImageNet with specific preprocessing:

1. **Resize(256)** - Resize shortest edge to 256 pixels (maintains aspect ratio)
2. **CenterCrop(224)** - Crop center 224x224 (ViT input size)
3. **ToTensor()** - Convert PIL Image to tensor, scales pixels from [0,255] to [0,1]
4. **Normalize()** - Normalize with ImageNet statistics:
   - Mean: [0.485, 0.456, 0.406] (RGB channel means from ImageNet)
   - Std: [0.229, 0.224, 0.225] (RGB channel stds from ImageNet)

**Why these specific values?**
- The pretrained ViT learned features assuming this exact preprocessing
- Using different normalization would produce garbage features

In [5]:
# Image preprocessing pipeline - must match what ViT was trained with
image_transform = transforms.Compose([
    transforms.Resize(256),                    # Resize shortest edge to 256
    transforms.CenterCrop(224),                # Crop center 224x224
    transforms.ToTensor(),                     # [0,255] -> [0,1], HWC -> CHW
    transforms.Normalize(                      # ImageNet normalization
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Test on a sample image
for sample in ds['train'].take(1):
    img = sample['image']
    print(f"Original image size: {img.size}")  # PIL Image size (W, H)
    
    img_tensor = image_transform(img)
    print(f"Transformed tensor shape: {img_tensor.shape}")  # [3, 224, 224]
    print(f"Tensor value range: [{img_tensor.min():.2f}, {img_tensor.max():.2f}]")

Original image size: (640, 480)
Transformed tensor shape: torch.Size([3, 224, 224])
Tensor value range: [-2.12, 1.79]


## Step 6: Dataset Class

Our dataset wraps HuggingFace's streaming dataset and:
1. Applies image transforms
2. Tokenizes captions with `<image>` prefix
3. Handles padding and truncation

**Caption format:**
```
Raw caption:    "a brown dog running"
With prefix:    "<image> a brown dog running"
Token IDs:      [32000, 214, 7532, 4485, 4274, <eos>]
                  ^--- This position will be REPLACED by 196 vision tokens
```

In [6]:
class VisionLanguageDataset(IterableDataset):
    """
    Dataset for vision-language alignment training.
    
    Each sample contains:
    - image: preprocessed image tensor [3, 224, 224]
    - input_ids: [<image>, caption_tokens..., <eos>]
    - labels: same as input_ids (for next-token prediction loss)
    """
    
    def __init__(
        self,
        hf_dataset,           # HuggingFace streaming dataset
        tokenizer,            # Our BPE tokenizer
        image_transform,      # torchvision transforms
        max_seq_len=64,       # Max caption length (including <image> and <eos>)
    ):
        self.dataset = hf_dataset
        self.tokenizer = tokenizer
        self.transform = image_transform
        self.max_seq_len = max_seq_len
        
        # Get special token IDs
        self.image_token_id = tokenizer.token_to_id("<image>")
        self.pad_token_id = tokenizer.token_to_id("<pad>") or 0
        self.eos_token_id = tokenizer.token_to_id("</s>")
    
    def __iter__(self):
        """Iterate through the streaming dataset."""
        for sample in self.dataset:
            try:
                processed = self._process_sample(sample)
                if processed is not None:
                    yield processed
            except Exception as e:
                # Skip corrupted samples
                continue
    
    def _process_sample(self, sample):
        """Process a single (image, caption) pair."""
        
        # ----- Image Processing -----
        image = sample['image']
        
        # Handle grayscale images (convert to RGB)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Apply transforms: resize, crop, normalize
        image_tensor = self.transform(image)  # [3, 224, 224]
        
        # ----- Caption Processing -----
        caption = sample['caption']
        
        # Prepend <image> token to caption
        # This placeholder will be replaced by 196 vision tokens during forward pass
        caption_with_image = f"<image> {caption}"
        
        # Tokenize
        encoding = self.tokenizer.encode(caption_with_image)
        token_ids = encoding.ids
        
        # Truncate (leave room for EOS)
        if len(token_ids) > self.max_seq_len - 1:
            token_ids = token_ids[:self.max_seq_len - 1]

        # Add EOS
        token_ids = token_ids + [self.eos_token_id]

        
        # Create attention mask (1 for real tokens, 0 for padding)
        attention_mask = [1] * len(token_ids)
        
        # Pad if too short
        padding_length = self.max_seq_len - len(token_ids)
        if padding_length > 0:
            token_ids = token_ids + [self.pad_token_id] * padding_length
            attention_mask = attention_mask + [0] * padding_length
        
        # Convert to tensors
        input_ids = torch.tensor(token_ids, dtype=torch.long)
        attention_mask = torch.tensor(attention_mask, dtype=torch.long)
        
        # Labels are same as input_ids for language modeling
        # During loss computation, we'll shift and mask appropriately
        labels = input_ids.clone()
        
        return {
            'image': image_tensor,           # [3, 224, 224]
            'input_ids': input_ids,          # [max_seq_len]
            'attention_mask': attention_mask, # [max_seq_len]
            'labels': labels                  # [max_seq_len]
        }

## Step 7: Create DataLoader

The DataLoader handles batching. Since we already padded to `max_seq_len`, 
we can use the default collate function.

In [7]:
# Create dataset
train_dataset = VisionLanguageDataset(
    hf_dataset=ds['train'],
    tokenizer=tokenizer,
    image_transform=image_transform,
    max_seq_len=128  # Plenty for COCO captions
)

# Create dataloader
train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    num_workers=0  # Set to 0 for streaming datasets
)

print("Dataset and DataLoader created!")

Dataset and DataLoader created!


## Step 8: Test the Pipeline

Let's verify everything works correctly by inspecting a batch.

In [8]:
# Get one batch to verify
batch = next(iter(train_loader))

print("Batch contents:")
print(f"  image shape:          {batch['image'].shape}")          # [B, 3, 224, 224]
print(f"  input_ids shape:      {batch['input_ids'].shape}")      # [B, max_seq_len]
print(f"  attention_mask shape: {batch['attention_mask'].shape}") # [B, max_seq_len]
print(f"  labels shape:         {batch['labels'].shape}")         # [B, max_seq_len]

print("\nFirst sample in batch:")
print(f"  input_ids: {batch['input_ids'][0][:20].tolist()}...")   # First 20 tokens
print(f"  First token is <image>? {batch['input_ids'][0][0].item() == IMAGE_TOKEN_ID}")

Batch contents:
  image shape:          torch.Size([4, 3, 224, 224])
  input_ids shape:      torch.Size([4, 128])
  attention_mask shape: torch.Size([4, 128])
  labels shape:         torch.Size([4, 128])

First sample in batch:
  input_ids: [32000, 297, 5403, 7289, 214, 1827, 310, 663, 2519, 3958, 214, 9556, 17, 177, 2, 3, 3, 3, 3, 3]...
  First token is <image>? True


In [9]:
# Decode a sample to verify tokenization
sample_ids = batch['input_ids'][0].tolist()

# Remove padding for cleaner output
sample_ids_no_pad = [t for t in sample_ids if t != PAD_TOKEN_ID]

decoded = tokenizer.decode(sample_ids_no_pad)
print(f"Decoded caption: {decoded}")

Decoded caption:  A woman wearing a net on her head cutting a cake. 


## Summary

**What we built:**
- Image transform pipeline matching ViT preprocessing
- Dataset class that yields `(image, input_ids, attention_mask, labels)`
- Each caption starts with `<image>` token (ID: 32000)

**Next steps (Phase 3):**
- Modify LLM embedding layer to handle the new vocab size (32001)
- Create function to replace `<image>` embedding with 196 vision tokens
- Build the full VLM forward pass