In [38]:
# Bayesian NanoGPT with Posteriors Library
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import posteriors
import torchopt
from pathlib import Path
from typing import Dict, Optional, Tuple, List
import sys

# Set up paths and import config
sys.path.append(str(Path().resolve().parent))
import config
from utils import load_model, load_tokenizer, load_shakespeare_dataset, generate_text

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"Posteriors available: {posteriors.__version__}")

# Set random seeds for reproducibility
torch.manual_seed(509)
np.random.seed(509)

Using device: cpu
PyTorch version: 2.8.0+cpu
Posteriors available: 0.1.1


In [39]:
# Load model, tokenizer, and dataset using config paths
print("Loading configuration and paths from config.py...")
print(f"Base directory: {config.BASE_DIR}")
print(f"Model path: {config.MODEL_PATH}")
print(f"Meta path: {config.META_PATH}")
print(f"Dataset path: {config.DATASET_PATH}")

# Load the pre-trained NanoGPT model
print("\n" + "="*50)
print("LOADING NANOGPT MODEL")
print("="*50)
model, checkpoint = load_model(config.MODEL_PATH, device=device)
model.eval()

# Load tokenizer
print("\n" + "="*50)
print("LOADING TOKENIZER")
print("="*50)
stoi, itos = load_tokenizer(config.META_PATH)
vocab_size = len(itos)
print(f"Vocabulary size: {vocab_size}")

# Load Shakespeare dataset
print("\n" + "="*50)
print("LOADING SHAKESPEARE DATASET")
print("="*50)
full_text, prompts, references = load_shakespeare_dataset(config.DATASET_PATH)

print(f"\nDataset loaded successfully!")
print(f"Total text length: {len(full_text):,} characters")
print(f"Number of test prompts: {len(prompts)}")
print(f"Number of references: {len(references)}")

# Print model information
print("\n" + "="*50)
print("MODEL INFORMATION")
print("="*50)
print(f"Model architecture: {model}")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Block size: {model.config.block_size}")
print(f"Number of layers: {model.config.n_layer}")
print(f"Number of heads: {model.config.n_head}")
print(f"Embedding dimension: {model.config.n_embd}")

if 'iter_num' in checkpoint:
    print(f"Training iterations: {checkpoint['iter_num']}")
if 'best_val_loss' in checkpoint:
    print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")

Loading configuration and paths from config.py...
Base directory: C:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen
Model path: C:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen\checkpoints\baseline_nanogpt\baseline_nanogpt.pt
Meta path: C:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen\checkpoints\baseline_nanogpt\nanogpt_meta.pkl
Dataset path: C:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen\baselines\nanogpt\dataset.txt

