# Entropy Based Sampling With Uncertainty Tokens And Control Vectors
Tested on T4 GPU instance :-)

This notebook contains a simple example for entropy based sampling with [hfppl](https://github.com/probcomp/hfppl) and control vector steering with the [repeng]() library. 

The prompt and uncertainty tokens are configurable, as well as the temperature, the number of iterations `smc_steer` is run, the beam factor, the number of particles, and an entropy threshold. There is also an option to trigger control vectors when the assistant outputs a specific token, but be default the control vector is applied after the uncertainty token is inserted.

In [None]:
! git clone https://github.com/genlm/llamppl
! cd hfppl && pip install .
! pip install repeng

## Imports, Constants, and Utils

Below are imports needed, as well as some constants for text coloring and a pretty format method for easier visualization.

In [None]:
import numpy as np
import math
from scipy.stats import entropy
from scipy.spatial.distance import cosine
import torch
from typing import List, Set
import torch.nn.functional as F

import asyncio
from llamppl import smc_steer
from llamppl import Model, LMContext, CachedCausalLM, TokenCategorical, Token
from llamppl import log_softmax
from llamppl import sample_word

import re

import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from repeng import ControlVector, ControlModel, DatasetEntry

GREEN = "\033[92m"
RED = "\033[91m"
YELLOW = "\033[93m"
END = "\033[0m"

def pretty_format(particle):
    context_str = str(particle.context)
    new_context_str = re.sub(f"({re.escape(uncertainty_token)})", f"{YELLOW}\\1{END}", context_str)
    return f"{new_context_str} (weight: {RED}{particle.weight:.3f}{END})\n"

## Parameters

Here are the configurable pieces of this notebook:

- prompt: the prompt given to the LLM

- uncertainty_token: the token to inject to induce backtracking when entropy is high - (eventually I will update this to take multiple tokens, it's diverged a little since that was part of this implementation but is very doable)

- NUM_PARTICLES: number of particles to use for SMC - this is the number of different outputs you will see below

- NUM_BEAMS: the beam factor used in `smc_steer`, increasing this extends the beam search replication of the SMC steering

- ENTROPY_THRESHOLD: this is a floating point number > 0.0 which determines the threshold to trigger entropy after it exceeds this value. Values under 4.0 trigger more commonly, and values above 4.0 trigger much less often. I've found 2.5-3.0 is a good range for Llama-3.2-1B, though I'd love to implement variation in entropy (varentropy) as a higher confidence measure.

- MAX_TOKENS: the maximum number of tokens to generate

- ITERATIONS: how many times to run `smc_steer` in a single call (you probably won't change this one, was used for some eval scripts I have set up)

-

In [None]:
NUM_PARTICLES = 3
BEAM_FACTOR = 3
ENTROPY_THRESHOLD = 2.5
MAX_TOKENS = 200
ITERATIONS = 1

USE_TOKEN_TRIGGER = False
TOKEN_TRIGGER = "?"

TEMPERATURE = 1.0
CVEC_TOKENS = 10
CVEC_STRENGTH = 1.2

prompt = "Which number is larger, 9.9 or 9.11?"
uncertainty_token = "wait..."

POSITIVE_PERSONA = "thoughtful and genuine in your responses"
NEGATIVE_PERSONA = "superficial and empty in your responses"

## Model Definition
- LLamPPL model override for entropy based sampling and steering
- I'm including use of the `score` method here, though, to be honest I'm not totally sure if this is working or how.

In [None]:
class NightwingEntropyControlModel(Model):
    def __init__(
        self,
        lm: CachedCausalLM,
        prompt: str,
        uncertainty_token: str = "Wait...",
        temperature: float = 1.0,
        entropy_threshold: float = 3.0,
        max_tokens: int = 100,
        min_tokens_between_uncertainty: int = 50,
        cvec_tokens: int = 10,
        cvec_strength: float = 1.2,
        use_token_trigger = False
    ):
        super().__init__()

        print(f"\nInitializing model...")

        self.lm = lm
        self.context = LMContext(lm, prompt)
        self.uncertainty_tokens = self.lm.tokenizer.encode(
            f" {uncertainty_token}", # Encode uncertainty token with a space in front
            add_special_tokens=False
        )
        self.entropy_threshold = entropy_threshold
        self.max_tokens = max_tokens

        # Tracking uncertainty
        self.min_tokens_between_uncertainty = min_tokens_between_uncertainty
        self.generated_tokens = []
        self.last_uncertainty_pos = -self.min_tokens_between_uncertainty
        self.is_generating_uncertainty = False
        self.is_generating_cvec = False
        self.control_vector = self.train_control_vector(lm)
        self.use_token_trigger = use_token_trigger
        self.cvec_tokens = cvec_tokens
        self.cvec_strength = cvec_strength

    # Entropy calculation with LSE and normalization
    def calculate_entropy(self, logprobs: np.ndarray) -> float:
        probs = np.exp(logprobs - np.max(logprobs)) # log-sum-exp
        probs = probs / np.sum(probs) # Normalize

        return float(-np.sum(probs * np.log(probs + 1e-10))) # Add constant to avoid taking log(0)

    # Step method for SMC
    async def step(self):
        if len(self.generated_tokens) >= self.max_tokens:
            self.finish()
            return

        # Current position in generated output
        current_pos = len(self.generated_tokens)
        
        # If we aren't currently inserting uncertainty tokens
        # And we're outside of the minimum distance to insert tokens
        if (not self.is_generating_uncertainty and 
            current_pos - self.last_uncertainty_pos >= self.min_tokens_between_uncertainty):
            
            # Calculate entropy values
            logprobs = await self.lm.next_token_logprobs(self.context.tokens)
            current_entropy = self.calculate_entropy(logprobs)
            
            # If entropy is above the threshold and it's not the first token
            if current_entropy > self.entropy_threshold and current_pos != 0:
                self.is_generating_uncertainty = True
                self.last_uncertainty_pos = current_pos

        # Observe the uncertainty token
        if self.is_generating_uncertainty:
            if len(self.generated_tokens) - self.last_uncertainty_pos < len(self.uncertainty_tokens):

                # Create token
                token_idx = len(self.generated_tokens) - self.last_uncertainty_pos
                token = Token(self.lm, self.uncertainty_tokens[token_idx], self.lm.tokenizer.decode(self.uncertainty_tokens[token_idx]))

                # Observe the token over the next token distribution
                next_dist = self.context.next_token()
                await self.observe(next_dist, token)
                self.score(0.2)

                self.generated_tokens.append(token)

                if token_idx == len(self.uncertainty_tokens) - 1:
                    self.is_generating_uncertainty = False

                if not self.use_token_trigger:
                    self.is_generating_cvec = True
            return

        if self.is_generating_cvec:
            self.lm.model.set_control(self.control_vector, self.cvec_strength)
            tokens_remaining = self.cvec_tokens
            while tokens_remaining > 0:
                next_token = self.context.next_token()
                await self.sample(next_token)
                tokens_remaining -= 1
            #self.condition(tokens_remaining == 0)
            self.lm.model.set_control(self.control_vector, 0.0)  
            self.is_generating_cvec = False
            return      

        # Normal sampling behavior
        next_dist = self.context.next_token()
        token = await self.sample(next_dist)
        self.score(1.0)
        self.generated_tokens.append(token)

        if token.token_id == self.lm.tokenizer.eos_token_id:
            self.finish()

    def make_dataset(self, template: str, pos_personas: list[str], neg_personas: list[str], suffixes: list[str]):
        dataset = []
        user_tag, asst_tag = "", ""
        for suffix in suffixes:
            for positive_persona, negative_persona in zip(pos_personas, neg_personas):
                positive_template = template.format(persona=positive_persona)
                negative_template = template.format(persona=negative_persona)
                dataset.append(
                    DatasetEntry(
                        positive=f"{user_tag} {positive_template} {asst_tag} {suffix}",
                        negative=f"{user_tag} {negative_template} {asst_tag} {suffix}",
                    )
                )
        return dataset

    def train_control_vector(self, lm):
        print("Training control vector...")
        #user_tag, asst_tag = "[INST]", "[/INST]"
        user_tag, asst_tag = "", ""
        lm.tokenizer.pad_token_id = 0

        with open("/home/nightwing/Desktop/Projects/repeng/notebooks/data/all_truncated_outputs.json") as f:
            output_suffixes = json.load(f)

        truncated_output_suffixes = [
            lm.tokenizer.convert_tokens_to_string(tokens[:i])
            for tokens in (lm.tokenizer.tokenize(s) for s in output_suffixes)
            for i in range(1, len(tokens))
        ]

        cvec_dataset = self.make_dataset(
            "Act as if you're extremely {persona}.",
            [POSITIVE_PERSONA],
            [NEGATIVE_PERSONA],
            truncated_output_suffixes,
        )

        control_vector = ControlVector.train(lm.model, lm.tokenizer, cvec_dataset)

        return control_vector

## Run SMC
- Running the below cell will run `smc_steer` ITERATIONS number of times, likely once, with the parameters set above.

In [None]:
    lm = CachedCausalLM.from_pretrained("NousResearch/Hermes-3-Llama-3.2-3B", backend='hf')

    lm.model = ControlModel(lm.model, list(range(-5, -12, -1)))

    lm.batch_size = 8

    model = NightwingEntropyControlModel(
        lm=lm,
        prompt=prompt,
        uncertainty_token=uncertainty_token,
        temperature=TEMPERATURE,
        entropy_threshold=ENTROPY_THRESHOLD,
        max_tokens=MAX_TOKENS,
        use_token_trigger=USE_TOKEN_TRIGGER
    )

    print(f"\nSteering with smc_steer with {GREEN}{NUM_PARTICLES}{END} particles and beam factor {GREEN}{BEAM_FACTOR}{END}")

    for i in range(ITERATIONS):
        particles = await smc_steer(model, NUM_PARTICLES, BEAM_FACTOR)

        print(f"\n{GREEN}{prompt}{END}")
        for particle in particles:
            print(pretty_format(particle))