# 14 - Multimodal Language Models

This notebook covers multimodal language models that process text, images, and audio.

## Topics Covered:
- Text-image models
- Text-audio models
- Vision-language transformers
- Multimodal fusion techniques

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

np.random.seed(42)
torch.manual_seed(42)

## 1. Text-Image Models

In [None]:
class VisionTransformer(nn.Module):
    """Simple Vision Transformer for image encoding."""
    
    def __init__(self, image_size=224, patch_size=16, d_model=512, num_layers=6, num_heads=8):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.d_model = d_model
        
        # Calculate number of patches
        self.num_patches = (image_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embedding = nn.Linear(patch_size * patch_size * 3, d_model)
        
        # Position embeddings
        self.position_embeddings = nn.Parameter(torch.randn(1, self.num_patches + 1, d_model))
        
        # CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        
        # Transformer layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Layer norm
        self.layer_norm = nn.LayerNorm(d_model)
    
    def patchify(self, images):
        """Convert images to patches."""
        batch_size, channels, height, width = images.shape
        
        # Reshape to patches
        patches = images.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(
            batch_size, channels, -1, self.patch_size, self.patch_size
        )
        patches = patches.permute(0, 2, 1, 3, 4).contiguous()
        patches = patches.view(batch_size, -1, channels * self.patch_size * self.patch_size)
        
        return patches
    
    def forward(self, images):
        batch_size = images.shape[0]
        
        # Convert to patches
        patches = self.patchify(images)  # (batch_size, num_patches, patch_dim)
        
        # Embed patches
        patch_embeddings = self.patch_embedding(patches)  # (batch_size, num_patches, d_model)
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat([cls_tokens, patch_embeddings], dim=1)
        
        # Add position embeddings
        embeddings = embeddings + self.position_embeddings
        
        # Apply transformer
        features = self.transformer(embeddings)
        
        # Apply layer norm
        features = self.layer_norm(features)
        
        return features

class TextImageModel(nn.Module):
    """Text-Image multimodal model."""
    
    def __init__(self, vocab_size=50000, d_model=512, num_heads=8, num_layers=6):
        super().__init__()
        self.d_model = d_model
        
        # Text encoder
        self.text_embedding = nn.Embedding(vocab_size, d_model)
        self.text_position_embedding = nn.Embedding(1000, d_model)  # Max sequence length
        
        text_encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            batch_first=True
        )
        self.text_encoder = nn.TransformerEncoder(text_encoder_layer, num_layers=num_layers)
        
        # Image encoder
        self.image_encoder = VisionTransformer(d_model=d_model, num_layers=num_layers, num_heads=num_heads)
        
        # Cross-modal attention
        self.cross_attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        # Layer norms
        self.text_layer_norm = nn.LayerNorm(d_model)
        self.image_layer_norm = nn.LayerNorm(d_model)
    
    def encode_text(self, text_tokens):
        """Encode text tokens."""
        batch_size, seq_len = text_tokens.shape
        
        # Text embeddings
        text_emb = self.text_embedding(text_tokens)
        
        # Position embeddings
        positions = torch.arange(seq_len, device=text_tokens.device).unsqueeze(0).expand(batch_size, -1)
        pos_emb = self.text_position_embedding(positions)
        
        # Combine embeddings
        embeddings = text_emb + pos_emb
        
        # Apply transformer
        text_features = self.text_encoder(embeddings)
        text_features = self.text_layer_norm(text_features)
        
        return text_features
    
    def encode_image(self, images):
        """Encode images."""
        image_features = self.image_encoder(images)
        image_features = self.image_layer_norm(image_features)
        return image_features
    
    def forward(self, text_tokens, images):
        """Forward pass with text and images."""
        # Encode modalities
        text_features = self.encode_text(text_tokens)
        image_features = self.encode_image(images)
        
        # Cross-modal attention: text attends to image
        attended_text, attention_weights = self.cross_attention(
            query=text_features,
            key=image_features,
            value=image_features
        )
        
        # Combine with residual connection
        fused_features = text_features + attended_text
        
        # Output projection
        output = self.output_projection(fused_features)
        
        return output, attention_weights