LOADING NANOGPT MODEL
Loading model from: C:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen\checkpoints\baseline_nanogpt\baseline_nanogpt.pt
Model arguments: {'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'block_size': 256, 'bias': False, 'vocab_size': 65, 'dropout': 0.2}
number of parameters: 10.65M
Model loaded successfully!
Number of parameters: 10,745,088

LOADING TOKENIZER
Vocabulary size: 65

LOADING SHAKESPEARE DATASET
Successful

# Bayesian Neural Network Text Generation with Posteriors

This notebook demonstrates how to convert a pre-trained NanoGPT model into a Bayesian neural network using the **posteriors** library. We'll explore different posterior approximation methods and analyze uncertainty in text generation.

## Overview

- **Model**: Pre-trained character-level NanoGPT on Shakespeare text
- **Methods**: Laplace approximation, Variational Inference, SGMCMC
- **Goal**: Quantify uncertainty in text generation and compare different Bayesian approaches

Let's start by setting up the environment and loading our pre-trained model.

In [None]:
START_PROMPT = "To be or not to be"

start_ids = encode(START_PROMPT)
x = torch.tensor(start_ids, dtype=torch.long, device=DEVICE)[None, ...]

# Generating text samples
with torch.no_grad():
    y = model.generate(x, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_k=TOP_K)
    print(decode(y[0].tolist()))
    print()


In [27]:
# Test text generation with the deterministic model
print("="*60)
print("DETERMINISTIC TEXT GENERATION (BASELINE)")
print("="*60)

test_prompts = [
    "To be, or not to be",
    "All the world's a stage",
    "What light through yonder window breaks?"
]

print("Generating text with deterministic model...")
deterministic_outputs = []

for i, prompt in enumerate(test_prompts):
    print(f"\nPrompt {i+1}: '{prompt}'")
    generated = generate_text(
        model, 
        prompt, 
        stoi, 
        itos, 
        max_new_tokens=30,
        temperature=0.9,
        top_k=40,
        device=device
    )
    deterministic_outputs.append(generated)
    print(f"Generated: '{generated[len(prompt):].strip()}'")

print(f"\nCompleted deterministic text generation for {len(test_prompts)} prompts.")

DETERMINISTIC TEXT GENERATION (BASELINE)
Generating text with deterministic model...

Prompt 1: 'To be, or not to be'
Generated: 'past.

NERRSAN:
The groan:
No'

Prompt 2: 'All the world's a stage'
Generated: 'past.

NERRSAN:
The groan:
No'

Prompt 2: 'All the world's a stage'
Generated: ';
And, that my gentleman have'

Prompt 3: 'What light through yonder window breaks?'
Generated: ';
And, that my gentleman have'

Prompt 3: 'What light through yonder window breaks?'
Generated: 'When have I truep it?

CLIFFO'

Completed deterministic text generation for 3 prompts.
Generated: 'When have I truep it?

CLIFFO'

Completed deterministic text generation for 3 prompts.


## Setting up Bayesian Inference

Now we'll convert our deterministic NanoGPT model into a Bayesian neural network using the posteriors library. We'll implement three different approaches:

1. **Laplace Approximation**: Quick Gaussian approximation around the MAP estimate
2. **Variational Inference**: Learn a parameterized posterior distribution
3. **SGMCMC**: Sample from the true posterior using stochastic gradient MCMC

Each method has different computational costs and approximation quality trade-offs.

In [31]:
# Set up the log posterior function for posteriors library
def prepare_data_for_posteriors(dataset_text: str, tokenizer, block_size: int = 256, 
                               batch_size: int = 32, num_batches: int = 10):
    """
    Prepare training data batches for posterior inference.
    
    Args:
        dataset_text: Full Shakespeare text
        tokenizer: (stoi, itos) tuple
        block_size: Context length
        batch_size: Batch size
        num_batches: Number of batches to create
        
    Returns:
        List of (input, target) batches
    """
    stoi, itos = tokenizer
    
    # Encode the full text
    encoded_text = [stoi.get(c, 0) for c in dataset_text]  # Use 0 for unknown chars
    
    batches = []
    for _ in range(num_batches):
        # Random starting positions
        start_indices = torch.randint(0, len(encoded_text) - block_size - 1, (batch_size,))
        
        batch_x = torch.stack([
            torch.tensor(encoded_text[start:start + block_size], dtype=torch.long)
            for start in start_indices
        ])
        
        batch_y = torch.stack([
            torch.tensor(encoded_text[start + 1:start + block_size + 1], dtype=torch.long)
            for start in start_indices
        ])
        
        batches.append((batch_x.to(device), batch_y.to(device)))
    
    return batches

# Prepare data batches
print("Preparing data batches for Bayesian inference...")
data_batches = prepare_data_for_posteriors(
    full_text, 
    (stoi, itos), 
    block_size=model.config.block_size,
    batch_size=16,  # Smaller batch size for memory efficiency
    num_batches=20
)

print(f"Created {len(data_batches)} batches for posterior inference")
print(f"Batch shape: {data_batches[0][0].shape} -> {data_batches[0][1].shape}")

# Define log posterior function for posteriors library
def log_posterior(params, batch):
    """
    Compute log posterior for a batch of data.
    
    Args:
        params: Model parameters (dict)
        batch: (input, target) batch
        
    Returns:
        log_posterior_value: Scalar tensor
        model_output: For auxiliary information
    """
    inputs, targets = batch
    
    # Use functional API to compute model output with given parameters
    outputs, _ = torch.func.functional_call(model, params, (inputs,))
    
    # Compute log likelihood (negative cross entropy)
    log_likelihood = -F.cross_entropy(
        outputs.view(-1, outputs.size(-1)), 
        targets.view(-1),
        reduction='mean'
    )
    
    # Add log prior (simple Gaussian prior)
    log_prior = posteriors.diag_normal_log_prob(params, sigma=1.0) / len(data_batches)
    
    log_post = log_likelihood + log_prior
    
    return log_post, outputs

print("Log posterior function defined successfully!")

Fixed log posterior function defined!
Testing fixed log posterior function...
  Log posterior value: -713.5050
  Output shape: torch.Size([4, 1, 65])
  ‚úÖ Log posterior function working correctly!


## Laplace Approximation for Bayesian Neural Networks

### What is Laplace Approximation?

The **Laplace approximation** is a method for approximating complex posterior distributions with a Gaussian distribution. It's particularly useful for Bayesian neural networks because:

1. **Computationally Efficient**: Only requires computing the Hessian at the MAP (Maximum A Posteriori) estimate
2. **Post-hoc Method**: Can be applied to any pre-trained neural network
3. **Analytical Uncertainty**: Provides closed-form uncertainty estimates

### Mathematical Foundation

Given a neural network with parameters Œ∏ and data D, the posterior distribution is:

```
p(Œ∏|D) ‚àù p(D|Œ∏) √ó p(Œ∏)
```

The Laplace approximation approximates this posterior as a Gaussian centered at the MAP estimate Œ∏*:

```
p(Œ∏|D) ‚âà N(Œ∏*, Œ£)
```

Where:
- **Œ∏*** = argmax p(Œ∏|D) (the MAP estimate, i.e., our pre-trained weights)
- **Œ£** = H‚Åª¬π (inverse Hessian of the negative log posterior at Œ∏*)

### Hessian Approximations

Computing the full Hessian is expensive, so we use approximations:

1. **Fisher Information Matrix**: Uses first-order gradients only
   - `H ‚âà E[‚àá log p(D|Œ∏) ‚àá log p(D|Œ∏)·µÄ]`
   - More stable and faster to compute

2. **Diagonal Approximation**: Assumes parameter independence
   - Only computes diagonal elements of the Hessian
   - Much more memory efficient

### Implementation with Posteriors

The `posteriors` library provides several Laplace variants:
- `posteriors.laplace.diag_fisher`: Diagonal Fisher information matrix
- `posteriors.laplace.dense_fisher`: Full Fisher information matrix  
- `posteriors.laplace.diag_ggn`: Diagonal Gauss-Newton approximation

We'll use the **diagonal Fisher** approach for computational efficiency while still capturing parameter uncertainties.

In [34]:
# Implement Laplace Approximation using Diagonal Fisher Information
print("="*60)
print("LAPLACE APPROXIMATION - DIAGONAL FISHER")
print("="*60)

# Get model parameters as a dictionary
model_params = dict(model.named_parameters())
print(f"Number of parameter tensors: {len(model_params)}")
total_params = sum(p.numel() for p in model_params.values())
print(f"Total parameters: {total_params:,}")

# Initialize Laplace transform with diagonal Fisher information
print("\nInitializing Laplace approximation...")
laplace_transform = posteriors.laplace.diag_fisher.build(
    log_posterior,
    len(data_batches)  # Number of data points for proper scaling
)

# Initialize the Laplace state
print("Initializing Laplace state...")
laplace_state = laplace_transform.init(model_params)

print(f"Laplace state initialized with keys: {list(laplace_state._fields)}")
print(f"State params shape info:")
for name, param in laplace_state.params.items():
    if hasattr(param, 'shape'):
        print(f"  {name}: {param.shape}")

# Fit the Laplace approximation by computing Fisher information
print("\nFitting Laplace approximation (computing Fisher information)...")
print("This may take a few minutes...")

fit_start_time = torch.cuda.Event(enable_timing=True) if device.type == 'cuda' else None
fit_end_time = torch.cuda.Event(enable_timing=True) if device.type == 'cuda' else None

if fit_start_time:
    fit_start_time.record()

# Process batches to compute Fisher information
for i, batch in enumerate(data_batches):
    print(f"Processing batch {i+1}/{len(data_batches)}", end='\r')
    laplace_state, aux = laplace_transform.update(laplace_state, batch)

if fit_end_time:
    fit_end_time.record()
    torch.cuda.synchronize()
    fit_time = fit_start_time.elapsed_time(fit_end_time) / 1000.0  # Convert to seconds
    print(f"\nLaplace fitting completed in {fit_time:.2f} seconds")
else:
    print(f"\nLaplace fitting completed!")

# Examine the fitted Laplace approximation
print(f"\nFinal Laplace state:")
print(f"  Fitted parameters available: {hasattr(laplace_state, 'params')}")
print(f"  Fisher information available: {hasattr(laplace_state, 'fisher')}")

if hasattr(laplace_state, 'fisher'):
    print(f"Fisher information computed for {len(laplace_state.fisher)} parameter groups")
    
    # Print some statistics about the Fisher information
    fisher_norms = {}
    for name, fisher_diag in laplace_state.fisher.items():
        if hasattr(fisher_diag, 'norm'):
            fisher_norms[name] = fisher_diag.norm().item()
        
    if fisher_norms:
        print(f"Fisher diagonal norms:")
        for name, norm in list(fisher_norms.items())[:5]:  # Show first 5
            print(f"  {name}: {norm:.6f}")
        if len(fisher_norms) > 5:
            print(f"  ... and {len(fisher_norms) - 5} more layers")

print("\nLaplace approximation fitted successfully! üéâ")

LAPLACE APPROXIMATION - DIAGONAL FISHER
Number of parameter tensors: 39
Total parameters: 10,745,088
Available Laplace methods:
  - dense_fisher
  - dense_ggn
  - dense_hessian
  - diag_fisher
  - diag_ggn

Initializing Laplace approximation...
Using posteriors.laplace.diag_fisher directly
‚úÖ Laplace transform initialized!
Testing parameter initialization...
‚úÖ State initialized with init method


In [37]:
# Function to sample from the Laplace posterior
def sample_laplace_weights(laplace_state, num_samples=5):
    """
    Sample parameter sets from the Laplace posterior.
    
    Args:
        laplace_state: Fitted Laplace state
        num_samples: Number of parameter samples to draw
        
    Returns:
        List of parameter dictionaries sampled from posterior
    """
    samples = []
    
    print(f"Sampling {num_samples} parameter sets from Laplace posterior...")
    
    for i in range(num_samples):
        # Sample from the Gaussian posterior
        sampled_params = {}
        
        for name, mean_param in laplace_state.params.items():
            if name in laplace_state.fisher:
                # Get diagonal Fisher information (precision)
                fisher_diag = laplace_state.fisher[name]
                
                # Convert Fisher information to standard deviation
                # Fisher is precision (inverse variance), so std = 1/sqrt(Fisher)
                std = 1.0 / torch.sqrt(fisher_diag + 1e-8)  # Add small epsilon for numerical stability
                
                # Sample from Gaussian: N(mean, std¬≤)
                noise = torch.randn_like(mean_param) * std
                sampled_param = mean_param + noise
            else:
                # If no Fisher information, just use the MAP estimate
                sampled_param = mean_param.clone()
            
            sampled_params[name] = sampled_param
        
        samples.append(sampled_params)
        print(f"  Sample {i+1}/{num_samples} generated")
    
    return samples

# Sample from the Laplace posterior
print("="*60)
print("SAMPLING FROM LAPLACE POSTERIOR")
print("="*60)

num_posterior_samples = 8
laplace_samples = sample_laplace_weights(laplace_state, num_posterior_samples)

print(f"\nSuccessfully generated {len(laplace_samples)} parameter samples!")

# Verify sample structure
sample_keys = list(laplace_samples[0].keys())
print(f"Each sample contains {len(sample_keys)} parameter tensors")
print(f"Sample parameter names: {sample_keys[:5]}{'...' if len(sample_keys) > 5 else ''}")

# Check parameter differences between samples
print(f"\nParameter variation analysis:")
if len(laplace_samples) > 1:
    total_variation = 0
    param_count = 0
    
    for name in sample_keys[:3]:  # Check first 3 layers
        diff = (laplace_samples[1][name] - laplace_samples[0][name]).norm().item()
        param_size = laplace_samples[0][name].numel()
        relative_diff = diff / torch.norm(laplace_samples[0][name]).item()
        
        print(f"  {name}: ||ŒîŒ∏|| = {diff:.6f}, relative = {relative_diff:.6f}")
        total_variation += diff
        param_count += param_size
    
    print(f"  Average parameter variation: {total_variation/len(sample_keys[:3]):.6f}")

print("\nLaplace posterior sampling complete! üé≤")

Setting up Laplace approximation with correct API...
Fitting Laplace approximation...
Processing batch with 4 samples...
‚ùå Error with posteriors API: 'NoneType' object has no attribute 'params'

FALLBACK: Manual Diagonal Fisher Implementation
Computing diagonal Fisher information manually...
‚úÖ Manual Laplace approximation created!
  Parameters: 39 tensors
  Fisher diagonal: 39 tensors
  Fisher statistics (first 3 layers):
    transformer.wte.weight: mean=0.000027, max=0.000501
    transformer.wpe.weight: mean=0.000013, max=0.000798
    transformer.h.0.ln_1.weight: mean=0.008387, max=0.010461

‚úÖ Laplace approximation setup complete!


Traceback (most recent call last):
  File "C:\Users\hayk_\AppData\Local\Temp\ipykernel_23412\2215269773.py", line 27, in <module>
    laplace_state = laplace_transform.update(laplace_state, batch_log_posterior, test_batch)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen\bnn\Lib\site-packages\posteriors\laplace\diag_fisher.py", line 125, in update
    jac, aux = jacrev(log_posterior, has_aux=True)(state.params, batch)
                                                   ^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'params'


In [None]:
# Evaluate Bayesian Text Generation with Laplace Samples
print("="*60)
print("BAYESIAN TEXT GENERATION EVALUATION")
print("="*60)

def generate_text_with_sampled_params(model, params_dict, prompt, stoi, itos, 
                                    max_new_tokens=30, temperature=0.8, top_k=40, device='cpu'):
    """
    Generate text using a specific parameter sample.
    
    Args:
        model: Base model architecture
        params_dict: Sampled parameters
        prompt: Input text prompt
        stoi, itos: Tokenizer
        max_new_tokens: Length of generation
        temperature: Sampling temperature
        top_k: Top-k sampling
        device: Device for computation
        
    Returns:
        Generated text string
    """
    # Encode prompt
    encoded_prompt = [stoi.get(c, 0) for c in prompt]
    x = torch.tensor(encoded_prompt, dtype=torch.long, device=device)[None, ...]
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Use functional call with sampled parameters
            logits, _ = torch.func.functional_call(model, params_dict, (x,))
            logits = logits[:, -1, :] / temperature
            
            # Apply top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            
            # Sample next token
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, next_token), dim=1)
    
    # Decode generated sequence
    generated_tokens = x[0].tolist()
    return ''.join([itos.get(i, '?') for i in generated_tokens])

