In [7]:
import torch
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

# For colored output in Jupyter
from IPython.display import clear_output, display, HTML

# Define colors for small and large models
SMALL_MODEL_COLOR = 'blue'
LARGE_MODEL_COLOR = 'red'

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

# Add special tokens to the tokenizer if needed
special_tokens = {
    "additional_special_tokens": [
        "<|begin_of_text|>",
        "<|start_header_id|>",
        "<|end_header_id|>",
        "<|eot_id|>",
    ]
}
tokenizer.add_special_tokens(special_tokens)

# Load both models
model_small_id = "meta-llama/Llama-3.2-1B-Instruct"
model_large_id = "meta-llama/Llama-3.2-3B-Instruct"

model_small = AutoModelForCausalLM.from_pretrained(
    model_small_id, torch_dtype=torch.bfloat16
)
model_large = AutoModelForCausalLM.from_pretrained(
    model_large_id, torch_dtype=torch.bfloat16
)

# Resize token embeddings if special tokens were added
model_small.resize_token_embeddings(len(tokenizer))
model_large.resize_token_embeddings(len(tokenizer))

model_small.eval()
model_large.eval()

# Move models to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_small.to(device)
model_large.to(device)

# Define the conversation as a string prompt
def build_prompt(conversation):
    """
    Build the prompt string from a list of conversation turns.
    Each turn is a tuple: (role, content)
    """
    prompt = ""
    for role, content in conversation:
        if role == "system":
            prompt += (
                f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
                f"{content}<|eot_id|>"
            )
        elif role == "user":
            prompt += (
                f"<|start_header_id|>user<|end_header_id|>\n{content}<|eot_id|>"
            )
        elif role == "assistant":
            prompt += (
                f"<|start_header_id|>assistant<|end_header_id|>\n{content}"
            )
        else:
            raise ValueError(f"Unknown role: {role}")
    return prompt

# Example conversation
conversation = [
    (
        "system",
        "You are an intelligent AI assistant trained to do complex reasoning. You should carefully think about the question before outputting your final response.",
    ),
    (
        "user",
        "How many 'r's are there in the word strawberry?",
    ),
    ("assistant", ""),  # Start of assistant's reply
]

# Build the prompt
prompt = build_prompt(conversation)

# Tokenize the input prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Initialize variables
generated = input_ids
max_new_tokens = 100  # Maximum number of new tokens to generate
temperature = 0.7

# Get the EOS token ID
eos_token = "<|eot_id|>"
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)

# Prepare to store entropy and varentropy values
entropies = []
varentropies = []

# Flags and variables to manage model switching
current_model = model_small

is_first_pass = True
small_model_past_key_values = None
large_model_past_key_values = None
tokens_since_last_switch = 0  # Number of tokens generated since last model switch

# For storing generated tokens and their colors
generated_tokens = []
generated_colors = []

# Generate tokens one by one
for _ in range(max_new_tokens):
    if is_first_pass:
        # For the first pass, generate tokens with the small model
        small_outputs = model_small(input_ids=generated)
        small_model_past_key_values = small_outputs.past_key_values
        large_outputs = model_large(input_ids=generated)
        large_model_past_key_values = large_outputs.past_key_values
        is_first_pass = False
        outputs = large_outputs
    else:
        # For subsequent passes, use the current model
        current_past_key_values = small_model_past_key_values if current_model == model_small else large_model_past_key_values
        outputs = current_model(
            input_ids=generated[:, -1:], past_key_values=current_past_key_values
        )
        if current_model == model_small:
            small_model_past_key_values = outputs.past_key_values
        else:
            large_model_past_key_values = outputs.past_key_values

    next_token_logits = outputs.logits[:, -1, :]

    # Scale the logits by the temperature parameter
    next_token_logits = next_token_logits / temperature

    # Apply softmax to convert to probabilities
    log_probs = F.log_softmax(next_token_logits, dim=-1)
    probs = torch.exp(log_probs)

    # Compute entropy and varentropy
    entropy = -torch.sum(probs * log_probs, dim=-1) / math.log(2)  # Convert to bits
    varentropy = torch.sum(
        probs * ((log_probs / math.log(2) + entropy.unsqueeze(-1)) ** 2), dim=-1
    )
    entropies.append(entropy.item())
    varentropies.append(varentropy.item())

    # Adjust sampling strategy based on entropy and varentropy
    if entropy.item() < 0.5 and varentropy.item() < 0.1:
        # Low entropy and low varentropy: model is confident
        # Use the smaller model if not already using it
        if current_model != model_small:
            print("\nSwitching to small model (confident).")
            current_model = model_small
            # Recompute past_key_values for the small model
            # We need to process the generated tokens since last switch
            # TODO make sure this logic is correct
            generated_tokens_to_reprocess = generated[:, -tokens_since_last_switch:]
            outputs = current_model(input_ids=generated_tokens_to_reprocess, past_key_values=small_model_past_key_values)
            small_model_past_key_values = outputs.past_key_values

            next_token_logits = outputs.logits[:, -1, :]
            tokens_since_last_switch = 0
        # Take argmax (greedy)
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        sampling_info = "Greedy sampling (confident)"
    elif entropy.item() > 3.0 and varentropy.item() > 1:
        # High entropy and high varentropy: model is very uncertain
        # Use the larger model if not already using it
        if current_model != model_large:
            print("\nSwitching to large model (uncertain).")
            current_model = model_large
            # Recompute past_key_values for the large model
            # We need to process all generated tokens
            # TODO make sure this logic is correct
            generated_tokens_to_reprocess = generated[:, -tokens_since_last_switch:]
            outputs = current_model(input_ids=generated_tokens_to_reprocess, past_key_values=large_model_past_key_values)
            large_model_past_key_values = outputs.past_key_values

            next_token_logits = outputs.logits[:, -1, :]
            tokens_since_last_switch = 0
        # Continue sampling with adjusted parameters
        adjusted_temperature = min(1.5, temperature * 1.3)
        next_token_logits = next_token_logits / adjusted_temperature
        # Sample the next token
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        sampling_info = f"Sampling with increased temperature {adjusted_temperature:.2f} (uncertain)"
    else:
        # Default sampling
        next_token_logits = next_token_logits / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        sampling_info = "Default sampling"

    # Append the next token to the generated sequence
    generated = torch.cat((generated, next_token), dim=-1)
    tokens_since_last_switch += 1

    # Decode the newly generated token
    decoded_token = tokenizer.decode(next_token[0], skip_special_tokens=False)

    # Store the token and its associated color
    generated_tokens.append(decoded_token)
    if current_model == model_small:
        generated_colors.append(SMALL_MODEL_COLOR)
    else:
        generated_colors.append(LARGE_MODEL_COLOR)

    # Display the generated text with colors
    # Build the HTML content
    html_content = ''
    for token, color in zip(generated_tokens, generated_colors):
        html_content += f'<span style="color: {color}">{token}</span>'
    clear_output(wait=True)
    display(HTML(html_content))

    # Stop if the EOS token is generated
    if next_token.item() == eos_token_id:
        break

# After generation, print the final generated text
print("\nGenerated Text:")
output_text = tokenizer.decode(generated[0], skip_special_tokens=False)
print(output_text)



Generated Text:
<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an intelligent AI assistant trained to do complex reasoning. You should carefully think about the question before outputting your final response.<|eot_id|><|start_header_id|>user<|end_header_id|>
How many 'r's are there in the word strawberry?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
There are 2 'r's in the word "strawberry".<|eot_id|>