# Demonstrate text-image model
def demonstrate_text_image_model():
    """Demonstrate text-image multimodal model."""
    
    print("Text-Image Multimodal Model Demo:")
    
    # Model parameters
    vocab_size = 10000
    d_model = 256
    batch_size = 2
    seq_len = 20
    image_size = 224
    
    # Initialize model
    model = TextImageModel(vocab_size=vocab_size, d_model=d_model)
    
    # Sample data
    text_tokens = torch.randint(0, vocab_size, (batch_size, seq_len))
    images = torch.randn(batch_size, 3, image_size, image_size)
    
    # Forward pass
    with torch.no_grad():
        output, attention_weights = model(text_tokens, images)
    
    print(f"Input shapes:")
    print(f"  Text tokens: {text_tokens.shape}")
    print(f"  Images: {images.shape}")
    print(f"Output shapes:")
    print(f"  Model output: {output.shape}")
    print(f"  Attention weights: {attention_weights.shape}")
    
    # Visualize attention
    plt.figure(figsize=(12, 4))
    
    # Plot attention weights
    plt.subplot(1, 2, 1)
    attention_avg = attention_weights[0].mean(dim=0).numpy()  # Average over heads
    plt.imshow(attention_avg, cmap='Blues', aspect='auto')
    plt.title('Cross-Modal Attention\n(Text â†’ Image)')
    plt.xlabel('Image Patches')
    plt.ylabel('Text Tokens')
    plt.colorbar()
    
    # Plot attention distribution
    plt.subplot(1, 2, 2)
    attention_sum = attention_avg.sum(axis=0)
    plt.bar(range(len(attention_sum)), attention_sum)
    plt.title('Attention Distribution\nAcross Image Patches')
    plt.xlabel('Image Patch Index')
    plt.ylabel('Total Attention')
    
    plt.tight_layout()
    plt.show()
    
    return model

text_image_model = demonstrate_text_image_model()

## 2. Text-Audio Models

In [None]:
class AudioEncoder(nn.Module):
    """Audio encoder using 1D convolutions and transformers."""
    
    def __init__(self, input_dim=80, d_model=512, num_layers=6, num_heads=8):
        super().__init__()
        self.d_model = d_model
        
        # Convolutional layers for feature extraction
        self.conv_layers = nn.Sequential(
            nn.Conv1d(input_dim, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(512, d_model, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Position embeddings
        self.position_embedding = nn.Embedding(5000, d_model)  # Max audio length
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Layer norm
        self.layer_norm = nn.LayerNorm(d_model)
    
    def forward(self, audio_features):
        """Forward pass for audio encoding."""
        batch_size, input_dim, seq_len = audio_features.shape
        
        # Apply convolutions
        conv_features = self.conv_layers(audio_features)  # (batch, d_model, seq_len)
        conv_features = conv_features.transpose(1, 2)  # (batch, seq_len, d_model)
        
        # Add position embeddings
        positions = torch.arange(seq_len, device=audio_features.device).unsqueeze(0).expand(batch_size, -1)
        pos_emb = self.position_embedding(positions)
        
        embeddings = conv_features + pos_emb
        
        # Apply transformer
        audio_encoded = self.transformer(embeddings)
        audio_encoded = self.layer_norm(audio_encoded)
        
        return audio_encoded

class TextAudioModel(nn.Module):
    """Text-Audio multimodal model."""
    
    def __init__(self, vocab_size=50000, audio_dim=80, d_model=512, num_heads=8, num_layers=6):
        super().__init__()
        self.d_model = d_model
        
        # Text encoder (same as before)
        self.text_embedding = nn.Embedding(vocab_size, d_model)
        self.text_position_embedding = nn.Embedding(1000, d_model)
        
        text_encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            batch_first=True
        )
        self.text_encoder = nn.TransformerEncoder(text_encoder_layer, num_layers=num_layers)
        
        # Audio encoder
        self.audio_encoder = AudioEncoder(input_dim=audio_dim, d_model=d_model, num_layers=num_layers)
        
        # Cross-modal fusion
        self.text_to_audio_attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.audio_to_text_attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        
        # Fusion layer
        self.fusion_layer = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model, d_model)
        )
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        # Layer norms
        self.text_layer_norm = nn.LayerNorm(d_model)
        self.audio_layer_norm = nn.LayerNorm(d_model)
    
    def encode_text(self, text_tokens):
        """Encode text tokens."""
        batch_size, seq_len = text_tokens.shape
        
        # Text embeddings
        text_emb = self.text_embedding(text_tokens)
        
        # Position embeddings
        positions = torch.arange(seq_len, device=text_tokens.device).unsqueeze(0).expand(batch_size, -1)
        pos_emb = self.text_position_embedding(positions)
        
        # Combine embeddings
        embeddings = text_emb + pos_emb
        
        # Apply transformer
        text_features = self.text_encoder(embeddings)
        text_features = self.text_layer_norm(text_features)
        
        return text_features
    
    def encode_audio(self, audio_features):
        """Encode audio features."""
        audio_encoded = self.audio_encoder(audio_features)
        audio_encoded = self.audio_layer_norm(audio_encoded)
        return audio_encoded
    
    def forward(self, text_tokens, audio_features):
        """Forward pass with text and audio."""
        # Encode modalities
        text_features = self.encode_text(text_tokens)
        audio_features = self.encode_audio(audio_features)
        
        # Cross-modal attention
        text_attended, _ = self.text_to_audio_attention(
            query=text_features,
            key=audio_features,
            value=audio_features
        )
        
        audio_attended, _ = self.audio_to_text_attention(
            query=audio_features,
            key=text_features,
            value=text_features
        )
        
        # Pool audio features (mean pooling)
        audio_pooled = audio_attended.mean(dim=1, keepdim=True)  # (batch, 1, d_model)
        audio_pooled = audio_pooled.expand(-1, text_features.size(1), -1)  # Match text length
        
        # Fusion
        combined_features = torch.cat([text_attended, audio_pooled], dim=-1)
        fused_features = self.fusion_layer(combined_features)
        
        # Output projection
        output = self.output_projection(fused_features)
        
        return output