# Test Bayesian generation with multiple samples
test_prompt = "To be, or not to be"
print(f"Generating text with Bayesian NanoGPT...")
print(f"Prompt: '{test_prompt}'")
print(f"Using {len(laplace_samples)} posterior samples\n")

bayesian_outputs = []
generation_times = []

for i, params_sample in enumerate(laplace_samples):
    print(f"Sample {i+1}/{len(laplace_samples)}:")
    
    start_time = torch.cuda.Event(enable_timing=True) if device.type == 'cuda' else None
    end_time = torch.cuda.Event(enable_timing=True) if device.type == 'cuda' else None
    
    if start_time:
        start_time.record()
    
    generated = generate_text_with_sampled_params(
        model, params_sample, test_prompt, stoi, itos,
        max_new_tokens=25, temperature=0.8, top_k=40, device=device
    )
    
    if end_time:
        end_time.record()
        torch.cuda.synchronize()
        gen_time = start_time.elapsed_time(end_time)
        generation_times.append(gen_time)
    
    bayesian_outputs.append(generated)
    generated_part = generated[len(test_prompt):].strip()
    print(f"  Generated: '{generated_part}'")

if generation_times:
    avg_time = sum(generation_times) / len(generation_times)
    print(f"\nAverage generation time: {avg_time:.2f}ms")

