# Module 24: Multimodal Learning

**Bridging Vision and Language**

---

## 1. Objectives

- ✅ Understand multimodal architectures
- ✅ Implement Vision Transformer (ViT) from scratch
- ✅ Master CLIP for vision-language understanding
- ✅ Build image captioning pipeline
- ✅ Use HuggingFace multimodal models

## 2. Prerequisites

- [Module 14: Transformer Architecture](../14_transformer_architecture/14_transformer_architecture.ipynb)
- [Module 17: HuggingFace Ecosystem](../17_huggingface/17_huggingface.ipynb)

## 3. What is Multimodal Learning?

### Definition

Multimodal learning combines **multiple data types** (modalities) to create richer representations:

```
┌─────────────────────────────────────────────────────────┐
│                    Multimodal AI                         │
├─────────────────────────────────────────────────────────┤
│                                                          │
│   [Image]  ──┐                                          │
│              ├──→  [Joint Embedding]  ──→  [Output]     │
│   [Text]   ──┘         Space                            │
│                                                          │
└─────────────────────────────────────────────────────────┘
```

### Key Applications

| Task | Input | Output |
|------|-------|--------|
| Image Captioning | Image | Text description |
| Visual QA | Image + Question | Answer |
| Text-to-Image | Text prompt | Generated image |
| Image-Text Matching | Image + Text | Similarity score |
| Document Understanding | Document image | Structured data |

In [None]:
# Install required packages
# !pip install torch torchvision transformers pillow timm

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from PIL import Image
import requests
from io import BytesIO

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

## 4. Vision Transformer (ViT) - Theory

### The Key Insight

**"An image is worth 16x16 words"** - Split image into patches and treat them like tokens!

### Architecture Overview

```
┌─────────────────────────────────────────────────────────┐
│                  Vision Transformer                      │
├─────────────────────────────────────────────────────────┤
│                                                          │
│  Input Image (224×224×3)                                 │
│       ↓                                                  │
│  Split into Patches (14×14 = 196 patches of 16×16)      │
│       ↓                                                  │
│  Linear Projection (flatten 16×16×3 → 768)              │
│       ↓                                                  │
│  Add [CLS] Token + Positional Embeddings                │
│       ↓                                                  │
│  Transformer Encoder (L layers)                          │
│       ↓                                                  │
│  [CLS] Token → Classification Head                       │
│                                                          │
└─────────────────────────────────────────────────────────┘
```

### Patch Embedding Math

For an image of size $H \times W \times C$:
- Patch size: $P \times P$
- Number of patches: $N = \frac{H \times W}{P^2}$
- Each patch: $P^2 \cdot C$ values → projected to $D$ dimensions

Example: 224×224 image with 16×16 patches = 196 patches

In [None]:
class PatchEmbedding(nn.Module):
    """Split image into patches and embed them.
    
    Args:
        img_size: Input image size (assumes square)
        patch_size: Size of each patch
        in_channels: Number of input channels (3 for RGB)
        embed_dim: Embedding dimension
    """
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Linear projection of flattened patches
        # Conv2d with kernel=patch_size, stride=patch_size does the same
        self.proj = nn.Conv2d(
            in_channels, 
            embed_dim, 
            kernel_size=patch_size, 
            stride=patch_size
        )
    
    def forward(self, x):
        """x: [batch, channels, height, width]"""
        x = self.proj(x)  # [batch, embed_dim, n_patches_h, n_patches_w]
        x = x.flatten(2)  # [batch, embed_dim, n_patches]
        x = x.transpose(1, 2)  # [batch, n_patches, embed_dim]
        return x

# Test
patch_embed = PatchEmbedding()
img = torch.randn(1, 3, 224, 224)
patches = patch_embed(img)
print(f"Input: {img.shape}")
print(f"Patches: {patches.shape}  # [batch, 196 patches, 768 dim]")

In [None]:
class VisionTransformer(nn.Module):
    """Complete Vision Transformer implementation.
    
    This follows the original ViT paper architecture.
    """
    
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        n_classes=1000,
        embed_dim=768,
        depth=12,
        n_heads=12,
        mlp_ratio=4.0,
        dropout=0.1
    ):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )
        n_patches = self.patch_embed.n_patches
        
        # Learnable [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Positional embeddings (learnable)
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer encoder blocks
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=n_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)
        
        # Initialize weights
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=0.02)
    
    def forward(self, x):
        """x: [batch, channels, height, width]"""
        batch_size = x.shape[0]
        
        # Create patch embeddings
        x = self.patch_embed(x)  # [batch, n_patches, embed_dim]
        
        # Prepend [CLS] token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # [batch, n_patches+1, embed_dim]
        
        # Add positional embeddings
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer encoder
        x = self.transformer(x)
        
        # Classification: use [CLS] token output
        x = self.norm(x[:, 0])  # Take first token
        x = self.head(x)
        
        return x

