In [1]:
"""
Text Generation from Bayesian Neural Network (BNN) using Posteriors Library

This script shows how to properly generate diverse text by sampling 
different parameter sets from the learned posterior distribution.
"""

import sys
import os

root_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(root_path)

import torch
import pickle
from pathlib import Path
from contextlib import nullcontext
from torch import func
import posteriors

# Load your model architecture
from baselines.nanogpt.model import GPT, GPTConfig

  from optree.integration.torch import tree_ravel


In [2]:
from config import MODEL_PATH, META_PATH, BNN_MODEL_PATH
START_PROMPT = "I keep on burning deadlines,"
NUM_SAMPLES = 1  # Generate 5 different completions
MAX_NEW_TOKENS = 300
TEMPERATURE = 0.8
TOP_K = 200
SEED = 42
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SAMPLER_TYPE = 'vi' 

In [3]:
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# Load checkpoint
checkpoint = torch.load(BNN_MODEL_PATH, map_location='cpu', weights_only=False)

# Initialize model
model_args = checkpoint.get('model_args', {
    'n_layer': 6, 'n_head': 6, 'n_embd': 384, 
    'block_size': 256, 'bias': False, 'vocab_size': 65, 'dropout': 0.0
})
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
model.eval()  
model.to(DEVICE)

# Load tokenizer
with open(META_PATH, 'rb') as f:
    meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

number of parameters: 10.65M


--

In [None]:
import numpy as np 
import torch.nn.functional as F

def generate_text_bayesian(model, state, start_prompt, encode_fn, decode_fn, 
                          max_new_tokens=500, temperature=1.0, top_k=None,
                          num_samples=1, use_uncertainty=True):
    """
    Generate text using Bayesian model with uncertainty
    
    Args:
        model: The base model
        state: Bayesian sampler state (VIDiagState, etc.)
        start_prompt: Starting text string
        encode_fn: Function to encode text to tokens
        decode_fn: Function to decode tokens to text
        max_new_tokens: Number of tokens to generate
        temperature: Sampling temperature (higher = more random)
        top_k: If set, only sample from top k tokens
        num_samples: Number of posterior samples to use (1 = posterior mean)
        use_uncertainty: If True, sample from posterior; if False, use mean params
    
    Returns:
        generated_text: String of generated text
        uncertainty_info: Dict with token-level uncertainties (if num_samples > 1)
    """
    model.eval()
    
    # Encode starting prompt
    context = torch.tensor(encode_fn(start_prompt), dtype=torch.long, device=DEVICE).unsqueeze(0)
    
    # Get parameters to use
    if use_uncertainty and num_samples > 1:
        # Sample multiple parameter sets from posterior
        param_samples = []
        for _ in range(num_samples):
            if hasattr(state, 'log_sd_diag'):  # VI, EKF, Laplace
                sample_params = posteriors.vi.diag.sample(state)
            else:
                sample_params = state.params
            param_samples.append(sample_params)
    else:
        # Use posterior mean
        param_samples = [state.params]
    
    generated_tokens = []
    token_uncertainties = []
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Get predictions from all parameter samples
            all_logits = []
            
            for params in param_samples:
                # Forward pass with current parameters
                logits, _ = func.functional_call(model, params, (context,))
                # Get logits for last token
                logits = logits[:, -1, :] / temperature
                all_logits.append(logits)
            
            # Average logits across samples
            avg_logits = torch.stack(all_logits).mean(dim=0)
            
            # Calculate uncertainty (entropy of averaged predictions)
            if len(all_logits) > 1:
                probs = torch.stack([F.softmax(l, dim=-1) for l in all_logits])
                mean_probs = probs.mean(dim=0)
                entropy = -(mean_probs * torch.log(mean_probs + 1e-8)).sum(dim=-1)
                token_uncertainties.append(entropy.item())
            
            # Apply top-k filtering if specified
            if top_k is not None:
                v, _ = torch.topk(avg_logits, min(top_k, avg_logits.size(-1)))
                avg_logits[avg_logits < v[:, [-1]]] = -float('Inf')
            
            # Sample next token
            probs = F.softmax(avg_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            generated_tokens.append(next_token.item())
            
            # Append to context
            context = torch.cat([context, next_token], dim=1)
            
            # Optionally crop context to max sequence length
            if context.size(1) > model.config.block_size:
                context = context[:, -model.config.block_size:]
    
    # Decode generated tokens
    generated_text = decode_fn(generated_tokens)
    full_text = start_prompt + generated_text
    
    uncertainty_info = {
        'token_uncertainties': token_uncertainties,
        'avg_uncertainty': np.mean(token_uncertainties) if token_uncertainties else 0.0,
        'max_uncertainty': np.max(token_uncertainties) if token_uncertainties else 0.0
    }
    
    return full_text, uncertainty_info

In [None]:
START_PROMPT = "to be, or"
from posteriors.vi.diag import VIDiagState
params = {k: v.to(DEVICE) for k, v in checkpoint['sampler_state_params'].items()}

state = VIDiagState(
    params=params, 
    log_sd_diag=checkpoint['log_sd_diag'],
    opt_state=checkpoint['opt_state'] 
)  

text_v2, unc_info = generate_text_bayesian(
    model, state, START_PROMPT, encode, decode,
    max_new_tokens=600,
    temperature=0.4, 
    top_k=10,        
    num_samples=5,
    use_uncertainty=True)

In [9]:
print("\n=== Bayesian Generation with Uncertainty ===")
print(f"\nAverage uncertainty: {unc_info['avg_uncertainty']:.4f}")
print(f"Max uncertainty: {unc_info['max_uncertainty']:.4f}")


=== Bayesian Generation with Uncertainty ===

Average uncertainty: 0.5176
Max uncertainty: 2.4535


In [10]:
print(text_v2)

To be or not to be, the story Capulets and Saint George by the princes
Of all the world at hands me: the heavens are by the
That may strange her made me to the gates of a man the
And bite makes a fall of any thing and the guest content be abroad;
And yet I'll be ready to me; and therefore I do beseech you,
Who doth not proclaim your son of lend and cry your lordship.

LUCIO:
No, by your honour, I know your honest sir; I pray your son
Where is not advantage of such a hands.

LUCIO:
You shall not be so. You have been a man of the world and the state of your
proclaim of a strange any thing, you say your son your son


SGLD/SGHMC/BAOA checkpoint loading example

In [4]:
from src.generation_utils import load_checkpoint_for_generation, generate_text_bayesian_sgmcmc, save_generation_result


checkpoint_data = load_checkpoint_for_generation(BNN_MODEL_PATH, device=DEVICE)

# # Generate text using collected SGMCMC samples
START_PROMPT = "to be, or not to be;"
collected_samples = checkpoint_data['collected_samples']

Loaded 100 collected samples from checkpoint


In [5]:
text, unc_info = generate_text_bayesian_sgmcmc(
    model, collected_samples, START_PROMPT, encode, decode,
    max_new_tokens=600,
    temperature=0.3,
    top_k=10,
    num_samples=20  # Use 20 of the collected samples (~100 total available)
)


Using 20 SGMCMC samples for generation


In [6]:
save_generation_result(START_PROMPT, text, unc_info,
    max_new_tokens=600,
    temperature=0.3,
    top_k=10,
    num_samples=20,
    collected_samples=collected_samples,
    sample_id="sgmcmc_example_3",
    save_path="/Users/sofianikolenko/Downloads/Projects_25/ADL/adl-bnn-textgen/checkpoints/generation_results/generation_results_sgmcmc.json",
    )

Saved generation result with ID: sgmcmc_example_3


'sgmcmc_example_3'