print(f"\nBayesian text generation complete!")
print(f"Generated {len(bayesian_outputs)} diverse outputs from posterior samples.")

In [None]:
# Uncertainty Quantification Analysis
print("="*60)
print("UNCERTAINTY QUANTIFICATION ANALYSIS")
print("="*60)

def compute_predictive_uncertainty(model, laplace_samples, test_inputs, max_seq_len=20):
    """
    Compute predictive uncertainty for multiple test inputs.
    
    Args:
        model: Base model
        laplace_samples: List of parameter samples
        test_inputs: List of encoded input sequences
        max_seq_len: Maximum sequence length to analyze
        
    Returns:
        Dictionary with uncertainty metrics
    """
    uncertainties = []
    
    for input_seq in test_inputs:
        x = torch.tensor(input_seq, dtype=torch.long, device=device)[None, :]
        
        # Collect predictions from all samples
        all_logits = []
        
        for params_sample in laplace_samples:
            with torch.no_grad():
                logits, _ = torch.func.functional_call(model, params_sample, (x,))
                # Take last position logits
                last_logits = logits[:, -1, :]
                all_logits.append(last_logits)
        
        # Stack all predictions: [num_samples, vocab_size]
        logits_stack = torch.stack(all_logits, dim=0).squeeze(1)
        
        # Convert to probabilities
        probs_stack = F.softmax(logits_stack, dim=-1)
        
        # Compute uncertainty metrics
        mean_probs = probs_stack.mean(dim=0)  # [vocab_size]
        
        # Epistemic uncertainty (variance across samples)
        epistemic = probs_stack.var(dim=0).sum().item()
        
        # Entropy of mean prediction
        entropy_mean = -(mean_probs * torch.log(mean_probs + 1e-8)).sum().item()
        
        # Mean entropy across samples (aleatoric + epistemic)
        entropies = -(probs_stack * torch.log(probs_stack + 1e-8)).sum(dim=-1)
        mean_entropy = entropies.mean().item()
        
        # Mutual information (epistemic uncertainty)
        mutual_info = mean_entropy - entropy_mean
        
        uncertainties.append({
            'epistemic_variance': epistemic,
            'entropy_of_mean': entropy_mean,
            'mean_entropy': mean_entropy,
            'mutual_information': mutual_info,
            'input_length': len(input_seq)
        })
    
    return uncertainties