# Create ViT-Base
vit = VisionTransformer(n_classes=10)  # 10 classes for demo
print(f"ViT parameters: {sum(p.numel() for p in vit.parameters()):,}")

# Test forward pass
img = torch.randn(2, 3, 224, 224)
output = vit(img)
print(f"Output: {output.shape}  # [batch, n_classes]")

## 5. CLIP - Contrastive Language-Image Pre-training

### Theory

CLIP learns a **joint embedding space** for images and text through contrastive learning:

```
┌────────────────────────────────────────────────────────────┐
│                      CLIP Architecture                      │
├────────────────────────────────────────────────────────────┤
│                                                             │
│   Image ──→ [Image Encoder (ViT/ResNet)] ──→ Image Embed   │
│                                                    ↓        │
│                                              Cosine Sim     │
│                                                    ↑        │
│   Text  ──→ [Text Encoder (Transformer)]  ──→ Text Embed   │
│                                                             │
└────────────────────────────────────────────────────────────┘

Training: Match N images with their N correct captions
          Minimize distance for matches, maximize for non-matches
```

### Contrastive Loss

For a batch of N (image, text) pairs:

$$\mathcal{L} = -\frac{1}{N}\sum_{i=1}^{N} \log \frac{\exp(sim(I_i, T_i)/\tau)}{\sum_{j=1}^{N}\exp(sim(I_i, T_j)/\tau)}$$

Where:
- $sim(I, T)$ = cosine similarity between image and text embeddings
- $\tau$ = temperature parameter (typically 0.07)

In [None]:
class SimpleCLIP(nn.Module):
    """Simplified CLIP implementation for understanding.
    
    In practice, you'd use the pretrained CLIP from OpenAI.
    """
    
    def __init__(self, embed_dim=512, vocab_size=10000, max_seq_len=77):
        super().__init__()
        
        # Image encoder (simplified ViT)
        self.image_encoder = nn.Sequential(
            PatchEmbedding(embed_dim=embed_dim),
            nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    d_model=embed_dim, nhead=8, batch_first=True
                ),
                num_layers=6
            )
        )
        self.image_proj = nn.Linear(embed_dim, embed_dim)
        
        # Text encoder
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
        self.text_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim, nhead=8, batch_first=True
            ),
            num_layers=6
        )
        self.text_proj = nn.Linear(embed_dim, embed_dim)
        
        # Learnable temperature
        self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
    
    def encode_image(self, image):
        """Encode images to embeddings."""
        x = self.image_encoder(image)
        x = x.mean(dim=1)  # Global average pooling
        x = self.image_proj(x)
        return F.normalize(x, dim=-1)  # L2 normalize
    
    def encode_text(self, text_ids):
        """Encode text to embeddings."""
        x = self.token_embed(text_ids) + self.pos_embed[:, :text_ids.shape[1]]
        x = self.text_encoder(x)
        x = x[:, 0]  # Take [CLS] or first token
        x = self.text_proj(x)
        return F.normalize(x, dim=-1)  # L2 normalize
    
    def forward(self, image, text_ids):
        """Compute similarity matrix."""
        image_embeds = self.encode_image(image)
        text_embeds = self.encode_text(text_ids)
        
        # Cosine similarity with temperature scaling
        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_embeds @ text_embeds.T
        
        return logits  # [batch_images, batch_texts]

print("SimpleCLIP architecture ready!")

In [None]:
def clip_loss(logits):
    """Symmetric contrastive loss for CLIP.
    
    Args:
        logits: Similarity matrix [batch, batch]
    
    Returns:
        Average of image-to-text and text-to-image losses
    """
    batch_size = logits.shape[0]
    labels = torch.arange(batch_size, device=logits.device)
    
    # Image-to-text loss (which text matches each image?)
    loss_i2t = F.cross_entropy(logits, labels)
    
    # Text-to-image loss (which image matches each text?)
    loss_t2i = F.cross_entropy(logits.T, labels)
    
    return (loss_i2t + loss_t2i) / 2

# Example
logits = torch.randn(4, 4)  # 4 images, 4 texts
loss = clip_loss(logits)
print(f"CLIP Loss: {loss:.4f}")

## 6. Using Pretrained CLIP (HuggingFace)

In [None]:
from transformers import CLIPProcessor, CLIPModel

# Load pretrained CLIP
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

print(f"CLIP loaded! Image encoder: ViT-B/32")

In [None]:
# Download a sample image
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg"
response = requests.get(url)
image = Image.open(BytesIO(response.content))

# Candidate texts
texts = [
    "a photo of a dog",
    "a photo of a cat",
    "a photo of a bird",
    "a photo of a car"
]

