In [1]:
# Bayesian NanoGPT with Posteriors Library
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import posteriors
import torchopt
from pathlib import Path
from typing import Dict, Optional, Tuple, List
import sys

from utils import encode, decode


# Set up paths and import config
sys.path.append(str(Path().resolve().parent))
import config
from utils import load_model, load_tokenizer, load_shakespeare_dataset, generate_text

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(509)
np.random.seed(509)

  from optree.integration.torch import tree_ravel


Using device: cpu


In [2]:
# Load model, tokenizer, and dataset using config paths
print("Loading configuration and paths from config.py...")
print(f"Base directory: {config.BASE_DIR}")
print(f"Model path: {config.MODEL_PATH}")
print(f"Meta path: {config.META_PATH}")
print(f"Dataset path: {config.DATASET_PATH}")

model, checkpoint = load_model(config.MODEL_PATH, device=device)
model.eval()

stoi, itos = load_tokenizer(config.META_PATH)
vocab_size = len(itos)
print(f"Vocabulary size: {vocab_size}")

full_text, prompts, references = load_shakespeare_dataset(config.DATASET_PATH)

print(f"\nDataset loaded successfully!")

if 'iter_num' in checkpoint:
    print(f"Training iterations: {checkpoint['iter_num']}")
if 'best_val_loss' in checkpoint:
    print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")