# Prepare test inputs of varying lengths and complexity
test_sequences = [
    "To be",
    "Romeo, Romeo",
    "All the world's a stage",
    "What light through yonder window",
    "Now is the winter of our discontent"
]

print("Computing predictive uncertainties...")
encoded_test_seqs = []
for seq in test_sequences:
    encoded = [stoi.get(c, 0) for c in seq]
    encoded_test_seqs.append(encoded)
    print(f"  '{seq}' -> length {len(encoded)}")

uncertainties = compute_predictive_uncertainty(
    model, laplace_samples[:5], encoded_test_seqs  # Use first 5 samples for speed
)

# Display uncertainty analysis
print(f"\nUncertainty Analysis Results:")
print("-" * 60)
print(f"{'Prompt':<30} {'Epistemic':<12} {'Entropy':<10} {'Mutual Info':<12}")
print("-" * 60)

for i, (seq, unc) in enumerate(zip(test_sequences, uncertainties)):
    print(f"{seq[:28]:<30} {unc['epistemic_variance']:<12.4f} "
          f"{unc['entropy_of_mean']:<10.4f} {unc['mutual_information']:<12.4f}")

# Summary statistics
epistemic_values = [u['epistemic_variance'] for u in uncertainties]
entropy_values = [u['entropy_of_mean'] for u in uncertainties]
mi_values = [u['mutual_information'] for u in uncertainties]

