# NanoGPT Bayesian Neural Network Inference

This notebook applies Bayesian Neural Network (BNN) inference to a trained NanoGPT model using the `posteriors` library. We'll load a pre-trained model from the checkpoint folder and perform variational inference to learn a posterior distribution over the model parameters.

## Overview
- Load trained NanoGPT model and tokenizer
- Set up Bayesian inference with variational inference (VI)
- Train posterior distribution over model parameters
- Compare deterministic vs. Bayesian predictions
- Generate text with uncertainty quantification

---

## 1. Import Required Libraries

In [1]:
import os
import sys
import math
import numpy as np
from pathlib import Path
from typing import Dict, Tuple, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import func

# Import Bayesian libraries
import torchopt
import posteriors

# Add paths for importing utilities and models
current_dir = Path.cwd()
sys.path.append(str(current_dir))
sys.path.append(str(current_dir / "baselines"))

# Import our utilities
from utils import load_model, load_tokenizer, encode, decode

print(f"Current directory: {current_dir}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

W0925 02:42:20.971000 5828 Lib\site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
  from optree.integration.torch import tree_ravel
  from optree.integration.torch import tree_ravel


Current directory: c:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen\notebooks
PyTorch version: 2.8.0+cpu
CUDA available: False
Using device: cpu


## 2. Configuration

In [4]:
# Configuration for Bayesian NanoGPT
CONFIG = {
    # Model paths - choose one of the available checkpoints
    'model_path': '../checkpoints/baseline_nanogpt/baseline_nanogpt.pt',
    'meta_path': '../checkpoints/baseline_nanogpt/nanogpt_meta.pkl',
    'data_dir': 'nanoGPT/data/shakespeare_char',
    
    # Alternative: Use token-level model if available
    # 'model_path': '../checkpoints/token_level_nanogpt/token_level_nanogpt.pt',
    # 'meta_path': '../checkpoints/token_level_nanogpt/token_level_meta.pkl',
    # 'data_dir': 'nanoGPT/data/shakespeare',
    
    # Bayesian inference parameters (reduced for debugging)
    'batch_size': 4,        # Reduced batch size
    'num_epochs': 2,        # Reduced epochs
    'learning_rate': 1e-3,  # Slightly higher learning rate
    'temperature': 1.0,     # Higher temperature
    'prior_std': 1.0,       # Prior standard deviation
    
    # Text generation parameters
    'max_new_tokens': 50,   # Reduced for faster testing
    'generation_temperature': 0.8,
    'num_samples': 3,       # Fewer samples for testing
    
    # Data parameters (much reduced for debugging)
    'max_seq_length': 64,   # Shorter sequences
    'train_samples': 20,    # Much fewer samples
}

print("Configuration (DEBUG MODE):")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Check if paths exist
print("\nChecking paths:")
for path_key in ['model_path', 'meta_path', 'data_dir']:
    path = Path(CONFIG[path_key])
    if path.exists():
        print(f"{path_key}: {path}")
    else:
        raise FileNotFoundError(f"{path_key} does not exist: {path}")


Configuration (DEBUG MODE):
  model_path: ../checkpoints/baseline_nanogpt/baseline_nanogpt.pt
  meta_path: ../checkpoints/baseline_nanogpt/nanogpt_meta.pkl
  data_dir: nanoGPT/data/shakespeare_char
  batch_size: 4
  num_epochs: 2
  learning_rate: 0.001
  temperature: 1.0
  prior_std: 1.0
  max_new_tokens: 50
  generation_temperature: 0.8
  num_samples: 3
  max_seq_length: 64
  train_samples: 20

Checking paths:
model_path: ..\checkpoints\baseline_nanogpt\baseline_nanogpt.pt
meta_path: ..\checkpoints\baseline_nanogpt\nanogpt_meta.pkl
data_dir: nanoGPT\data\shakespeare_char


## 3. Load Pre-trained NanoGPT Model

In [5]:
# Load the pre-trained NanoGPT model and tokenizer
print("Loading pre-trained NanoGPT model...")

try:
    # Load model
    model, checkpoint = load_model(Path(CONFIG['model_path']), device)
    print(f"Model loaded successfully!")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Load tokenizer
    stoi, itos = load_tokenizer(Path(CONFIG['meta_path']))
    vocab_size = len(itos)
    print(f"Vocabulary size: {vocab_size}")
    
    # Get model architecture info
    print(f"Model architecture:")
    print(f"  - Layers: {model.config.n_layer}")
    print(f"  - Heads: {model.config.n_head}")
    print(f"  - Embedding dim: {model.config.n_embd}")
    print(f"  - Block size: {model.config.block_size}")
    
    # Extract model parameters for posteriors
    params = dict(model.named_parameters())
    print(f"Number of parameter tensors: {len(params)}")
    
    # Show parameter tensor shapes
    print("Parameter shapes:")
    for name, param in list(params.items())[:5]:  # Show first 5
        print(f"  {name}: {param.shape}")
    if len(params) > 5:
        print(f"  ... and {len(params) - 5} more parameter tensors")
    
except Exception as e:
    print(f"Error loading model: {e}")
    model, params, stoi, itos = None, None, None, None

Loading pre-trained NanoGPT model...
Loading model from: ..\checkpoints\baseline_nanogpt\baseline_nanogpt.pt
Model arguments: {'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'block_size': 256, 'bias': False, 'vocab_size': 65, 'dropout': 0.2}
Model arguments: {'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'block_size': 256, 'bias': False, 'vocab_size': 65, 'dropout': 0.2}
number of parameters: 10.65M
Model loaded successfully!
Number of parameters: 10,745,088
Model loaded successfully!
Model parameters: 10,745,088
Vocabulary size: 65
Model architecture:
  - Layers: 6
  - Heads: 6
  - Embedding dim: 384
  - Block size: 256
Number of parameter tensors: 39
Parameter shapes:
  transformer.wte.weight: torch.Size([65, 384])
  transformer.wpe.weight: torch.Size([256, 384])
  transformer.h.0.ln_1.weight: torch.Size([384])
  transformer.h.0.attn.c_attn.weight: torch.Size([1152, 384])
  transformer.h.0.attn.c_proj.weight: torch.Size([384, 384])
  ... and 34 more parameter tensors
number of parameters: 10.

## 4. Prepare Training Data

In [6]:
# Prepare training data for Bayesian inference
print("Preparing training data...")

if model is not None:
    # Load training data
    train_data_path = Path(CONFIG['data_dir']) / 'train.bin'
    
    if train_data_path.exists():
        # Load binary data
        data = np.memmap(str(train_data_path), dtype=np.uint16, mode='r')
        print(f"Loaded data: {len(data):,} tokens")
        
        # Create training batches for next-token prediction
        def create_training_batches(data, batch_size, seq_length, num_samples):
            """Create training batches from the data for next-token prediction"""
            batches = []
            max_start = len(data) - seq_length - 1
            
            # Sample random starting positions
            start_indices = np.random.choice(max_start, size=num_samples, replace=False)
            
            for i in range(0, len(start_indices), batch_size):
                batch_starts = start_indices[i:i+batch_size]
                x_batch = []
                y_batch = []
                
                for start in batch_starts:
                    # For next-token prediction, x is the sequence and y is the next token
                    x_seq = data[start:start+seq_length].astype(np.int64)
                    # y is just the last token (next token prediction)
                    y_seq = data[start+seq_length:start+seq_length+1].astype(np.int64)
                    x_batch.append(x_seq)
                    y_batch.append(y_seq)
                
                x_tensor = torch.tensor(np.array(x_batch), device=device)
                y_tensor = torch.tensor(np.array(y_batch), device=device)
                batches.append((x_tensor, y_tensor))
            
            return batches
        
        # Create training batches
        training_batches = create_training_batches(
            data, 
            CONFIG['batch_size'], 
            CONFIG['max_seq_length'], 
            CONFIG['train_samples']
        )
        
        print(f"Created {len(training_batches)} training batches")
        print(f"Batch shape: {training_batches[0][0].shape}")
        print(f"Target shape: {training_batches[0][1].shape}")
        
        # Calculate number of data points for posteriors
        num_data = CONFIG['train_samples']
        print(f"Total training samples: {num_data}")
        
    else:
        print(f"Training data not found at {train_data_path}")
        training_batches = []
        num_data = 0
        
else:
    print("Model not loaded, skipping data preparation")
    training_batches = []
    num_data = 0

Preparing training data...
Loaded data: 1,003,854 tokens
Created 5 training batches
Batch shape: torch.Size([4, 64])
Target shape: torch.Size([4, 1])
Total training samples: 20


## 5. Define Log Posterior Function

In [8]:
# Define the log posterior function for posteriors library
def single_batch_loss(params, batch):
    """Compute loss for a single batch using functional_call"""
    x, y = batch
    
    # Forward pass through the model using functional call
    logits, _ = func.functional_call(model, params, (x,))
    
    # For next-token prediction, we expect:
    # - logits shape: [batch_size, 1, vocab_size] (predicting next token)
    # - targets shape: [batch_size, 1] (the actual next token)
    
    # Compute cross entropy loss
    batch_size, seq_length, vocab_size = logits.shape
    logits_flat = logits.view(-1, vocab_size)  # (batch_size * seq_length, vocab_size)
    targets_flat = y.view(-1)  # (batch_size * seq_length,)
    
    loss = F.cross_entropy(logits_flat, targets_flat, reduction='mean')
    
    return loss

def log_posterior_fn(params, batch):
    """
    Log posterior function compatible with posteriors library.
    
    Args:
        params: Model parameters dictionary
        batch: Tuple of (input_tokens, target_tokens)
        
    Returns:
        log_posterior_value: Scalar tensor
    """
    # Compute negative log likelihood
    nll = single_batch_loss(params, batch)
    
    # Compute log prior more safely (avoiding potential tensor iteration issues)
    log_prior = torch.tensor(0.0, device=device)
    for param in params.values():
        if param.requires_grad:
            # Use normal distribution for prior - more robust computation
            prior_dist = torch.distributions.Normal(0.0, CONFIG['prior_std'])
            param_log_prob = prior_dist.log_prob(param).sum()
            log_prior = log_prior + param_log_prob
    
    # Scale prior by number of data points (standard in Bayesian inference)
    # Ensure num_data is a proper scalar
    num_data_tensor = torch.tensor(float(num_data), device=device)
    log_posterior = -nll + log_prior / num_data_tensor
    
    return log_posterior


if model is not None and training_batches:
    print("Log posterior function defined for posteriors")
    
    # Test the log posterior function with a sample batch
    print("Testing log posterior function...")
    test_batch = training_batches[0]
    
    try:
        # Debug: Print shapes to verify they're correct
        x_test, y_test = test_batch
        print(f"Input shape: {x_test.shape}")
        print(f"Target shape: {y_test.shape}")
        print(f"num_data: {num_data}")
        
        # Test forward pass first
        with torch.no_grad():
            logits, _ = func.functional_call(model, params, (x_test,))
            print(f"Logits shape: {logits.shape}")
            
            # Check model configuration
            print(f"Model block size: {model.config.block_size}")
            print(f"Model vocab size: {model.config.vocab_size}")
        
        # Test both versions
        log_post_val = log_posterior_fn(params, test_batch)
        simple_log_post_val = log_posterior_fn(params, test_batch)
        loss_val = single_batch_loss(params, test_batch)
        
        print(f"Test successful!")
        print(f"Log posterior value: {log_post_val.item():.4f}")
        print(f"Simple log posterior value: {simple_log_post_val.item():.4f}")
        print(f"Loss value: {loss_val.item():.4f}")
        
        # Test sampling functionality placeholder
        print("Ready for posteriors VI setup")
        
    except Exception as e:
        print(f"Test failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print("Cannot define log posterior - model or data not available")

Log posterior function defined for posteriors
Testing log posterior function...
Input shape: torch.Size([4, 64])
Target shape: torch.Size([4, 1])
num_data: 20
Logits shape: torch.Size([4, 1, 65])
Model block size: 256
Model vocab size: 65
Test successful!
Log posterior value: -494057.5312
Simple log posterior value: -494057.5938
Loss value: 0.8578
Ready for posteriors VI setup
Test successful!
Log posterior value: -494057.5312
Simple log posterior value: -494057.5938
Loss value: 0.8578
Ready for posteriors VI setup


## 6. Setup Variational Inference

In [9]:
# Setup variational inference using posteriors library
if model is not None and training_batches:
    print("Setting up variational inference with posteriors...")
    
    try:
        # Use the simplified log posterior function to avoid tensor iteration issues
        print("Using simplified log posterior function for debugging...")
        
        # Build the variational inference transform using posteriors.vi.diag
        print("Building VI transform with posteriors.vi.diag...")
        
        # Create optimizer
        optimizer = torchopt.adam(lr=CONFIG['learning_rate'])
        
        # Build the VI transform with simplified log posterior
        vi_transform = posteriors.vi.diag.build(
            log_posterior=log_posterior_fn,  # Use simplified version
            optimizer=optimizer,
            temperature=CONFIG['temperature'],
            n_samples=1  # Number of samples per VI update
        )
        
        print("VI transform created successfully!")
        
        # Initialize the variational state
        print("Initializing VI state...")
        vi_state = vi_transform.init(params)
        
        print("VI state initialized!")
        print(f"State keys: {list(vi_state._asdict().keys()) if hasattr(vi_state, '_asdict') else 'Custom state object'}")
        
        print("Ready for Bayesian training with posteriors!")
        
    except Exception as e:
        print(f"Error setting up posteriors VI: {e}")
        import traceback
        traceback.print_exc()
        vi_transform = None
        vi_state = None
        
else:
    print("Cannot setup VI - model or data not available")
    vi_transform = None
    vi_state = None

Setting up variational inference with posteriors...
Using simplified log posterior function for debugging...
Building VI transform with posteriors.vi.diag...
VI transform created successfully!
Initializing VI state...
VI state initialized!
State keys: Custom state object
Ready for Bayesian training with posteriors!


## 7. Run Bayesian Training

In [10]:
# Run Bayesian training using posteriors VI
if vi_transform is not None and vi_state is not None:
    print("Starting Bayesian training with posteriors...")
    print("=" * 50)
    
    # Training metrics tracking
    training_losses = []
    log_posterior_values = []
    
    try:
        for epoch in range(CONFIG['num_epochs']):
            epoch_losses = []
            epoch_log_posts = []
            print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
            
            for batch_idx, batch in enumerate(training_batches):
                try:
                    # Update the variational state using posteriors
                    vi_state = vi_transform.update(vi_state, batch)
                    
                    # Evaluate current performance
                    with torch.no_grad():
                        # Get current parameter estimates (posterior mean)
                        current_loss = single_batch_loss(vi_state.params, batch)
                        current_log_post = log_posterior_fn(vi_state.params, batch)
                        
                        epoch_losses.append(current_loss.item())
                        epoch_log_posts.append(current_log_post.item())
                    
                    # Print progress every 2 batches (since we have fewer batches)
                    if (batch_idx + 1) % 2 == 0 or batch_idx == 0:
                        recent_loss = np.mean(epoch_losses[-2:]) if len(epoch_losses) >= 2 else epoch_losses[-1]
                        recent_log_post = np.mean(epoch_log_posts[-2:]) if len(epoch_log_posts) >= 2 else epoch_log_posts[-1]
                        print(f"  Batch {batch_idx + 1}/{len(training_batches)}: Loss = {recent_loss:.4f}, Log Post = {recent_log_post:.4f}")
                        
                except Exception as batch_error:
                    print(f"Error in batch {batch_idx + 1}: {batch_error}")
                    # Continue with next batch
                    continue
            
            if epoch_losses:  # Only proceed if we have some losses
                # Calculate epoch averages
                avg_epoch_loss = np.mean(epoch_losses)
                avg_log_post = np.mean(epoch_log_posts)
                
                training_losses.append(avg_epoch_loss)
                log_posterior_values.append(avg_log_post)
                
                print(f"Epoch {epoch + 1} completed:")
                print(f"   Average Loss: {avg_epoch_loss:.4f}")
                print(f"   Average Log Posterior: {avg_log_post:.4f}")
            else:
                print(f"Epoch {epoch + 1} failed - no successful batches")
                break
        
        if training_losses:
            print(f"\nBayesian training with posteriors completed!")
            print(f"Final loss: {training_losses[-1]:.4f}")
            print(f"Final log posterior: {log_posterior_values[-1]:.4f}")

            # Plot training progress
            try:
                import matplotlib.pyplot as plt
                
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
                
                # Plot loss
                epochs = range(1, len(training_losses) + 1)
                ax1.plot(epochs, training_losses, 'b-', linewidth=2, label='Loss')
                ax1.set_title('Training Loss')
                ax1.set_xlabel('Epoch')
                ax1.set_ylabel('Loss')
                ax1.grid(True, alpha=0.3)
                ax1.legend()
                
                # Plot log posterior
                ax2.plot(epochs, log_posterior_values, 'r-', linewidth=2, label='Log Posterior')
                ax2.set_title('Log Posterior')
                ax2.set_xlabel('Epoch')
                ax2.set_ylabel('Log Posterior')
                ax2.grid(True, alpha=0.3)
                ax2.legend()
                
                plt.tight_layout()
                plt.show()
                
            except ImportError:
                print("Matplotlib not available for plotting")
        else:
            print(" Training failed - no successful epochs")
        
    except Exception as e:
        print(f" Error during posteriors training: {e}")
        import traceback
        traceback.print_exc()
        
else:
    print("Cannot run training - posteriors VI setup failed")

Starting Bayesian training with posteriors...

Epoch 1/2
Error in batch 1: iteration over a 0-d tensor
Error in batch 1: iteration over a 0-d tensor
Error in batch 2: iteration over a 0-d tensor
Error in batch 2: iteration over a 0-d tensor
Error in batch 3: iteration over a 0-d tensor
Error in batch 3: iteration over a 0-d tensor
Error in batch 4: iteration over a 0-d tensor
Error in batch 4: iteration over a 0-d tensor
Error in batch 5: iteration over a 0-d tensor
Epoch 1 failed - no successful batches
 Training failed - no successful epochs
Error in batch 5: iteration over a 0-d tensor
Epoch 1 failed - no successful batches
 Training failed - no successful epochs


# NEVER GOT THIS FAR

## 8. Compare Deterministic vs Bayesian Predictions

In [None]:
# Compare deterministic vs Bayesian predictions using posteriors
if vi_state is not None and model is not None:
    print("Comparing Deterministic vs Bayesian Predictions (posteriors)")
    print("=" * 50)
    
    # Evaluate on a test batch
    test_batch = training_batches[0] if training_batches else None
    
    if test_batch is not None:
        x_test, y_test = test_batch
        
        print(f"Test batch shape: {x_test.shape}")
        
        # 1. Deterministic prediction (original model)
        print("\nDeterministic Prediction:")
        model.eval()
        with torch.no_grad():
            deterministic_logits, _ = model(x_test)
            deterministic_loss = F.cross_entropy(
                deterministic_logits.view(-1, deterministic_logits.size(-1)),
                y_test.view(-1)
            )
            print(f"  Loss: {deterministic_loss.item():.4f}")
            
            # Calculate perplexity
            deterministic_ppl = torch.exp(deterministic_loss)
            print(f"  Perplexity: {deterministic_ppl.item():.4f}")
        
        # 2. Bayesian prediction using posteriors (posterior mean)
        print("\nBayesian Prediction (Posterior Mean via posteriors):")
        try:
            with torch.no_grad():
                bayesian_logits, _ = func.functional_call(model, vi_state.params, (x_test,))
                bayesian_loss = F.cross_entropy(
                    bayesian_logits.view(-1, bayesian_logits.size(-1)),
                    y_test.view(-1)
                )
                print(f"  Loss: {bayesian_loss.item():.4f}")
                
                # Calculate perplexity
                bayesian_ppl = torch.exp(bayesian_loss)
                print(f"  Perplexity: {bayesian_ppl.item():.4f}")
                
                # Calculate improvement
                improvement = deterministic_loss.item() - bayesian_loss.item()
                print(f"\nLoss Improvement: {improvement:.4f}")
                
                if improvement > 0:
                    print("Bayesian model performs better!")
                else:
                    print("Deterministic model performs better")
                    
        except Exception as e:
            print(f"Error in Bayesian prediction: {e}")
        
        # 3. Sample multiple predictions from posterior using posteriors
        print("\nMultiple Posterior Samples (using posteriors.vi.diag.sample):")
        try:
            posterior_samples = []
            sample_losses = []
            
            # Generate multiple samples from the posterior using posteriors
            for i in range(CONFIG['num_samples']):
                # Use posteriors to sample from the variational distribution
                sample_params = posteriors.vi.diag.sample(vi_state)
                
                # Evaluate sample
                with torch.no_grad():
                    sample_logits, _ = func.functional_call(model, sample_params, (x_test,))
                    sample_loss = F.cross_entropy(
                        sample_logits.view(-1, sample_logits.size(-1)),
                        y_test.view(-1)
                    )
                    sample_losses.append(sample_loss.item())
                    posterior_samples.append(sample_logits)
            
            # Calculate statistics
            mean_loss = np.mean(sample_losses)
            std_loss = np.std(sample_losses)
            
            print(f"  Mean Loss: {mean_loss:.4f} ± {std_loss:.4f}")
            print(f"  Min Loss: {min(sample_losses):.4f}")
            print(f"  Max Loss: {max(sample_losses):.4f}")
            
            # Calculate predictive uncertainty
            if len(posterior_samples) > 1:
                # Stack samples and calculate variance
                stacked_logits = torch.stack(posterior_samples)  # (num_samples, batch, seq, vocab)
                pred_probs = F.softmax(stacked_logits, dim=-1)
                
                # Calculate predictive entropy (uncertainty)
                mean_probs = pred_probs.mean(dim=0)  # Average over samples
                pred_entropy = -(mean_probs * torch.log(mean_probs + 1e-8)).sum(dim=-1)
                
                avg_uncertainty = pred_entropy.mean().item()
                print(f"  Average Predictive Uncertainty: {avg_uncertainty:.4f}")
                
                # Show parameter uncertainty statistics
                print(f"\nParameter Uncertainty (from posteriors):")
                total_params = 0
                total_std = 0
                
                for name in vi_state.params:
                    if hasattr(vi_state, 'log_scale'):
                        # Extract standard deviation from log_scale
                        param_std = torch.exp(vi_state.log_scale[name]).mean().item()
                        param_count = vi_state.params[name].numel()
                        
                        total_params += param_count
                        total_std += param_std * param_count
                        
                        print(f"  {name}: std = {param_std:.6f}")
                
                if total_params > 0:
                    avg_param_std = total_std / total_params
                    print(f"  Average parameter std: {avg_param_std:.6f}")
                
        except Exception as e:
            print(f"Error in posterior sampling with posteriors: {e}")
            import traceback
            traceback.print_exc()
            
    else:
        print("No test data available")
        
else:
    print("Cannot compare predictions - models not available")

## 9. Text Generation with Uncertainty

In [None]:
# Generate text with uncertainty quantification using posteriors
if vi_state is not None and model is not None:
    print("Generating Text with Bayesian Uncertainty (posteriors)")
    print("=" * 50)
    
    # Define generation prompts
    prompts = [
        "To be or not to be,",
        "HAMLET:",
        "Fair is foul and foul is",
        "Tomorrow, and tomorrow,"
    ]
    
    for prompt in prompts:
        print(f"\nPrompt: '{prompt}'")
        print("-" * 40)
        
        try:
            # Encode prompt
            prompt_tokens = encode(prompt, stoi)
            x = torch.tensor(prompt_tokens, dtype=torch.long, device=device)[None, ...]
            
            # 1. Deterministic generation (original model)
            print("Deterministic Generation:")
            with torch.no_grad():
                deterministic_x = x.clone()
                for _ in range(CONFIG['max_new_tokens']):
                    # Crop if sequence gets too long
                    x_cond = deterministic_x if deterministic_x.size(1) <= model.config.block_size else deterministic_x[:, -model.config.block_size:]
                    
                    # Forward pass
                    logits, _ = model(x_cond)
                    logits = logits[:, -1, :] / CONFIG['generation_temperature']
                    
                    # Sample next token
                    probs = F.softmax(logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                    deterministic_x = torch.cat((deterministic_x, next_token), dim=1)
                
                # Decode generated text
                deterministic_tokens = deterministic_x[0].tolist()
                deterministic_text = decode(deterministic_tokens, itos)
                generated_part = deterministic_text[len(prompt):].strip()[:200]  # Limit length
                print(f"  {generated_part}")
            
            # 2. Bayesian generation with multiple samples using posteriors
            print(f"\nBayesian Generation ({CONFIG['num_samples']} samples via posteriors):")
            bayesian_generations = []
            
            for sample_idx in range(CONFIG['num_samples']):
                # Sample parameters from posterior using posteriors
                try:
                    sampled_params = posteriors.vi.diag.sample(vi_state)
                except Exception as e:
                    print(f"Error sampling from posteriors, using mean: {e}")
                    # Fallback to using mean parameters
                    sampled_params = vi_state.params
                
                # Generate with sampled parameters
                with torch.no_grad():
                    bayesian_x = x.clone()
                    for _ in range(CONFIG['max_new_tokens']):
                        # Crop if sequence gets too long
                        x_cond = bayesian_x if bayesian_x.size(1) <= model.config.block_size else bayesian_x[:, -model.config.block_size:]
                        
                        # Forward pass with sampled parameters
                        logits, _ = func.functional_call(model, sampled_params, (x_cond,))
                        logits = logits[:, -1, :] / CONFIG['generation_temperature']
                        
                        # Sample next token
                        probs = F.softmax(logits, dim=-1)
                        next_token = torch.multinomial(probs, num_samples=1)
                        bayesian_x = torch.cat((bayesian_x, next_token), dim=1)
                    
                    # Decode generated text
                    bayesian_tokens = bayesian_x[0].tolist()
                    bayesian_text = decode(bayesian_tokens, itos)
                    generated_part = bayesian_text[len(prompt):].strip()[:200]  # Limit length
                    bayesian_generations.append(generated_part)
                    print(f"  Sample {sample_idx + 1}: {generated_part}")
            
            # 3. Analyze diversity using posteriors samples
            print(f"\nAnalysis (posteriors-based):")
            if len(bayesian_generations) > 1:
                # Calculate diversity metrics
                unique_starts = set()
                for gen in bayesian_generations:
                    words = gen.split()[:5]  # First 5 words
                    if words:
                        unique_starts.add(' '.join(words))
                
                diversity = len(unique_starts) / len(bayesian_generations)
                print(f"  Diversity (unique 5-word starts): {diversity:.2f}")
                
                # Character-level diversity
                avg_length = np.mean([len(gen) for gen in bayesian_generations])
                std_length = np.std([len(gen) for gen in bayesian_generations])
                print(f"  Average length: {avg_length:.1f} ± {std_length:.1f} chars")
                
                # Calculate generation uncertainty
                if len(set(bayesian_generations)) > 1:
                    unique_ratio = len(set(bayesian_generations)) / len(bayesian_generations)
                    print(f"  Uniqueness ratio: {unique_ratio:.2f}")
                else:
                    print(f"  All generations identical (low uncertainty)")
            
        except Exception as e:
            print(f"Error generating text for prompt '{prompt}': {e}")
            continue
    
else:
    print("Cannot generate text - Bayesian model not available")

## 10. Uncertainty Quantification Analysis

In [None]:
# Analyze uncertainty in predictions using posteriors
if vi_state is not None and model is not None:
    print("Uncertainty Quantification Analysis (posteriors)")
    print("=" * 50)
    
    # Test sequence for uncertainty analysis
    test_sequence = "To be or not to be, that is the"
    test_tokens = encode(test_sequence, stoi)
    x_test = torch.tensor(test_tokens, dtype=torch.long, device=device)[None, ...]
    
    print(f"Test sequence: '{test_sequence}'")
    print(f"Analyzing next token predictions using posteriors...")
    
    # Collect predictions from multiple posterior samples
    next_token_logits = []
    
    with torch.no_grad():
        for sample_idx in range(20):  # More samples for better uncertainty estimation
            try:
                # Sample parameters from posterior using posteriors
                sampled_params = posteriors.vi.diag.sample(vi_state)
                
                # Forward pass with sampled parameters
                logits, _ = func.functional_call(model, sampled_params, (x_test,))
                next_token_logit = logits[:, -1, :]  # Last token predictions
                next_token_logits.append(next_token_logit.cpu())
                
            except Exception as e:
                print(f"Error in sample {sample_idx}: {e}")
                continue
    
    if next_token_logits:
        # Stack predictions
        logits_stack = torch.stack(next_token_logits, dim=0)  # [n_samples, 1, vocab_size]
        
        # Convert to probabilities
        probs_stack = F.softmax(logits_stack, dim=-1)  # [n_samples, 1, vocab_size]
        
        # Calculate statistics
        mean_probs = probs_stack.mean(dim=0)[0]  # [vocab_size]
        std_probs = probs_stack.std(dim=0)[0]    # [vocab_size]
        
        # Find top predictions with uncertainty
        top_k = 10
        top_indices = mean_probs.argsort(descending=True)[:top_k]
        
        print(f"\nTop {top_k} next token predictions with uncertainty (posteriors):")
        print("-" * 60)
        print(f"{'Rank':<4} {'Token':<15} {'Mean Prob':<10} {'Std Prob':<10} {'CV':<8}")
        print("-" * 60)
        
        for rank, idx in enumerate(top_indices):
            token_idx = idx.item()
            if token_idx < len(itos):
                token = itos[token_idx]
                mean_p = mean_probs[token_idx].item()
                std_p = std_probs[token_idx].item()
                cv = std_p / mean_p if mean_p > 0 else float('inf')  # Coefficient of variation
                
                print(f"{rank+1:<4} '{token}':<15} {mean_p:.6f} {std_p:.6f} {cv:.3f}")
        
        # Overall uncertainty metrics
        entropy = -(mean_probs * torch.log(mean_probs + 1e-10)).sum()
        pred_uncertainty = std_probs.mean()
        max_prob = mean_probs.max()
        
        print(f"\nUncertainty Metrics (posteriors-based):")
        print(f"  Predictive Entropy: {entropy:.4f}")
        print(f"  Average Std: {pred_uncertainty:.6f}")
        print(f"  Max Probability: {max_prob:.6f}")
        print(f"  Confidence: {1 - entropy/torch.log(torch.tensor(len(itos), dtype=torch.float)):.4f}")
        
        # Extract parameter uncertainty from posteriors VI state
        print(f"\nParameter Uncertainty (from posteriors VI):")
        try:
            if hasattr(vi_state, 'log_scale'):
                param_uncertainties = {}
                for name, log_scale in vi_state.log_scale.items():
                    param_std = torch.exp(log_scale).mean().item()
                    param_uncertainties[name] = param_std
                
                # Show top 5 most uncertain parameters
                sorted_params = sorted(param_uncertainties.items(), key=lambda x: x[1], reverse=True)[:5]
                for name, std in sorted_params:
                    print(f"  {name}: {std:.6f}")
                    
                avg_param_uncertainty = np.mean(list(param_uncertainties.values()))
                print(f"  Average parameter uncertainty: {avg_param_uncertainty:.6f}")
            else:
                print("  Parameter uncertainties not directly accessible")
        except Exception as e:
            print(f"  Error extracting parameter uncertainties: {e}")
        
        # Visualize uncertainty if matplotlib is available
        try:
            import matplotlib.pyplot as plt
            
            # Plot top predictions with uncertainty bars
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            
            # Top predictions with error bars
            top_5_indices = top_indices[:5]
            top_5_tokens = [itos[idx.item()] if idx.item() < len(itos) else f"[{idx.item()}]" 
                           for idx in top_5_indices]
            top_5_means = [mean_probs[idx].item() for idx in top_5_indices]
            top_5_stds = [std_probs[idx].item() for idx in top_5_indices]
            
            ax1.bar(range(5), top_5_means, yerr=top_5_stds, capsize=5, alpha=0.7)
            ax1.set_xlabel('Token Rank')
            ax1.set_ylabel('Probability')
            ax1.set_title('Top 5 Predictions with Uncertainty (posteriors)')
            ax1.set_xticks(range(5))
            ax1.set_xticklabels([f"'{token}'" for token in top_5_tokens], rotation=45)
            ax1.grid(True, alpha=0.3)
            
            # Uncertainty distribution
            uncertainty_values = std_probs[mean_probs > 0.001]  # Only consider tokens with some probability
            ax2.hist(uncertainty_values.numpy(), bins=30, alpha=0.7, edgecolor='black')
            ax2.set_xlabel('Standard Deviation')
            ax2.set_ylabel('Frequency')
            ax2.set_title('Distribution of Prediction Uncertainties (posteriors)')
            ax2.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
            
        except ImportError:
            print("Matplotlib not available for visualization")
        
    else:
        print("No valid predictions collected for uncertainty analysis")
        
else:
    print("Cannot analyze uncertainty - Bayesian model not available")