# Entropy Based Sampling With Uncertainty Tokens

In [None]:
! git clone https://github.com/genlm/llamppl
! cd hfppl && pip install .

## Imports, Constants, and Utils

In [None]:
import asyncio
import math
import numpy as np
import re
import torch
import torch.nn.functional as F

from llamppl import smc_steer
from llamppl import Model, LMContext, CachedCausalLM, TokenCategorical, Token
from llamppl import log_softmax

from scipy.stats import entropy
from scipy.spatial.distance import cosine

from typing import List, Set

GREEN = "\033[92m"
RED = "\033[91m"
YELLOW = "\033[93m"
END = "\033[0m"

def pretty_format(particle):
    context_str = str(particle.context)
    new_context_str = re.sub(f"({re.escape(uncertainty_token)})", f"{YELLOW}\\1{END}", context_str)
    return f"{new_context_str} (weight: {RED}{particle.weight:.3f}{END})\n"

## Parameters

In [None]:
prompt = "Vignette goes here"
MAX_TOKENS = 50

NUM_PARTICLES = 3
BEAM_FACTOR = 3

## Model Definition

In [None]:
class NightwingLengthModel(Model):
    def __init__(
        self,
        lm: CachedCausalLM,
        prompt: str,
        max_tokens: int = 100
    ):
        super().__init__()

        print(f"\nInitializing model...")

        self.lm = lm
        self.context = LMContext(lm, prompt)
        self.max_tokens = max_tokens

    # Step method for SMC
    async def step(self):
        # Normal sampling behavior
        next_dist = self.context.next_token()
        token = await self.sample(next_dist)

        if token.token_id == self.lm.tokenizer.eos_token_id or len(self.generated_tokens) == self.max_tokens:
            self.finish()

## Run SMC

In [None]:
lm = CachedCausalLM.from_pretrained("NousResearch/Hermes-3-Llama-3.2-3B", backed='hf')
lm.batch_size = 16

model = NightwingLengthModel(
    lm=lm,
    prompt=prompt,
    max_tokens=MAX_TOKENS
)

print(f"\nSteering with smc_steer with {GREEN}{NUM_PARTICLES}{END} particles and beam factor {GREEN}{BEAM_FACTOR}{END}")

for i in range(ITERATIONS):
    particles = await smc_steer(model, NUM_PARTICLES, BEAM_FACTOR)

    print(f"\n{GREEN}{prompt}{END}")
    for particle in particles:
        print(pretty_format(particle))