# Multimodal Language Models

## Overview

Multimodal LLMs can process and generate content across different modalities. This notebook covers:

- **Vision-Language Models**: Text-image understanding and generation
- **Audio-Language Models**: Speech and text integration
- **Cross-Modal Attention**: Attention mechanisms across modalities
- **Multimodal Fusion**: Combining information from different sources

Let's implement practical multimodal architectures and processing pipelines.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass
import math

print("Libraries imported successfully!")

## 1. Vision-Language Model

Architecture for processing both text and images:

In [None]:
class VisionEncoder(nn.Module):
    """Simple vision encoder for image features"""
    
    def __init__(self, image_size=224, patch_size=16, d_model=768):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.d_model = d_model
        
        # Calculate number of patches
        self.num_patches = (image_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(
            3, d_model, kernel_size=patch_size, stride=patch_size
        )
        
        # Position embeddings
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, d_model))
        
        # CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        
        # Transformer layers
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=12, batch_first=True),
            num_layers=6
        )
    
    def forward(self, images):
        B = images.shape[0]
        
        # Patch embedding
        x = self.patch_embed(images)  # [B, d_model, H/P, W/P]
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, d_model]
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # Add position embeddings
        x = x + self.pos_embed
        
        # Apply transformer
        x = self.transformer(x)
        
        return x

class CrossModalAttention(nn.Module):
    """Cross-modal attention between text and vision"""
    
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        # Projection layers
        self.text_to_vision_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.vision_to_text_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        
        # Layer norms
        self.text_norm = nn.LayerNorm(d_model)
        self.vision_norm = nn.LayerNorm(d_model)
        
        # Feed-forward networks
        self.text_ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model)
        )
        
        self.vision_ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model)
        )
    
    def forward(self, text_features, vision_features):
        # Text attending to vision
        text_attn_out, _ = self.text_to_vision_attn(
            text_features, vision_features, vision_features
        )
        text_features = self.text_norm(text_features + text_attn_out)
        text_features = text_features + self.text_ffn(text_features)
        
        # Vision attending to text
        vision_attn_out, _ = self.vision_to_text_attn(
            vision_features, text_features, text_features
        )
        vision_features = self.vision_norm(vision_features + vision_attn_out)
        vision_features = vision_features + self.vision_ffn(vision_features)
        
        return text_features, vision_features

class VisionLanguageModel(nn.Module):
    """Complete vision-language model"""
    
    def __init__(self, vocab_size, d_model=768, n_heads=12, n_layers=6):
        super().__init__()
        self.d_model = d_model
        
        # Text components
        self.text_embed = nn.Embedding(vocab_size, d_model)
        self.text_pos_embed = nn.Parameter(torch.randn(1, 512, d_model))  # Max seq len 512
        
        # Vision components
        self.vision_encoder = VisionEncoder(d_model=d_model)
        
        # Cross-modal layers
        self.cross_modal_layers = nn.ModuleList([
            CrossModalAttention(d_model, n_heads) for _ in range(n_layers)
        ])
        
        # Output heads
        self.text_head = nn.Linear(d_model, vocab_size)
        self.vision_head = nn.Linear(d_model, 1000)  # ImageNet classes
        
        # Multimodal fusion
        self.fusion_layer = nn.Linear(2 * d_model, d_model)
        self.multimodal_head = nn.Linear(d_model, vocab_size)
    
    def forward(self, text_tokens=None, images=None, task='multimodal'):
        text_features = None
        vision_features = None
        
        # Process text
        if text_tokens is not None:
            B, T = text_tokens.shape
            text_features = self.text_embed(text_tokens)
            text_features = text_features + self.text_pos_embed[:, :T, :]
        
        # Process images
        if images is not None:
            vision_features = self.vision_encoder(images)
        
        # Cross-modal processing
        if text_features is not None and vision_features is not None:
            for layer in self.cross_modal_layers:
                text_features, vision_features = layer(text_features, vision_features)
        
        # Task-specific outputs
        if task == 'text_generation' and text_features is not None:
            return self.text_head(text_features)
        elif task == 'image_classification' and vision_features is not None:
            return self.vision_head(vision_features[:, 0])  # CLS token
        elif task == 'multimodal' and text_features is not None and vision_features is not None:
            # Fuse modalities
            text_pooled = text_features.mean(dim=1)  # Pool text features
            vision_pooled = vision_features[:, 0]     # CLS token
            
            fused = torch.cat([text_pooled, vision_pooled], dim=-1)
            fused = self.fusion_layer(fused)
            
            return self.multimodal_head(fused)
        
        return None

class AudioLanguageModel(nn.Module):
    """Audio-language model for speech and text"""
    
    def __init__(self, vocab_size, d_model=768, n_heads=12):
        super().__init__()
        self.d_model = d_model
        
        # Audio encoder (simplified)
        self.audio_encoder = nn.Sequential(
            nn.Conv1d(80, 256, kernel_size=3, padding=1),  # Mel-spectrogram input
            nn.ReLU(),
            nn.Conv1d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(512, d_model, kernel_size=3, padding=1),
            nn.AdaptiveAvgPool1d(1)
        )
        
        # Text components
        self.text_embed = nn.Embedding(vocab_size, d_model)
        self.text_pos_embed = nn.Parameter(torch.randn(1, 512, d_model))
        
        # Cross-modal attention
        self.audio_text_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.text_audio_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        
        # Output heads
        self.speech_recognition_head = nn.Linear(d_model, vocab_size)
        self.text_to_speech_head = nn.Linear(d_model, 80)  # Mel-spectrogram output
        
        # Layer norms
        self.audio_norm = nn.LayerNorm(d_model)
        self.text_norm = nn.LayerNorm(d_model)
    
    def forward(self, audio_features=None, text_tokens=None, task='speech_recognition'):
        audio_repr = None
        text_repr = None
        
        # Process audio
        if audio_features is not None:
            # audio_features: [B, mel_bins, time_steps]
            audio_encoded = self.audio_encoder(audio_features)  # [B, d_model, 1]
            audio_repr = audio_encoded.squeeze(-1).unsqueeze(1)  # [B, 1, d_model]
        
        # Process text
        if text_tokens is not None:
            B, T = text_tokens.shape
            text_repr = self.text_embed(text_tokens)
            text_repr = text_repr + self.text_pos_embed[:, :T, :]
        
        # Cross-modal processing
        if audio_repr is not None and text_repr is not None:
            # Audio attending to text
            audio_attn_out, _ = self.audio_text_attn(audio_repr, text_repr, text_repr)
            audio_repr = self.audio_norm(audio_repr + audio_attn_out)
            
            # Text attending to audio
            text_attn_out, _ = self.text_audio_attn(text_repr, audio_repr, audio_repr)
            text_repr = self.text_norm(text_repr + text_attn_out)
        
        # Task-specific outputs
        if task == 'speech_recognition' and audio_repr is not None:
            return self.speech_recognition_head(audio_repr)
        elif task == 'text_to_speech' and text_repr is not None:
            return self.text_to_speech_head(text_repr)
        
        return None

print("Multimodal model architectures implemented!")