print("-" * 60)
print(f"{'AVERAGE':<30} {np.mean(epistemic_values):<12.4f} "
      f"{np.mean(entropy_values):<10.4f} {np.mean(mi_values):<12.4f}")
print(f"{'STD DEV':<30} {np.std(epistemic_values):<12.4f} "
      f"{np.std(entropy_values):<10.4f} {np.std(mi_values):<12.4f}")

print(f"\nKey Insights:")
print(f"  ‚Ä¢ Higher epistemic variance indicates model uncertainty about predictions")
print(f"  ‚Ä¢ Mutual information quantifies information gain from additional samples")
print(f"  ‚Ä¢ Entropy measures overall prediction uncertainty")

print(f"\nUncertainty quantification analysis complete! üìä")

In [None]:
# Visualization and Comparison Analysis
print("="*60)
print("VISUALIZATION AND COMPARISON ANALYSIS")
print("="*60)

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")

# Create comprehensive comparison plots
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Laplace Approximation: Bayesian NanoGPT Analysis', fontsize=16, fontweight='bold')

# Plot 1: Uncertainty vs Input Length
ax1 = axes[0, 0]
input_lengths = [len(seq) for seq in test_sequences]
epistemic_vars = [u['epistemic_variance'] for u in uncertainties]
entropies = [u['entropy_of_mean'] for u in uncertainties]

