In [172]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import torch.nn.functional as F

In [12]:
from huggingface_hub import notebook_login, login

#notebook_login()
login('YOUR_HF_TOKEN')

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/pau/.cache/huggingface/token
Login successful


In [13]:
model_name = 'meta-llama/Llama-3.2-1B'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [14]:
wait_token = "...Wait! Let's think step by step..."

In [31]:
# If you want to add the CoT token into the vocab.
if wait_token not in tokenizer.get_vocab():
    tokenizer.add_tokens([wait_token])
    model.resize_token_embeddings(len(tokenizer))
    
WAIT_TOKEN_ID = tokenizer.encode(wait_token, add_special_tokens=False)[0]

In [17]:
EOT_TOKEN_ID = tokenizer.eos_token_id

In [130]:
U = 2e-7
MAX_ERROR = np.log(U) - np.log(tokenizer.vocab_size)/2
MAX_ERROR, 1/MAX_ERROR**2

(-21.304841241849253, 0.002203146594003764)

In [275]:
class LLMSampler:
    def __init__(self, model, tokenizer, entropy_threshold=1.0, varentropy_threshold=0.25):
        self.model = model
        self.tokenizer = tokenizer
        self.entropy_threshold = entropy_threshold
        self.varentropy_threshold = varentropy_threshold
        self.verbose = False
        self.use_beam_search = False
    
    def calculate_entropy(self, logits):
        probs = torch.softmax(logits, dim=-1)
        valid_probs = probs[(probs > 0) & (torch.log(probs) >= MAX_ERROR)]
        valid_probs = valid_probs / torch.sum(valid_probs, dim=-1)
        H_X = -torch.sum(valid_probs * torch.log(valid_probs), dim=-1)
        return H_X

    def calculate_varentropy(self, logits, entropies):
        H_X = entropies
        probs = torch.softmax(logits, dim=-1)
        valid_probs = probs[(probs > 0) & (torch.log(probs) >= MAX_ERROR)]
        valid_probs = valid_probs / torch.sum(valid_probs, dim=-1)
        log_p = torch.log(valid_probs)
        E_log_p_squared = torch.sum(valid_probs * log_p**2, dim=-1)
        var_H_X = E_log_p_squared - H_X**2
        return var_H_X

    def beam_search(self, logits, beam_width, max_beam_steps, temperature):
        scores = F.log_softmax(logits / temperature, dim=-1)
        beam = [(scores[0, i].item(), [i]) for i in torch.topk(scores[0], beam_width).indices.tolist()]

        for _ in range(max_beam_steps - 1):
            candidates = []
            for score, sequence in beam:
                input_sequence = torch.tensor(sequence).unsqueeze(0)
                outputs = self.model(input_sequence)
                next_logits = outputs.logits[:, -1, :]
                next_scores = F.log_softmax(next_logits / temperature, dim=-1)
                top_scores, top_indices = next_scores[0].topk(beam_width)
                for i, (s, idx) in enumerate(zip(top_scores.tolist(), top_indices.tolist())):
                    candidates.append((score + s, sequence + [idx]))
            
            beam = sorted(candidates, key=lambda x: x[0], reverse=True)[:beam_width]

            if all(sequence[-1] == beam[0][1][-1] for _, sequence in beam):
                break
        return beam[0][1]


    def apply_sampling_parameters(self, logits, temperature=1.0, top_k=0, top_p=1.0, min_p=0.0):
        logits = logits/temperature

        if top_k > 0:
            top_k = min(top_k, logits.size(-1))
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
            logits[indices_to_remove] = float('-inf')
        
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
            logits[indices_to_remove] = float('-inf')

        if min_p > 0.0:
            probs = torch.softmax(logits, dim=-1)
            indices_to_remove = probs < min_p
            logits[indices_to_remove] = float('-inf')

        return logits

    def torch_uniform(self, minval, maxval):
        return ((minval - maxval) * torch.rand(1) + maxval).item()

    def sample(self, input_text, max_length=100, temperature=1.0, top_k=0, top_p=1.0, min_p=0.0):
        input_ids = self.tokenizer.encode(input_text, return_tensors='pt')
        output_ids = input_ids.clone()
        logits_entropies = []
        cot_cooldown = 0
        cot_tokens = self.tokenizer.encode("... Wait! Let's think step by step:", add_special_tokens=False)[1:]

        for _ in range(max_length):
            with torch.no_grad():
                outputs = self.model(output_ids)
                logits = outputs.logits[:, -1, :]

            logits = self.apply_sampling_parameters(logits, temperature, top_k, top_p, min_p)
            
            if torch.argmax(logits, dim=-1).item() == EOT_TOKEN_ID:
                output_ids = torch.cat([output_ids, torch.tensor([[EOT_TOKEN_ID]])], dim=-1)
                break

            entropies = self.calculate_entropy(logits)
            logits_entropies.append(entropies)
            varentropy = self.calculate_varentropy(logits, entropies)
            high_entropy = entropies.item() > self.entropy_threshold
            high_varentropy = varentropy > self.varentropy_threshold
            if self.verbose:
                print(entropies, entropies.item(), varentropy)
                print(high_entropy, high_varentropy)

            if not high_entropy and not high_varentropy:
                # If the model is very confident, and we don't have many confident options we just take the argmax.
                next_token = torch.argmax(logits, dim=-1).unsqueeze(0)

            elif high_entropy and not high_varentropy and cot_cooldown <= 0:
                # If the model is not very confident, but the uncertainty is spread equally (meaning probs are similar across), we insert a CoT token.
                next_token = torch.tensor([cot_tokens])
                cot_cooldown = 100

            elif (high_entropy and high_varentropy) or (high_entropy and not high_varentropy and cot_cooldown > 0):
                # The model is not confident and the uncertainty is high, meaning multiple options are possible/acceptable.
                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).t()

            else:
                # The model is confident, but still uncertain. Beam search looks at possible futures and returns the token that results in the likeliest option.
                # If you are gpu poor this can be heavy.
                if self.verbose:
                    print(" | branching | ", end='')
                if self.use_beam_search:
                    next_token = torch.tensor([[self.beam_search(logits, 50, 3, temperature * (1 + self.torch_uniform(-0.1, 0.1)))[0]]])
                else:
                    probs = torch.softmax(logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1).t()
            
            output_ids = torch.cat([output_ids, next_token], dim=-1)
            cot_cooldown -= 1
            if next_token.size(1) == 1 and next_token.item() == EOT_TOKEN_ID:
                break
            print(self.tokenizer.decode(next_token[0]), end='')

        return self.tokenizer.decode(output_ids[0])

In [282]:
sampler = LLMSampler(model, tokenizer, 1, 0.75)

In [None]:
output = sampler.sample("Which number is larger: 9.9 or 9.11?", temperature=1.1, top_k=3000, top_p=0.8, min_p=0.0)

In [278]:
output

"<|begin_of_text|>Which number is larger: 9.9 or 9.11? Are they the same Wait! Let's think step by step: 9.9 = 9 + 0.9 = 9.1 9.11 = 9 + 0.11 = 9.11 The numbers are the same.<|end_of_text|>"

In [171]:
output

"<|begin_of_text|>Which number is larger: 9.9 or 9.11? Wait! Let's think step by step: 9.9 = 9 + 0.9 9.11 = 9 + 0.11 9.9 is larger because 0.9 > 0.11.<|end_of_text|>"

In [243]:
output

"<|begin_of_text|>Which number is larger: 9.9 or 9.11? Wait! Let's think step by step: First, 9.9 is larger than 9.11 because 9.9 has the greater value of the two numbers.<|end_of_text|>"