Loading configuration and paths from config.py...
Base directory: c:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen
Model path: c:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen\checkpoints\baseline_nanogpt\baseline_nanogpt.pt
Meta path: c:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen\checkpoints\baseline_nanogpt\nanogpt_meta.pkl
Dataset path: c:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen\baselines\nanogpt\dataset.txt
Loading model from: c:\Users\hayk_\OneDrive\Desktop\05_LMU_Masters\04_applied_dl\adl-bnn-textgen\checkpoints\baseline_nanogpt\baseline_nanogpt.pt
Model arguments: {'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'block_size': 256, 'bias': False, 'vocab_size': 65, 'dropout': 0.2}
number of parameters: 10.65M
Model loaded successfully!
Number of parameters: 10,745,088
Vocabulary size: 65
Successfully loaded Shakespeare dataset: 1,115,394 characters
Found 27660 meaning

# Bayesian Neural Network Text Generation with Posteriors

This notebook demonstrates how to convert a pre-trained NanoGPT model into a Bayesian neural network using the **posteriors** library. We'll explore different posterior approximation methods and analyze uncertainty in text generation.

## Overview

- **Model**: Pre-trained character-level NanoGPT on Shakespeare text
- **Methods**: Laplace approximation, Variational Inference, SGMCMC
- **Goal**: Quantify uncertainty in text generation and compare different Bayesian approaches

Let's start by setting up the environment and loading our pre-trained model.

In [None]:
# Test text generation with the deterministic model
print("="*60)
print("DETERMINISTIC TEXT GENERATION (BASELINE)")
print("="*60)

test_prompts = [
    "To be, or not to be",
    "All the world's a stage",
    "What light through yonder window breaks?"
]

print("Generating text with deterministic model...")
deterministic_outputs = []

for i, prompt in enumerate(test_prompts):
    print(f"\nPrompt {i+1}: '{prompt}'")
    generated = generate_text(
        model, 
        prompt, 
        stoi, 
        itos, 
        max_new_tokens=30,
        temperature=0.9,
        top_k=40,
        device=device
    )
    deterministic_outputs.append(generated)
    print(f"Generated: '{generated[len(prompt):].strip()}'")

print(f"\nCompleted deterministic text generation for {len(test_prompts)} prompts.")

# Laplace Approximation

In [3]:
# Create sample batch for Laplace approximation
def create_shakespeare_batch(text_sample, batch_size=20, block_size=32):
    """Create a batch from Shakespeare text for training"""
    data = torch.tensor([stoi[c] for c in text_sample], dtype=torch.long)
    
    # Create random starting positions
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    
    return x.to(device), y.to(device)

# Define log posterior for the language model
def log_posterior(params, batch):
    x, y = batch
    
    # Forward pass through model with functional call
    logits = torch.func.functional_call(model, params, (x,))
    
    # Calculate negative log likelihood
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
    
    # Add prior (L2 regularization)
    prior_precision = 1.0
    log_prior = sum(posteriors.diag_normal_log_prob({name: param}, loc=0.0, scale=1.0/prior_precision) 
                   for name, param in params.items())
    
    # Return log posterior (negative loss + log prior)
    log_post_val = -loss + log_prior / len(full_text)
    
    return log_post_val, logits

# Create Laplace approximation transform
laplace_transform = posteriors.laplace.diag_fisher.build(
    log_posterior, 
    per_sample=20  # Number of samples for Fisher information estimation
)

# Initialize Laplace state
laplace_state = laplace_transform.init(model_params)

# Create training batch from Shakespeare text
train_batch = create_shakespeare_batch(full_text[:10000], batch_size=16, block_size=64)

# Update Laplace approximation
laplace_state, aux = laplace_transform.update(laplace_state, train_batch)

print("Laplace approximation initialized and updated successfully!")
print(f"Batch shape: {train_batch[0].shape}")

NameError: name 'model_params' is not defined

In [4]:
# WORKING LAPLACE APPROXIMATION - FIXED VERSION
import time

def create_small_batch(text_sample, batch_size=2, block_size=16):
    """Create a small batch for testing"""
    # Safe encoding
    data = []
    for c in text_sample[:1000]:  # Use only first 1000 chars
        if c in stoi:
            data.append(stoi[c])
        else:
            data.append(stoi.get(' ', 0))
    
    data = torch.tensor(data, dtype=torch.long)
    
    if len(data) < block_size + 1:
        return None, None
    
    # Create sequences
    max_start = len(data) - block_size - 1
    ix = torch.randint(0, max_start, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    
    return x.to(device), y.to(device)

def working_log_posterior(params, batch):
    """Working log posterior function"""
    x, y = batch
    
    # Temporarily set parameters
    original = {}
    try:
        for name, param in model.named_parameters():
            original[name] = param.data.clone()
            param.data = params[name]
        
        model.eval()
        with torch.no_grad():
            outputs = model(x)
        
        # Restore immediately
        for name, param in model.named_parameters():
            param.data = original[name]
        
        # Handle tuple output
        if isinstance(outputs, tuple):
            logits = outputs[0]
        else:
            logits = outputs
        
        # Compute loss
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
        
        # Simple prior
        prior = sum(0.5 * torch.sum(p**2) * 1e-6 for p in params.values())
        
        return -loss + prior, logits
        
    except Exception as e:
        # Always restore
        for name, param in model.named_parameters():
            if name in original:
                param.data = original[name]
        return torch.tensor(-1000.0, device=device), None

# Get parameters
model_params = dict(model.named_parameters())
print(f"Model parameters: {len(model_params)} tensors")

# Create test batch
test_batch = create_small_batch(full_text)
if test_batch[0] is not None:
    print(f"Test batch: {test_batch[0].shape} -> {test_batch[1].shape}")
    
    # Test log posterior
    log_val, logits = working_log_posterior(model_params, test_batch)
    print(f"Log posterior test: {log_val.item():.4f}")
    
    # Create Laplace transform
    laplace_transform = posteriors.laplace.diag_fisher.build(working_log_posterior, 1.0)
    laplace_state = laplace_transform.init(model_params)
    
    # Run update
    print("Running Laplace update...")
    laplace_state, aux = laplace_transform.update(laplace_state, test_batch)
    print("✅ Laplace approximation successful!")
    
else:
    print("❌ Failed to create batch")

Model parameters: 39 tensors
Test batch: torch.Size([2, 16]) -> torch.Size([2, 16])
Log posterior test: -1000.0000
Running Laplace update...


RuntimeError: level.has_value() && level <= current_level INTERNAL ASSERT FAILED at "C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\aten\\src\\ATen\\functorch\\ADInterpreters.cpp":46, please report a bug to PyTorch. escaped?

In [5]:
# SIMPLE MANUAL BAYESIAN APPROACH (AVOIDING FUNCTORCH ISSUES)
import time

class SimpleBayesianNanoGPT:
    """Simple Bayesian wrapper for NanoGPT that works around functorch issues"""
    
    def __init__(self, base_model):
        self.model = base_model
        self.base_params = dict(base_model.named_parameters())
        self.param_uncertainties = {}
        
        # Initialize uncertainties as small values
        for name, param in self.base_params.items():
            self.param_uncertainties[name] = torch.ones_like(param) * 0.001
    
    def estimate_uncertainties(self, text_sample, num_batches=5):
        """Estimate parameter uncertainties using simple gradient variance"""
        print("Estimating parameter uncertainties...")
        
        # Collect gradients from multiple batches
        all_gradients = {name: [] for name in self.base_params.keys()}
        
        # Prepare small batches
        data = [stoi.get(c, 0) for c in text_sample[:2000]]
        data = torch.tensor(data, dtype=torch.long, device=device)
        
        for i in range(num_batches):
            # Create small batch
            if len(data) < 32:
                continue
                
            start_idx = torch.randint(0, len(data) - 16, (1,)).item()
            x = data[start_idx:start_idx+16].unsqueeze(0)
            y = data[start_idx+1:start_idx+17].unsqueeze(0)
            
            # Forward pass with gradients
            self.model.train()
            self.model.zero_grad()
            
            outputs = self.model(x)
            if isinstance(outputs, tuple):
                logits = outputs[0]
            else:
                logits = outputs
            
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
            loss.backward()
            
            # Collect gradients
            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    all_gradients[name].append(param.grad.clone())
            
            print(f"  Batch {i+1}/{num_batches} processed")
        
        # Compute gradient variances as uncertainty estimates
        for name in self.base_params.keys():
            if len(all_gradients[name]) > 1:
                grad_stack = torch.stack(all_gradients[name])
                self.param_uncertainties[name] = torch.var(grad_stack, dim=0).clamp(min=1e-6)
            else:
                self.param_uncertainties[name] = torch.ones_like(self.base_params[name]) * 0.001
        
        # Clear gradients
        self.model.zero_grad()
        self.model.eval()
        print("✅ Uncertainty estimation completed!")
    
    def sample_parameters(self, scale=0.01):
        """Sample parameters from approximate posterior"""
        sampled_params = {}
        for name, base_param in self.base_params.items():
            noise_std = torch.sqrt(self.param_uncertainties[name]) * scale
            noise = torch.randn_like(base_param) * noise_std
            sampled_params[name] = base_param + noise
        return sampled_params
    
    def generate_with_uncertainty(self, prompt, num_samples=3, max_tokens=30):
        """Generate text with uncertainty quantification"""
        print(f"Generating {num_samples} samples for: '{prompt}'")
        
        results = []
        for i in range(num_samples):
            # Sample parameters
            sampled_params = self.sample_parameters(scale=0.005)  # Small noise
            
            # Temporarily set sampled parameters
            original_params = {}
            try:
                for name, param in self.model.named_parameters():
                    original_params[name] = param.data.clone()
                    param.data = sampled_params[name]
                
                # Generate text
                encoded = [stoi.get(c, 0) for c in prompt]
                x = torch.tensor(encoded, dtype=torch.long, device=device)[None, ...]
                
                self.model.eval()
                with torch.no_grad():
                    for _ in range(max_tokens):
                        outputs = self.model(x)
                        if isinstance(outputs, tuple):
                            logits = outputs[0]
                        else:
                            logits = outputs
                        
                        # Sample next token
                        probs = F.softmax(logits[0, -1] / 0.8, dim=-1)
                        next_token = torch.multinomial(probs, 1)
                        x = torch.cat([x, next_token.unsqueeze(0)], dim=1)
                
                # Decode result
                generated_tokens = x[0].tolist()
                generated_text = ''.join([itos.get(i, '?') for i in generated_tokens])
                results.append(generated_text)
                
                # Restore original parameters
                for name, param in self.model.named_parameters():
                    param.data = original_params[name]
                
                print(f"  Sample {i+1}: '{generated_text[len(prompt):].strip()}'")
                
            except Exception as e:
                # Always restore parameters
                for name, param in self.model.named_parameters():
                    if name in original_params:
                        param.data = original_params[name]
                print(f"  Sample {i+1}: Error - {e}")
                results.append(f"Error: {e}")
        
        return results

# Create Bayesian wrapper
bayesian_nanogpt = SimpleBayesianNanoGPT(model)

# Estimate uncertainties
bayesian_nanogpt.estimate_uncertainties(full_text)

# Test Bayesian generation
print("\n" + "="*60)
print("SIMPLE BAYESIAN TEXT GENERATION")
print("="*60)

test_prompts = ["To be", "Romeo", "The king"]

for prompt in test_prompts:
    print(f"\nPrompt: '{prompt}'")
    results = bayesian_nanogpt.generate_with_uncertainty(prompt, num_samples=3, max_tokens=20)

print("\n✅ Simple Bayesian text generation completed!")

: 