ax1.scatter(input_lengths, epistemic_vars, label='Epistemic Variance', alpha=0.7, s=100)
ax1.scatter(input_lengths, entropies, label='Entropy of Mean', alpha=0.7, s=100)
ax1.set_xlabel('Input Sequence Length')
ax1.set_ylabel('Uncertainty')
ax1.set_title('Uncertainty vs Input Length')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Fisher Information Distribution (sample a few parameters)
ax2 = axes[0, 1]
sample_fisher_values = []
if hasattr(laplace_state, 'fisher'):
    for name, fisher_diag in list(laplace_state.fisher.items())[:3]:
        if hasattr(fisher_diag, 'flatten'):
            values = fisher_diag.flatten().cpu().numpy()
            sample_fisher_values.extend(values[:1000])  # Sample first 1000 values

if sample_fisher_values:
    ax2.hist(sample_fisher_values, bins=50, alpha=0.7, density=True)
    ax2.set_xlabel('Fisher Information Value')
    ax2.set_ylabel('Density')
    ax2.set_title('Fisher Information Distribution')
    ax2.set_yscale('log')
else:
    ax2.text(0.5, 0.5, 'Fisher Info\nNot Available', 
             transform=ax2.transAxes, ha='center', va='center')
ax2.grid(True, alpha=0.3)

# Plot 3: Parameter Uncertainty Comparison
ax3 = axes[0, 2]
if len(laplace_samples) > 1:
    param_variations = []
    param_names = []
    
    for name in list(laplace_samples[0].keys())[:5]:  # First 5 layers
        variations = []
        base_param = laplace_samples[0][name]
        
        for sample in laplace_samples[1:4]:  # Compare with next 3 samples
            diff = (sample[name] - base_param).norm().item()
            relative_diff = diff / base_param.norm().item()
            variations.append(relative_diff)
        
        param_variations.append(variations)
        param_names.append(name.split('.')[-1][:10])  # Short name
    
    # Box plot of parameter variations
    ax3.boxplot(param_variations, labels=param_names)
    ax3.set_ylabel('Relative Parameter Variation')
    ax3.set_title('Parameter Uncertainty Across Layers')
    ax3.tick_params(axis='x', rotation=45)
else:
    ax3.text(0.5, 0.5, 'Need >1 Sample\nfor Comparison', 
             transform=ax3.transAxes, ha='center', va='center')
ax3.grid(True, alpha=0.3)

# Plot 4: Text Generation Diversity
ax4 = axes[1, 0]
prompt_for_analysis = "To be, or not to be"

# Analyze character diversity in generated text
char_diversities = []
for output in bayesian_outputs:
    generated_part = output[len(prompt_for_analysis):].strip()
    unique_chars = len(set(generated_part))
    total_chars = len(generated_part)
    diversity = unique_chars / max(total_chars, 1)
    char_diversities.append(diversity)

sample_indices = range(1, len(char_diversities) + 1)
ax4.bar(sample_indices, char_diversities, alpha=0.7)
ax4.axhline(np.mean(char_diversities), color='red', linestyle='--', 
            label=f'Mean: {np.mean(char_diversities):.3f}')
ax4.set_xlabel('Posterior Sample')
ax4.set_ylabel('Character Diversity Ratio')
ax4.set_title('Text Generation Diversity')
ax4.legend()
ax4.grid(True, alpha=0.3)

# Plot 5: Uncertainty Metrics Comparison
ax5 = axes[1, 1]
metric_names = ['Epistemic\nVariance', 'Entropy\nof Mean', 'Mutual\nInformation']
metric_values = [
    np.mean(epistemic_values),
    np.mean(entropy_values), 
    np.mean(mi_values)
]
metric_stds = [
    np.std(epistemic_values),
    np.std(entropy_values),
    np.std(mi_values)
]

bars = ax5.bar(metric_names, metric_values, yerr=metric_stds, 
               capsize=5, alpha=0.7, color=['skyblue', 'lightcoral', 'lightgreen'])
ax5.set_ylabel('Uncertainty Value')
ax5.set_title('Average Uncertainty Metrics')
ax5.grid(True, alpha=0.3)

# Add value labels on bars
for bar, value in zip(bars, metric_values):
    height = bar.get_height()
    ax5.text(bar.get_x() + bar.get_width()/2., height + 0.01,
             f'{value:.3f}', ha='center', va='bottom')

# Plot 6: Performance Summary
ax6 = axes[1, 2]
if generation_times:
    avg_gen_time = np.mean(generation_times)
    std_gen_time = np.std(generation_times)
else:
    avg_gen_time = 0
    std_gen_time = 0

performance_metrics = [
    ('Samples', len(laplace_samples)),
    ('Avg Gen Time (ms)', avg_gen_time),
    ('Uncertainty Range', max(epistemic_values) - min(epistemic_values)),
    ('Diversity Score', np.mean(char_diversities))
]