# Demonstrate text-audio model
def demonstrate_text_audio_model():
    """Demonstrate text-audio multimodal model."""
    
    print("\nText-Audio Multimodal Model Demo:")
    
    # Model parameters
    vocab_size = 10000
    audio_dim = 80  # Mel spectrogram features
    d_model = 256
    batch_size = 2
    text_seq_len = 20
    audio_seq_len = 100
    
    # Initialize model
    model = TextAudioModel(vocab_size=vocab_size, audio_dim=audio_dim, d_model=d_model)
    
    # Sample data
    text_tokens = torch.randint(0, vocab_size, (batch_size, text_seq_len))
    audio_features = torch.randn(batch_size, audio_dim, audio_seq_len)  # Mel spectrogram
    
    # Forward pass
    with torch.no_grad():
        output = model(text_tokens, audio_features)
    
    print(f"Input shapes:")
    print(f"  Text tokens: {text_tokens.shape}")
    print(f"  Audio features: {audio_features.shape}")
    print(f"Output shape: {output.shape}")
    
    # Visualize audio features and model components
    plt.figure(figsize=(15, 8))
    
    # Plot sample audio features (mel spectrogram)
    plt.subplot(2, 3, 1)
    plt.imshow(audio_features[0].numpy(), cmap='viridis', aspect='auto')
    plt.title('Sample Audio Features\n(Mel Spectrogram)')
    plt.xlabel('Time Steps')
    plt.ylabel('Mel Frequency Bins')
    plt.colorbar()
    
    # Plot text token distribution
    plt.subplot(2, 3, 2)
    token_counts = torch.bincount(text_tokens.flatten(), minlength=100)[:100]
    plt.bar(range(len(token_counts)), token_counts.numpy())
    plt.title('Text Token Distribution\n(First 100 tokens)')
    plt.xlabel('Token ID')
    plt.ylabel('Count')
    
    # Plot output logits distribution
    plt.subplot(2, 3, 3)
    output_sample = output[0, 0, :100].numpy()  # First position, first 100 logits
    plt.plot(output_sample)
    plt.title('Output Logits\n(First position, 100 dims)')
    plt.xlabel('Vocabulary Index')
    plt.ylabel('Logit Value')
    
    # Model architecture visualization
    plt.subplot(2, 3, 4)
    components = ['Text\nEncoder', 'Audio\nEncoder', 'Cross\nAttention', 'Fusion\nLayer', 'Output\nProjection']
    positions = range(len(components))
    plt.barh(positions, [1, 1, 1, 1, 1], alpha=0.7)
    plt.yticks(positions, components)
    plt.title('Model Architecture\nComponents')
    plt.xlabel('Processing Stage')
    
    # Parameter count comparison
    plt.subplot(2, 3, 5)
    param_counts = {
        'Text Encoder': sum(p.numel() for p in model.text_encoder.parameters()),
        'Audio Encoder': sum(p.numel() for p in model.audio_encoder.parameters()),
        'Fusion': sum(p.numel() for p in model.fusion_layer.parameters()),
        'Output': sum(p.numel() for p in model.output_projection.parameters())
    }
    
    plt.pie(param_counts.values(), labels=param_counts.keys(), autopct='%1.1f%%')
    plt.title('Parameter Distribution')
    
    # Audio processing pipeline
    plt.subplot(2, 3, 6)
    pipeline_steps = ['Raw\nAudio', 'Mel\nSpectrogram', 'Conv\nLayers', 'Transformer\nEncoder', 'Features']
    step_positions = range(len(pipeline_steps))
    
    for i in range(len(pipeline_steps) - 1):
        plt.arrow(i, 0, 0.8, 0, head_width=0.1, head_length=0.1, fc='blue', ec='blue')
    
    plt.scatter(step_positions, [0] * len(pipeline_steps), s=100, c='red')
    for i, step in enumerate(pipeline_steps):
        plt.text(i, 0.2, step, ha='center', va='bottom')
    
    plt.xlim(-0.5, len(pipeline_steps) - 0.5)
    plt.ylim(-0.5, 0.5)
    plt.title('Audio Processing Pipeline')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return model

text_audio_model = demonstrate_text_audio_model()