# Process inputs
inputs = processor(
    text=texts,
    images=image,
    return_tensors="pt",
    padding=True
)

# Get similarity scores
with torch.no_grad():
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

print("Zero-shot classification results:")
for text, prob in zip(texts, probs[0]):
    print(f"  '{text}': {prob:.2%}")

## 7. Image Captioning

### Architecture

```
Image ──→ [ViT Encoder] ──→ Image Features
                                   ↓
              [Cross-Attention in Decoder]
                                   ↓
                         [GPT-like Decoder] ──→ Caption
```

In [None]:
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer

# Load image captioning model
caption_model = VisionEncoderDecoderModel.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning"
)
feature_extractor = ViTImageProcessor.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning"
)
tokenizer = AutoTokenizer.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning"
)

print("Image captioning model loaded!")

In [None]:
def generate_caption(image, max_length=16):
    """Generate caption for an image."""
    pixel_values = feature_extractor(
        images=image, return_tensors="pt"
    ).pixel_values
    
    output_ids = caption_model.generate(
        pixel_values,
        max_length=max_length,
        num_beams=4
    )
    
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption

# Test with our image
caption = generate_caption(image)
print(f"Generated caption: '{caption}'")

## 8. Visual Question Answering (VQA)

In [None]:
from transformers import ViltProcessor, ViltForQuestionAnswering

# Load VQA model (ViLT - Vision-and-Language Transformer)
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

def answer_question(image, question):
    """Answer a question about an image."""
    inputs = vqa_processor(image, question, return_tensors="pt")
    
    with torch.no_grad():
        outputs = vqa_model(**inputs)
    
    # Get top answer
    idx = outputs.logits.argmax(-1).item()
    return vqa_model.config.id2label[idx]

# Test
questions = [
    "What animal is in the image?",
    "What color is the animal?",
    "Is the animal sitting or standing?"
]

print("Visual Question Answering:")
for q in questions:
    answer = answer_question(image, q)
    print(f"  Q: {q}")
    print(f"  A: {answer}\n")

## 9. BLIP-2 - State of the Art

### Architecture

BLIP-2 bridges frozen image encoders and LLMs with a lightweight Q-Former:

```
Image ──→ [Frozen ViT] ──→ [Q-Former] ──→ [Frozen LLM] ──→ Response
                              ↑
                        Learnable queries
```

In [None]:
from transformers import Blip2Processor, Blip2ForConditionalGeneration

# Load BLIP-2 (requires significant memory)
# blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
# blip_model = Blip2ForConditionalGeneration.from_pretrained(
#     "Salesforce/blip2-opt-2.7b", 
#     torch_dtype=torch.float16
# )

# Usage (if loaded):
# inputs = blip_processor(images=image, text="Question: What is in the image?", return_tensors="pt")
# output = blip_model.generate(**inputs)
# print(blip_processor.decode(output[0], skip_special_tokens=True))

print("BLIP-2 example (commented out due to memory requirements)")
print("Use Salesforce/blip2-opt-2.7b for state-of-the-art results")

## 10. Interview Questions

**Q1: How does ViT differ from CNNs for image processing?**
<details><summary>Answer</summary>

- ViT: Splits image into patches, treats them as tokens, uses self-attention
- CNN: Uses convolutional kernels that slide across image
- ViT captures global context from start; CNN builds it hierarchically
- ViT needs more data to train from scratch (no inductive bias for locality)
</details>

**Q2: Explain CLIP's contrastive learning objective.**
<details><summary>Answer</summary>

CLIP maximizes cosine similarity between matching (image, text) pairs while minimizing similarity for non-matching pairs. This creates a shared embedding space where similar concepts (regardless of modality) are close together.
</details>

**Q3: What is zero-shot classification with CLIP?**
<details><summary>Answer</summary>

Using CLIP to classify images without task-specific training:
1. Encode image with image encoder
2. Encode class names as text (e.g., "a photo of a cat")
3. Compute cosine similarity
4. Highest similarity = predicted class
</details>

## 11. Summary

| Model | Task | Key Insight |
|-------|------|-------------|
| ViT | Image classification | Images as sequences of patches |
| CLIP | Vision-language | Contrastive learning for joint space |
| VQA (ViLT) | Question answering | Fused image-text transformer |
| BLIP-2 | Multi-task | Q-Former bridges vision & LLM |

## 12. References

- [ViT Paper](https://arxiv.org/abs/2010.11929)
- [CLIP Paper](https://arxiv.org/abs/2103.00020)
- [BLIP-2 Paper](https://arxiv.org/abs/2301.12597)
- [ViLT Paper](https://arxiv.org/abs/2102.03334)

---
**Next:** [Module 25: Stable Diffusion](../25_stable_diffusion/25_stable_diffusion.ipynb)