y_pos = range(len(performance_metrics))
values = [metric[1] for metric in performance_metrics]
normalized_values = [v / max(values) for v in values]  # Normalize for comparison

bars = ax6.barh(y_pos, normalized_values, alpha=0.7)
ax6.set_yticks(y_pos)
ax6.set_yticklabels([metric[0] for metric in performance_metrics])
ax6.set_xlabel('Normalized Score')
ax6.set_title('Performance Summary')

# Add actual values as text
for i, (bar, (name, value)) in enumerate(zip(bars, performance_metrics)):
    width = bar.get_width()
    if 'Time' in name:
        text = f'{value:.1f}'
    elif 'Samples' in name:
        text = f'{int(value)}'
    else:
        text = f'{value:.3f}'
    ax6.text(width + 0.01, bar.get_y() + bar.get_height()/2,
             text, ha='left', va='center')

plt.tight_layout()
plt.show()

# Print comprehensive summary
print(f"\n" + "="*60)
print("LAPLACE APPROXIMATION SUMMARY")
print("="*60)
print(f"‚úÖ Fitted diagonal Fisher information matrix")
print(f"‚úÖ Generated {len(laplace_samples)} posterior parameter samples")
print(f"‚úÖ Performed Bayesian text generation")
print(f"‚úÖ Quantified predictive uncertainty")

print(f"\nKey Results:")
print(f"  ‚Ä¢ Average epistemic uncertainty: {np.mean(epistemic_values):.4f}")
print(f"  ‚Ä¢ Average prediction entropy: {np.mean(entropy_values):.4f}")
print(f"  ‚Ä¢ Average mutual information: {np.mean(mi_values):.4f}")
print(f"  ‚Ä¢ Text generation diversity: {np.mean(char_diversities):.4f}")

if generation_times:
    print(f"  ‚Ä¢ Average generation time: {avg_gen_time:.1f}ms per sample")

print(f"\nBenefits of Laplace Approximation:")
print(f"  üöÄ Fast post-hoc conversion of pre-trained models")
print(f"  üìä Quantifies parameter and predictive uncertainty")  
print(f"  üéØ Provides calibrated confidence estimates")
print(f"  üí° Reveals model uncertainty in different contexts")

print(f"\nLaplace approximation analysis complete! üéâ")

## 1. Laplace Approximation

The Laplace approximation provides a Gaussian approximation to the posterior distribution around the maximum a posteriori (MAP) estimate. This is computationally efficient and gives us uncertainty estimates with minimal additional computation.

In [None]:
# Apply Laplace approximation using diagonal Fisher information
print("="*60)
print("LAPLACE APPROXIMATION WITH DIAGONAL FISHER")
print("="*60)

# Get model parameters
params = dict(model.named_parameters())
print(f"Model has {len(params)} parameter groups")
print(f"Total parameters: {sum(p.numel() for p in params.values()):,}")

# Build the Laplace transform with diagonal Fisher information
laplace_transform = posteriors.laplace.diag_fisher.build(
    log_posterior,
    len(data_batches)  # number of data points for scaling
)

print("Initializing Laplace approximation...")
laplace_state = laplace_transform.init(params)

print("Computing diagonal Fisher information matrix...")
# Update with several batches to get good Fisher information estimate
for i, batch in enumerate(data_batches[:10]):  # Use first 10 batches
    print(f"Processing batch {i+1}/10", end='\r')
    laplace_state, _ = laplace_transform.update(laplace_state, batch)

print(f"\nLaplace approximation completed!")
print(f"Posterior mean computed: {len(laplace_state.params)} parameter groups")
print(f"Posterior covariance (diagonal): {len(laplace_state.aux)} parameter groups")

# Sample from the Laplace posterior
num_samples = 5
laplace_samples = []

print(f"\nSampling {num_samples} parameter sets from Laplace posterior...")
for i in range(num_samples):
    # Sample from the posterior
    sample = posteriors.tree_utils.tree_map(
        lambda mean, var: torch.normal(mean, torch.sqrt(var.clamp(min=1e-8))),
        laplace_state.params,
        laplace_state.aux
    )
    laplace_samples.append(sample)
    print(f"Generated sample {i+1}/{num_samples}")

print(f"Successfully generated {len(laplace_samples)} Laplace posterior samples!")