# üöÄ Next-Generation Medical AI: Hands-on Practice

## Table of Contents
1. [Mixture of Experts (MoE) Basics](#practice-1-mixture-of-experts-moe-basics)
2. [Sparse Activation and Routing](#practice-2-sparse-activation-and-routing)
3. [Long-Context Attention Mechanisms](#practice-3-long-context-attention-mechanisms)
4. [Flash Attention Simulation](#practice-4-flash-attention-simulation)
5. [Retrieval-Augmented Generation (RAG)](#practice-5-retrieval-augmented-generation-rag)
6. [Graph Neural Networks for Medical Data](#practice-6-graph-neural-networks-for-medical-data)
7. [State Space Models (Mamba-style)](#practice-7-state-space-models-mamba-style)
8. [Performance Comparison](#practice-8-performance-comparison)

## Installing and Importing Essential Libraries

In [None]:
# Import essential libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')

# Visualization settings
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11
sns.set_style('whitegrid')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print("‚úÖ All libraries loaded successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

---
## Practice 1: Mixture of Experts (MoE) Basics

### üéØ Learning Objectives
- Understand the MoE architecture with gating mechanism
- Implement sparse expert selection (Top-K routing)
- Visualize expert specialization

### üìñ Key Concepts
**Mixture of Experts:** Multiple specialized sub-networks (experts) with a gating network that routes inputs to the most relevant experts.
- **Gating Network:** Selects top-K experts based on input
- **Sparse Activation:** Only K out of N experts are active per input
- **Efficiency:** Massive scaling with manageable computation

In [None]:
# 1.1 Simple MoE Implementation
class SimpleMoE(nn.Module):
    """A simple Mixture of Experts model"""
    
    def __init__(self, input_dim: int, hidden_dim: int, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Create expert networks
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, input_dim)
            ) for _ in range(num_experts)
        ])
        
        # Gating network
        self.gate = nn.Linear(input_dim, num_experts)
    
    def forward(self, x):
        # Compute gating scores
        gate_logits = self.gate(x)  # [batch_size, num_experts]
        gate_scores = F.softmax(gate_logits, dim=-1)
        
        # Select top-k experts
        top_k_scores, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1)
        
        # Normalize top-k scores
        top_k_scores = top_k_scores / top_k_scores.sum(dim=-1, keepdim=True)
        
        # Compute expert outputs
        output = torch.zeros_like(x)
        for i in range(self.top_k):
            expert_idx = top_k_indices[:, i]
            expert_weight = top_k_scores[:, i].unsqueeze(-1)
            
            for batch_idx, exp_idx in enumerate(expert_idx):
                expert_output = self.experts[exp_idx](x[batch_idx:batch_idx+1])
                output[batch_idx:batch_idx+1] += expert_weight[batch_idx] * expert_output
        
        return output, top_k_indices, top_k_scores

# Test the MoE model
def test_moe():
    """Test MoE with sample data"""
    
    # Create sample medical data (simulating patient features)
    batch_size = 5
    input_dim = 10
    
    # Simulate different types of medical data
    medical_specialties = ['Radiology', 'Pathology', 'Cardiology', 'Neurology', 
                          'Oncology', 'Genomics', 'Surgery', 'General']
    
    X = torch.randn(batch_size, input_dim)
    
    # Create MoE model
    model = SimpleMoE(input_dim=input_dim, hidden_dim=32, num_experts=8, top_k=2)
    
    # Forward pass
    output, expert_indices, expert_weights = model(X)
    
    print("üè• Mixture of Experts - Medical Domain Routing")
    print("=" * 60)
    print(f"\nInput shape: {X.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Number of experts: {model.num_experts}")
    print(f"Active experts per input: {model.top_k}")
    
    print("\nüìä Expert Selection for Each Sample:")
    print("-" * 60)
    for i in range(batch_size):
        print(f"\nSample {i+1}:")
        for j in range(model.top_k):
            expert_id = expert_indices[i, j].item()
            weight = expert_weights[i, j].item()
            specialty = medical_specialties[expert_id]
            print(f"  Expert {expert_id} ({specialty}): {weight:.4f} weight")
    
    # Calculate computation savings
    active_experts = model.top_k
    total_experts = model.num_experts
    compute_savings = (1 - active_experts / total_experts) * 100
    
    print("\nüí° Efficiency Gains:")
    print("=" * 60)
    print(f"Active experts: {active_experts}/{total_experts} ({active_experts/total_experts*100:.1f}%)")
    print(f"Computation savings: {compute_savings:.1f}%")
    print(f"Effective parameter scaling: {total_experts/active_experts:.1f}x")
    
    return model, X, output

moe_model, sample_input, moe_output = test_moe()

---
## Practice 2: Sparse Activation and Routing

### üéØ Learning Objectives
- Visualize expert utilization patterns
- Understand load balancing in MoE
- Compare dense vs sparse activation

### üìñ Key Concepts
**Sparse Activation:** Only a small subset of experts process each input
- **Benefits:** Reduced computation, memory efficiency, specialized expertise
- **Challenge:** Load balancing - ensuring all experts are utilized

In [None]:
# 2.1 Visualize Expert Routing Patterns
def visualize_expert_routing(num_samples: int = 100):
    """Visualize which experts are selected for different inputs"""
    
    # Generate diverse medical data samples
    X = torch.randn(num_samples, 10)
    
    # Forward pass through MoE
    with torch.no_grad():
        _, expert_indices, expert_weights = moe_model(X)
    
    # Count expert usage
    expert_usage = torch.zeros(moe_model.num_experts)
    for i in range(num_samples):
        for j in range(moe_model.top_k):
            expert_usage[expert_indices[i, j]] += expert_weights[i, j].item()
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Expert utilization
    specialties = ['Radiology', 'Pathology', 'Cardiology', 'Neurology', 
                   'Oncology', 'Genomics', 'Surgery', 'General']
    
    colors = plt.cm.viridis(np.linspace(0, 1, len(specialties)))
    bars = ax1.bar(range(moe_model.num_experts), expert_usage.numpy(), color=colors)
    ax1.set_xlabel('Expert ID')
    ax1.set_ylabel('Total Weight (Usage)')
    ax1.set_title('Expert Utilization Pattern')
    ax1.set_xticks(range(moe_model.num_experts))
    ax1.set_xticklabels([f'E{i}\n{s[:4]}' for i, s in enumerate(specialties)], rotation=0, fontsize=9)
    ax1.grid(axis='y', alpha=0.3)
    
    # Add average line
    avg_usage = expert_usage.mean().item()
    ax1.axhline(y=avg_usage, color='r', linestyle='--', linewidth=2, label=f'Average: {avg_usage:.1f}')
    ax1.legend()
    
    # Plot 2: Routing heatmap
    routing_matrix = torch.zeros(num_samples, moe_model.num_experts)
    for i in range(num_samples):
        for j in range(moe_model.top_k):
            routing_matrix[i, expert_indices[i, j]] = expert_weights[i, j].item()
    
    im = ax2.imshow(routing_matrix[:20].numpy(), aspect='auto', cmap='YlOrRd')
    ax2.set_xlabel('Expert ID')
    ax2.set_ylabel('Sample ID')
    ax2.set_title('Expert Routing Heatmap (First 20 samples)')
    ax2.set_xticks(range(moe_model.num_experts))
    ax2.set_xticklabels([f'E{i}' for i in range(moe_model.num_experts)])
    plt.colorbar(im, ax=ax2, label='Routing Weight')
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print("üìà Expert Usage Statistics:")
    print("=" * 60)
    print(f"Most used expert: Expert {expert_usage.argmax().item()} ({specialties[expert_usage.argmax().item()]})")
    print(f"Least used expert: Expert {expert_usage.argmin().item()} ({specialties[expert_usage.argmin().item()]})")
    print(f"Usage ratio (max/min): {expert_usage.max()/expert_usage.min():.2f}x")
    print(f"\nLoad balance score: {1 - (expert_usage.std() / expert_usage.mean()):.2%}")
    print("  (Higher is better, 100% = perfect balance)")

visualize_expert_routing()

---
## Practice 3: Long-Context Attention Mechanisms

### üéØ Learning Objectives
- Compare standard attention vs efficient attention
- Understand complexity trade-offs: O(n¬≤) vs O(n)
- Visualize attention patterns on long sequences

### üìñ Key Concepts
**Attention Complexity:**
- Standard Attention: O(n¬≤) memory and compute
- Flash Attention: O(n) memory with block-wise computation
- Critical for processing 100K+ token sequences

In [None]:
# 3.1 Compare Attention Mechanisms
def compare_attention_complexity():
    """Compare memory usage of different attention mechanisms"""
    
    sequence_lengths = [512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 100000]
    
    # Calculate memory usage (in GB)
    def standard_attention_memory(n, d=64):
        """Memory for standard attention: O(n¬≤)"""
        # Attention matrix: n x n, each element is 4 bytes (float32)
        return (n * n * 4) / (1024**3)
    
    def linear_attention_memory(n, d=64):
        """Memory for linear attention: O(n)"""
        # Linear in sequence length
        return (n * d * 4) / (1024**3)
    
    standard_mem = [standard_attention_memory(n) for n in sequence_lengths]
    linear_mem = [linear_attention_memory(n) for n in sequence_lengths]
    
    # Visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Memory usage comparison
    ax1.plot(sequence_lengths, standard_mem, 'o-', linewidth=2, markersize=8, 
             label='Standard Attention O(n¬≤)', color='#dc3545')
    ax1.plot(sequence_lengths, linear_mem, 's-', linewidth=2, markersize=8,
             label='Linear Attention O(n)', color='#28a745')
    
    ax1.set_xlabel('Sequence Length (tokens)', fontsize=12)
    ax1.set_ylabel('Memory Usage (GB)', fontsize=12)
    ax1.set_title('Attention Mechanism Memory Comparison', fontsize=14, fontweight='bold')
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.grid(True, alpha=0.3)
    ax1.legend(fontsize=11)
    
    # Add 100K token marker
    idx_100k = sequence_lengths.index(100000)
    ax1.axvline(x=100000, color='blue', linestyle='--', alpha=0.5, label='100K tokens')
    ax1.annotate('100K tokens\n(Full patient history)', 
                xy=(100000, standard_mem[idx_100k]),
                xytext=(50000, standard_mem[idx_100k]*2),
                arrowprops=dict(arrowstyle='->', color='blue'),
                fontsize=10, color='blue')
    
    # Plot 2: Speedup factor
    speedup = [s/l for s, l in zip(standard_mem, linear_mem)]
    ax2.plot(sequence_lengths, speedup, 'D-', linewidth=2, markersize=8, color='#1E64C8')
    ax2.set_xlabel('Sequence Length (tokens)', fontsize=12)
    ax2.set_ylabel('Memory Efficiency (x times)', fontsize=12)
    ax2.set_title('Linear Attention Efficiency Gain', fontsize=14, fontweight='bold')
    ax2.set_xscale('log')
    ax2.grid(True, alpha=0.3)
    ax2.fill_between(sequence_lengths, 1, speedup, alpha=0.3, color='#1E64C8')
    
    plt.tight_layout()
    plt.show()
    
    # Print comparison table
    print("\nüìä Memory Usage Comparison Table:")
    print("=" * 80)
    print(f"{'Seq Length':<12} {'Standard (GB)':<18} {'Linear (GB)':<15} {'Speedup':<10}")
    print("-" * 80)
    
    for i, n in enumerate([512, 2048, 8192, 32768, 100000]):
        idx = sequence_lengths.index(n)
        print(f"{n:<12,} {standard_mem[idx]:<18.4f} {linear_mem[idx]:<15.6f} {speedup[idx]:<10.1f}x")
    
    print("\nüí° Key Insights:")
    print("=" * 80)
    idx_100k = sequence_lengths.index(100000)
    print(f"‚Ä¢ At 100K tokens (full patient history):")
    print(f"  - Standard attention: {standard_mem[idx_100k]:.2f} GB")
    print(f"  - Linear attention: {linear_mem[idx_100k]:.4f} GB")
    print(f"  - Efficiency gain: {speedup[idx_100k]:.0f}x less memory!")
    print(f"\n‚Ä¢ This is why Flash Attention and Mamba are crucial for medical AI")

compare_attention_complexity()

---
## Practice 4: Flash Attention Simulation

### üéØ Learning Objectives
- Understand block-wise computation
- Simulate memory-efficient attention
- Visualize the difference in computation patterns

### üìñ Key Concepts
**Flash Attention Innovation:**
- Block-wise computation in SRAM (fast memory)
- Minimize HBM (slow memory) access
- Achieve exact attention with O(n) memory

In [None]:
# 4.1 Simulate Flash Attention Block-wise Processing
def simulate_flash_attention():
    """Simulate block-wise attention computation"""
    
    seq_length = 128
    block_size = 32
    num_blocks = seq_length // block_size
    
    print("‚ö° Flash Attention: Block-wise Computation")
    print("=" * 60)
    print(f"Sequence length: {seq_length} tokens")
    print(f"Block size: {block_size} tokens")
    print(f"Number of blocks: {num_blocks}")
    
    # Simulate attention matrix computation
    # Standard: Load entire n√ón matrix
    # Flash: Process in blocks
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Standard Attention Pattern
    standard_pattern = np.random.rand(seq_length, seq_length)
    # Apply causal mask
    mask = np.triu(np.ones((seq_length, seq_length)), k=1)
    standard_pattern = np.where(mask, 0, standard_pattern)
    
    im1 = axes[0].imshow(standard_pattern, cmap='YlOrRd', aspect='auto')
    axes[0].set_title('Standard Attention\n(Load entire matrix)', fontsize=12, fontweight='bold')
    axes[0].set_xlabel('Key/Value Position')
    axes[0].set_ylabel('Query Position')
    plt.colorbar(im1, ax=axes[0], label='Attention Weight')
    
    # Flash Attention Pattern with blocks
    flash_pattern = standard_pattern.copy()
    
    # Draw block boundaries
    for i in range(0, seq_length, block_size):
        axes[1].axhline(y=i, color='blue', linewidth=2, alpha=0.7)
        axes[1].axvline(x=i, color='blue', linewidth=2, alpha=0.7)
    
    im2 = axes[1].imshow(flash_pattern, cmap='YlOrRd', aspect='auto')
    axes[1].set_title('Flash Attention\n(Block-wise processing)', fontsize=12, fontweight='bold')
    axes[1].set_xlabel('Key/Value Position')
    axes[1].set_ylabel('Query Position')
    plt.colorbar(im2, ax=axes[1], label='Attention Weight')
    
    # Highlight one block
    highlight_block = 1
    start = highlight_block * block_size
    end = start + block_size
    rect = plt.Rectangle((start, start), block_size, block_size, 
                         fill=False, edgecolor='green', linewidth=3)
    axes[1].add_patch(rect)
    axes[1].annotate('Current block\nin SRAM', xy=(start+block_size/2, start-5),
                    ha='center', fontsize=10, color='green', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Calculate memory savings
    standard_memory = seq_length * seq_length * 4 / (1024**2)  # MB
    flash_memory = block_size * block_size * 4 / (1024**2)  # MB
    
    print("\nüíæ Memory Usage:")
    print("=" * 60)
    print(f"Standard Attention: {standard_memory:.2f} MB (entire matrix)")
    print(f"Flash Attention: {flash_memory:.2f} MB (one block at a time)")
    print(f"Memory reduction: {standard_memory/flash_memory:.1f}x")
    
    print("\nüöÄ Speed Benefits:")
    print("=" * 60)
    print(f"‚Ä¢ Reduced HBM access: {num_blocks**2} ‚Üí {num_blocks} operations")
    print(f"‚Ä¢ Computation stays in fast SRAM")
    print(f"‚Ä¢ Result: 5-9x faster than standard attention")

simulate_flash_attention()

---
## Practice 5: Retrieval-Augmented Generation (RAG)

### üéØ Learning Objectives
- Implement simple RAG pipeline
- Understand vector similarity search
- Combine retrieval with generation

### üìñ Key Concepts
**RAG Components:**
1. Document embedding and storage
2. Query encoding
3. Similarity-based retrieval
4. Context-augmented generation

In [None]:
# 5.1 Simple RAG Implementation
class SimpleRAG:
    """Simple Retrieval-Augmented Generation system"""
    
    def __init__(self, embedding_dim: int = 128):
        self.embedding_dim = embedding_dim
        self.documents = []
        self.embeddings = []
    
    def add_documents(self, docs: List[str]):
        """Add documents to the knowledge base"""
        self.documents.extend(docs)
        
        # Simple embedding: use random projections (in practice, use BERT, etc.)
        for doc in docs:
            # Simulate document embedding
            embedding = torch.randn(self.embedding_dim)
            # Normalize
            embedding = embedding / embedding.norm()
            self.embeddings.append(embedding)
    
    def retrieve(self, query: str, top_k: int = 3):
        """Retrieve most relevant documents"""
        # Encode query (simulate)
        query_embedding = torch.randn(self.embedding_dim)
        query_embedding = query_embedding / query_embedding.norm()
        
        # Compute similarities
        similarities = []
        for doc_emb in self.embeddings:
            similarity = torch.dot(query_embedding, doc_emb).item()
            similarities.append(similarity)
        
        # Get top-k
        top_k_indices = np.argsort(similarities)[-top_k:][::-1]
        top_k_docs = [self.documents[i] for i in top_k_indices]
        top_k_scores = [similarities[i] for i in top_k_indices]
        
        return top_k_docs, top_k_scores

# Test RAG system
def test_rag():
    """Test RAG with medical documents"""
    
    # Sample medical knowledge base
    medical_docs = [
        "Patient presents with elevated glucose levels and frequent urination - Type 2 Diabetes suspected",
        "Chest X-ray shows opacity in lower right lobe - possible pneumonia",
        "EKG reveals ST-segment elevation - acute myocardial infarction",
        "MRI scan indicates lesion in frontal lobe - further neurological assessment needed",
        "Blood pressure 180/120 mmHg - hypertensive emergency",
        "Biopsy results show abnormal cell growth - malignancy confirmed",
        "Patient history includes multiple hospital admissions for heart failure",
        "Lab results: HbA1c 9.2% - poor glycemic control",
        "Genetic testing reveals BRCA1 mutation - increased cancer risk",
        "Patient exhibits symptoms of depression and anxiety"
    ]
    
    # Create RAG system
    rag = SimpleRAG(embedding_dim=128)
    rag.add_documents(medical_docs)
    
    # Test queries
    queries = [
        "Patient with high blood sugar",
        "Heart attack symptoms",
        "Brain imaging results"
    ]
    
    print("üîç Retrieval-Augmented Generation Demo")
    print("=" * 70)
    print(f"Knowledge Base: {len(medical_docs)} medical documents")
    
    for query in queries:
        print(f"\n{'='*70}")
        print(f"Query: '{query}'")
        print("-" * 70)
        
        retrieved_docs, scores = rag.retrieve(query, top_k=3)
        
        print("\nTop 3 Retrieved Documents:")
        for i, (doc, score) in enumerate(zip(retrieved_docs, scores), 1):
            print(f"\n{i}. [Similarity: {score:.4f}]")
            print(f"   {doc}")
    
    # Visualize retrieval
    fig, ax = plt.subplots(figsize=(12, 6))
    
    query = "Patient with high blood sugar"
    retrieved_docs, scores = rag.retrieve(query, top_k=len(medical_docs))
    
    colors = ['green' if score > 0.5 else 'orange' if score > 0.3 else 'lightgray' 
              for score in scores]
    
    bars = ax.barh(range(len(scores)), scores, color=colors)
    ax.set_yticks(range(len(scores)))
    ax.set_yticklabels([f"Doc {i+1}" for i in range(len(scores))], fontsize=9)
    ax.set_xlabel('Cosine Similarity Score', fontsize=11)
    ax.set_title(f'Document Relevance for Query: "{query}"', fontsize=12, fontweight='bold')
    ax.axvline(x=0.5, color='red', linestyle='--', linewidth=2, alpha=0.5, label='High relevance threshold')
    ax.legend()
    ax.grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nüí° RAG Benefits for Medical AI:")
    print("=" * 70)
    print("‚úì Access to vast medical knowledge without model retraining")
    print("‚úì Up-to-date information from latest research")
    print("‚úì Explainable: can cite retrieved sources")
    print("‚úì Efficient: unlimited knowledge base size")
    
    return rag

rag_system = test_rag()

---
## Practice 6: Graph Neural Networks for Medical Data

### üéØ Learning Objectives
- Understand graph representation of medical relationships
- Implement simple message passing
- Visualize disease-symptom-drug networks

### üìñ Key Concepts
**Medical Knowledge Graphs:**
- Nodes: diseases, symptoms, drugs, genes, proteins
- Edges: relationships (causes, treats, interacts)
- GNN: Learn representations by aggregating neighbor information

In [None]:
# 6.1 Medical Knowledge Graph Visualization
def visualize_medical_graph():
    """Create and visualize a medical knowledge graph"""
    
    # Define nodes
    nodes = {
        'diseases': ['Type 2 Diabetes', 'Hypertension', 'Heart Disease'],
        'symptoms': ['High Blood Glucose', 'Increased Thirst', 'High BP', 'Chest Pain'],
        'drugs': ['Metformin', 'Insulin', 'Lisinopril', 'Aspirin'],
        'genes': ['TCF7L2', 'PPARG', 'ACE']
    }
    
    # Define edges (relationships)
    edges = [
        ('Type 2 Diabetes', 'High Blood Glucose', 'causes'),
        ('Type 2 Diabetes', 'Increased Thirst', 'causes'),
        ('Type 2 Diabetes', 'Metformin', 'treated_by'),
        ('Type 2 Diabetes', 'Insulin', 'treated_by'),
        ('Type 2 Diabetes', 'TCF7L2', 'associated'),
        ('Type 2 Diabetes', 'PPARG', 'associated'),
        ('Hypertension', 'High BP', 'causes'),
        ('Hypertension', 'Lisinopril', 'treated_by'),
        ('Hypertension', 'ACE', 'associated'),
        ('Heart Disease', 'Chest Pain', 'causes'),
        ('Heart Disease', 'Aspirin', 'treated_by'),
    ]
    
    print("üï∏Ô∏è Medical Knowledge Graph")
    print("=" * 60)
    print(f"Total nodes: {sum(len(v) for v in nodes.values())}")
    print(f"Total edges: {len(edges)}")
    print("\nNode types:")
    for node_type, node_list in nodes.items():
        print(f"  ‚Ä¢ {node_type}: {len(node_list)}")
    
    # Create adjacency matrix
    all_nodes = []
    for node_list in nodes.values():
        all_nodes.extend(node_list)
    
    node_to_idx = {node: i for i, node in enumerate(all_nodes)}
    n_nodes = len(all_nodes)
    
    adj_matrix = np.zeros((n_nodes, n_nodes))
    for src, dst, rel_type in edges:
        i, j = node_to_idx[src], node_to_idx[dst]
        adj_matrix[i, j] = 1
    
    # Visualize
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot 1: Adjacency matrix
    im = ax1.imshow(adj_matrix, cmap='YlOrRd', aspect='auto')
    ax1.set_title('Knowledge Graph Adjacency Matrix', fontsize=12, fontweight='bold')
    ax1.set_xlabel('Node Index')
    ax1.set_ylabel('Node Index')
    
    # Add colorbar
    plt.colorbar(im, ax=ax1, label='Connection (1=connected)')
    
    # Plot 2: Node degree distribution
    degrees = adj_matrix.sum(axis=1) + adj_matrix.sum(axis=0)
    
    colors_by_type = []
    labels_by_type = []
    start_idx = 0
    color_map = {'diseases': '#ff6b6b', 'symptoms': '#4ecdc4', 
                 'drugs': '#95e1d3', 'genes': '#feca57'}
    
    for node_type, node_list in nodes.items():
        end_idx = start_idx + len(node_list)
        colors_by_type.extend([color_map[node_type]] * len(node_list))
        labels_by_type.extend([node_type] * len(node_list))
        start_idx = end_idx
    
    bars = ax2.bar(range(n_nodes), degrees, color=colors_by_type)
    ax2.set_xlabel('Node Index', fontsize=11)
    ax2.set_ylabel('Node Degree (# connections)', fontsize=11)
    ax2.set_title('Node Connectivity', fontsize=12, fontweight='bold')
    ax2.grid(axis='y', alpha=0.3)
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=color, label=node_type) 
                      for node_type, color in color_map.items()]
    ax2.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    plt.show()
    
    print("\nüìä Graph Statistics:")
    print("=" * 60)
    print(f"Average degree: {degrees.mean():.2f}")
    print(f"Max degree: {degrees.max():.0f}")
    print(f"Most connected node: {all_nodes[int(degrees.argmax())]}")
    
    print("\nüí° GNN Benefits:")
    print("=" * 60)
    print("‚úì Model complex medical relationships")
    print("‚úì Drug repurposing through graph analysis")
    print("‚úì Predict disease comorbidities")
    print("‚úì Discover new biomarkers")
    
    return adj_matrix, all_nodes

adj_matrix, graph_nodes = visualize_medical_graph()

---
## Practice 7: State Space Models (Mamba-style)

### üéØ Learning Objectives
- Understand linear complexity sequence modeling
- Compare RNN, Transformer, and State Space approaches
- Simulate continuous-time dynamics

### üìñ Key Concepts
**State Space Models (SSM):**
- Continuous-time dynamics: h'(t) = Ah(t) + Bx(t)
- Linear complexity: O(n) vs Transformer's O(n¬≤)
- Selective mechanism: Input-dependent transitions

In [None]:
# 7.1 Simple State Space Model
class SimpleSSM(nn.Module):
    """Simplified State Space Model"""
    
    def __init__(self, d_model: int, d_state: int = 16):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # State space parameters
        self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
        self.B = nn.Parameter(torch.randn(d_state, d_model) * 0.01)
        self.C = nn.Parameter(torch.randn(d_model, d_state) * 0.01)
        self.D = nn.Parameter(torch.randn(d_model) * 0.01)
    
    def forward(self, x):
        """Forward pass with linear scan"""
        batch, length, dim = x.shape
        
        # Initialize hidden state
        h = torch.zeros(batch, self.d_state, device=x.device)
        outputs = []
        
        # Linear scan (this is the key advantage!)
        for t in range(length):
            # State update: h(t) = A¬∑h(t-1) + B¬∑x(t)
            h = h @ self.A.T + x[:, t] @ self.B.T
            
            # Output: y(t) = C¬∑h(t) + D¬∑x(t)
            y = h @ self.C.T + x[:, t] * self.D
            outputs.append(y)
        
        return torch.stack(outputs, dim=1)

# Test SSM
def test_ssm():
    """Test State Space Model"""
    
    # Simulate medical time-series (e.g., continuous glucose monitoring)
    batch_size = 1
    seq_length = 1000  # 1000 time steps
    d_model = 8
    
    # Generate synthetic time-series
    t = torch.linspace(0, 10, seq_length)
    # Multiple sine waves with different frequencies (simulating physiological signals)
    x = torch.zeros(batch_size, seq_length, d_model)
    for i in range(d_model):
        freq = 0.5 + i * 0.3
        x[0, :, i] = torch.sin(2 * np.pi * freq * t) + torch.randn(seq_length) * 0.1
    
    # Create SSM
    ssm = SimpleSSM(d_model=d_model, d_state=32)
    
    print("üîÑ State Space Model Demo")
    print("=" * 60)
    print(f"Input shape: {x.shape}")
    print(f"Model parameters: {sum(p.numel() for p in ssm.parameters())}")
    
    # Forward pass
    import time
    start = time.time()
    with torch.no_grad():
        output = ssm(x)
    elapsed = time.time() - start
    
    print(f"Output shape: {output.shape}")
    print(f"Processing time: {elapsed*1000:.2f} ms")
    
    # Visualize
    fig, axes = plt.subplots(2, 2, figsize=(14, 8))
    
    # Plot input signals
    for i in range(min(3, d_model)):
        axes[0, 0].plot(t.numpy(), x[0, :, i].numpy(), label=f'Signal {i+1}', alpha=0.7)
    axes[0, 0].set_title('Input Time Series (Medical Signals)', fontsize=11, fontweight='bold')
    axes[0, 0].set_xlabel('Time')
    axes[0, 0].set_ylabel('Value')
    axes[0, 0].legend()
    axes[0, 0].grid(alpha=0.3)
    
    # Plot output
    for i in range(min(3, d_model)):
        axes[0, 1].plot(t.numpy(), output[0, :, i].detach().numpy(), 
                       label=f'Output {i+1}', alpha=0.7)
    axes[0, 1].set_title('SSM Output', fontsize=11, fontweight='bold')
    axes[0, 1].set_xlabel('Time')
    axes[0, 1].set_ylabel('Value')
    axes[0, 1].legend()
    axes[0, 1].grid(alpha=0.3)
    
    # State matrix A visualization
    im1 = axes[1, 0].imshow(ssm.A.detach().numpy(), cmap='RdBu', aspect='auto')
    axes[1, 0].set_title('State Transition Matrix A', fontsize=11, fontweight='bold')
    axes[1, 0].set_xlabel('State Dimension')
    axes[1, 0].set_ylabel('State Dimension')
    plt.colorbar(im1, ax=axes[1, 0])
    
    # Complexity comparison
    seq_lengths = [100, 500, 1000, 5000, 10000, 50000, 100000]
    transformer_complexity = [n**2 for n in seq_lengths]
    ssm_complexity = [n for n in seq_lengths]
    
    axes[1, 1].plot(seq_lengths, transformer_complexity, 'o-', label='Transformer O(n¬≤)', 
                   color='#dc3545', linewidth=2)
    axes[1, 1].plot(seq_lengths, ssm_complexity, 's-', label='SSM O(n)', 
                   color='#28a745', linewidth=2)
    axes[1, 1].set_xscale('log')
    axes[1, 1].set_yscale('log')
    axes[1, 1].set_xlabel('Sequence Length', fontsize=11)
    axes[1, 1].set_ylabel('Computational Complexity', fontsize=11)
    axes[1, 1].set_title('Complexity Comparison', fontsize=11, fontweight='bold')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nüí° SSM Advantages for Medical AI:")
    print("=" * 60)
    print("‚úì Linear complexity: O(n) vs Transformer's O(n¬≤)")
    print("‚úì Perfect for long medical time-series (ICU monitoring, EEG)")
    print("‚úì Continuous-time modeling of physiological dynamics")
    print("‚úì Efficient on both training and inference")
    
    return ssm

ssm_model = test_ssm()

---
## Practice 8: Performance Comparison

### üéØ Learning Objectives
- Compare all architectures on key metrics
- Understand trade-offs between approaches
- Make informed architecture choices

### üìñ Key Metrics
- **Computational Complexity:** Time and memory
- **Accuracy:** Task-specific performance
- **Scalability:** How well does it scale?
- **Deployment:** Practical considerations

In [None]:
# 8.1 Comprehensive Architecture Comparison
def architecture_comparison():
    """Compare different architectures"""
    
    architectures = {
        'Dense Transformer': {
            'complexity_time': 'O(n¬≤)',
            'complexity_memory': 'O(n¬≤)',
            'max_seq_length': 2048,
            'accuracy': 0.92,
            'inference_speed': 1.0,
            'color': '#dc3545'
        },
        'MoE': {
            'complexity_time': 'O(n¬≤/N√óK)',
            'complexity_memory': 'O(n¬≤/N√óK)',
            'max_seq_length': 2048,
            'accuracy': 0.94,
            'inference_speed': 2.5,
            'color': '#ffc107'
        },
        'Flash Attention': {
            'complexity_time': 'O(n¬≤)',
            'complexity_memory': 'O(n)',
            'max_seq_length': 100000,
            'accuracy': 0.92,
            'inference_speed': 7.0,
            'color': '#1E64C8'
        },
        'Mamba (SSM)': {
            'complexity_time': 'O(n)',
            'complexity_memory': 'O(n)',
            'max_seq_length': 1000000,
            'accuracy': 0.91,
            'inference_speed': 10.0,
            'color': '#28a745'
        },
        'Graph Transformer': {
            'complexity_time': 'O(E)',
            'complexity_memory': 'O(V+E)',
            'max_seq_length': 'N/A',
            'accuracy': 0.89,
            'inference_speed': 3.0,
            'color': '#a29bfe'
        }
    }
    
    print("üìä Architecture Comparison Table")
    print("=" * 90)
    print(f"{'Architecture':<20} {'Time':<15} {'Memory':<15} {'Max Length':<12} {'Accuracy':<10} {'Speed':<8}")
    print("-" * 90)
    
    for name, specs in architectures.items():
        print(f"{name:<20} {specs['complexity_time']:<15} {specs['complexity_memory']:<15} "
              f"{str(specs['max_seq_length']):<12} {specs['accuracy']:<10.2f} {specs['inference_speed']:<8.1f}x")
    
    # Visualization
    fig = plt.figure(figsize=(16, 10))
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
    
    # Plot 1: Accuracy comparison
    ax1 = fig.add_subplot(gs[0, :])
    names = list(architectures.keys())
    accuracies = [specs['accuracy'] for specs in architectures.values()]
    colors = [specs['color'] for specs in architectures.values()]
    
    bars = ax1.bar(names, accuracies, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    ax1.set_ylabel('Accuracy', fontsize=12)
    ax1.set_title('Model Accuracy on Medical Tasks', fontsize=14, fontweight='bold')
    ax1.set_ylim([0.85, 0.95])
    ax1.axhline(y=0.90, color='red', linestyle='--', alpha=0.5, label='90% threshold')
    ax1.legend()
    ax1.grid(axis='y', alpha=0.3)
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # Plot 2: Inference speed
    ax2 = fig.add_subplot(gs[1, 0])
    speeds = [specs['inference_speed'] for specs in architectures.values()]
    ax2.barh(names, speeds, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    ax2.set_xlabel('Relative Speed (x times)', fontsize=11)
    ax2.set_title('Inference Speed', fontsize=12, fontweight='bold')
    ax2.grid(axis='x', alpha=0.3)
    
    # Plot 3: Max sequence length
    ax3 = fig.add_subplot(gs[1, 1])
    max_lengths = [specs['max_seq_length'] if isinstance(specs['max_seq_length'], int) 
                   else 0 for specs in architectures.values()]
    valid_names = [name for name, length in zip(names, max_lengths) if length > 0]
    valid_lengths = [length for length in max_lengths if length > 0]
    valid_colors = [color for color, length in zip(colors, max_lengths) if length > 0]
    
    ax3.barh(valid_names, valid_lengths, color=valid_colors, alpha=0.7, 
             edgecolor='black', linewidth=1.5)
    ax3.set_xlabel('Max Sequence Length (tokens)', fontsize=11)
    ax3.set_title('Context Window Size', fontsize=12, fontweight='bold')
    ax3.set_xscale('log')
    ax3.grid(axis='x', alpha=0.3)
    
    # Add 100K marker
    ax3.axvline(x=100000, color='blue', linestyle='--', linewidth=2, alpha=0.5)
    ax3.text(100000, len(valid_names)-0.5, '100K\n(Patient history)', 
            ha='left', va='center', fontsize=9, color='blue')
    
    # Plot 4: Use case matrix
    ax4 = fig.add_subplot(gs[1, 2])
    use_cases = ['Short Seq', 'Long Seq', 'Graph Data', 'Time Series', 'Real-time']
    suitability = np.array([
        [1.0, 0.3, 0.2, 0.7, 0.4],  # Dense Transformer
        [1.0, 0.3, 0.2, 0.7, 0.6],  # MoE
        [0.9, 1.0, 0.2, 0.8, 0.7],  # Flash Attention
        [0.8, 1.0, 0.2, 1.0, 0.9],  # Mamba
        [0.6, 0.5, 1.0, 0.6, 0.7],  # Graph Transformer
    ])
    
    im = ax4.imshow(suitability, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
    ax4.set_xticks(range(len(use_cases)))
    ax4.set_yticks(range(len(names)))
    ax4.set_xticklabels(use_cases, rotation=45, ha='right', fontsize=9)
    ax4.set_yticklabels(names, fontsize=9)
    ax4.set_title('Suitability Matrix', fontsize=12, fontweight='bold')
    plt.colorbar(im, ax=ax4, label='Suitability (0-1)')
    
    # Add text annotations
    for i in range(len(names)):
        for j in range(len(use_cases)):
            text = ax4.text(j, i, f'{suitability[i, j]:.1f}',
                          ha="center", va="center", color="black", fontsize=8)
    
    # Plot 5: Radar chart comparison
    ax5 = fig.add_subplot(gs[2, :], projection='polar')
    
    categories = ['Accuracy', 'Speed', 'Memory Eff.', 'Scalability', 'Versatility']
    N = len(categories)
    
    # Normalize scores
    scores = {
        'Dense Transformer': [0.92, 0.10, 0.30, 0.20, 0.90],
        'MoE': [0.94, 0.25, 0.35, 0.60, 0.85],
        'Flash Attention': [0.92, 0.70, 0.95, 0.90, 0.90],
        'Mamba (SSM)': [0.91, 1.00, 1.00, 1.00, 0.80],
    }
    
    angles = [n / float(N) * 2 * np.pi for n in range(N)]
    angles += angles[:1]
    
    for name, score in scores.items():
        values = score + score[:1]
        ax5.plot(angles, values, 'o-', linewidth=2, label=name, 
                color=architectures[name]['color'])
        ax5.fill(angles, values, alpha=0.15, color=architectures[name]['color'])
    
    ax5.set_xticks(angles[:-1])
    ax5.set_xticklabels(categories, fontsize=10)
    ax5.set_ylim(0, 1)
    ax5.set_title('Overall Performance Profile', fontsize=13, fontweight='bold', pad=20)
    ax5.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1), fontsize=9)
    ax5.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print("\nüí° Architecture Selection Guide:")
    print("=" * 90)
    print("‚úì Short sequences (<2K tokens): Dense Transformer or MoE")
    print("‚úì Long context (100K+ tokens): Flash Attention or Mamba")
    print("‚úì Graph/relational data: Graph Transformers")
    print("‚úì Time-series (ICU monitoring): Mamba (SSM)")
    print("‚úì Large-scale deployment: MoE with Flash Attention")
    print("‚úì Real-time inference: Mamba or optimized MoE")

architecture_comparison()

---
## üéØ Practice Complete!

### Summary of What We Learned:

1. **Mixture of Experts (MoE)**
   - Sparse activation enables efficient scaling
   - Top-K routing selects specialized experts
   - 75% computation reduction with minimal accuracy loss

2. **Long-Context Models**
   - Flash Attention: O(n) memory vs standard O(n¬≤)
   - Essential for processing 100K+ token patient histories
   - 5-9x faster than standard attention

3. **Novel Architectures**
   - **RAG**: Unlimited knowledge base access
   - **Graph Transformers**: Model medical relationships
   - **Mamba (SSM)**: Linear complexity for time-series

4. **Performance Comparison**
   - Each architecture excels at different tasks
   - Trade-offs between accuracy, speed, and scalability
   - Hybrid approaches often yield best results

### Key Insights:

- **Efficiency matters**: Flash Attention and Mamba enable practical long-context processing
- **Specialization helps**: MoE and Graph models leverage domain structure
- **No single winner**: Choose architecture based on your specific use case

### Medical AI Applications:

- **MoE**: Multi-specialty diagnostic systems
- **Long-Context**: Comprehensive patient history analysis
- **RAG**: Always up-to-date medical knowledge
- **Graph**: Drug discovery and disease networks
- **SSM**: Continuous monitoring (ICU, wearables)

### Next Steps:

1. Experiment with pre-trained models (Hugging Face)
2. Fine-tune on medical datasets
3. Explore hybrid architectures
4. Stay updated on emerging techniques (2025+)

### üîÆ Future Outlook (2025-2030):

- **2025-2027**: 1M+ context models, clinical MoE deployment
- **2027-2029**: Quantum ML begins, neuromorphic devices
- **2030+**: Hybrid bio-AI systems, personalized AI for every patient

---

## üìö Further Resources

### Papers to Read:
- Switch Transformer (Google, 2021)
- Flash Attention v2 (Dao et al., 2023)
- Mamba: Linear-Time Sequence Modeling (Gu & Dao, 2023)
- Graph Transformer Networks (Dwivedi & Bresson, 2020)

### Libraries to Explore:
- **DeepSpeed-MoE**: Efficient MoE training
- **Flash Attention**: Fast attention implementation
- **PyTorch Geometric**: Graph neural networks
- **Hugging Face Transformers**: Pre-trained models

### Datasets:
- MIMIC-III / MIMIC-IV: ICU time-series
- MedQA: Medical question answering
- BioASQ: Biomedical semantic indexing

---

**üöÄ Keep exploring and building the future of Medical AI!**