In [2]:
import torch
from transformers import pipeline

model_id = "meta-llama/Llama-3.2-1B"

pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

pipe("The key to life is")



generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


RuntimeError: isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: Long

In [4]:
import torch
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the model and tokenizer
model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model.eval()

# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the prompt
prompt = "The capital of France is: "

# Tokenize the input prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Initialize variables
generated = input_ids
past_key_values = None
max_length = 50  # Maximum length of the generated sequence
temperature = 0.7
eos_token_id = tokenizer.eos_token_id

# Generate tokens one by one
while generated.shape[1] < max_length:
    if past_key_values is None:
        # For the first step, provide the full input
        outputs = model(input_ids=generated)
    else:
        # For subsequent steps, use only the last token and past key values
        outputs = model(input_ids=generated[:, -1:], past_key_values=past_key_values)
    next_token_logits = outputs.logits[:, -1, :]
    past_key_values = outputs.past_key_values

    # Apply temperature scaling
    next_token_logits = next_token_logits / temperature

    # Compute log probabilities
    log_probs = F.log_softmax(next_token_logits, dim=-1)
    probs = torch.exp(log_probs)

    # Compute entropy and varentropy
    entropy = -torch.sum(probs * log_probs, dim=-1) / math.log(2)  # Convert to bits
    varentropy = torch.sum(probs * ((log_probs / math.log(2) + entropy.unsqueeze(-1)) ** 2), dim=-1)
    print(f"Entropy: {entropy.item():.4f}, Varentropy: {varentropy.item():.4f}")

    # Sample the next token
    next_token = torch.multinomial(probs, num_samples=1)

    # Append the next token to the generated sequence
    generated = torch.cat((generated, next_token), dim=-1)

    # Stop if the EOS token is generated
    if next_token.item() == eos_token_id:
        break

# Decode and print the generated text
output_text = tokenizer.decode(generated[0], skip_special_tokens=True)
print("\nGenerated Text:")
print(output_text)


Entropy: 1.3663, Varentropy: 6.2080
Entropy: 1.1530, Varentropy: 3.6669
Entropy: 0.9399, Varentropy: 7.2019
Entropy: 1.6958, Varentropy: 2.6731
Entropy: 0.2166, Varentropy: 1.5497
Entropy: 0.0032, Varentropy: 0.0458
Entropy: 0.0034, Varentropy: 0.0514
Entropy: 4.6963, Varentropy: 6.4685
Entropy: 0.0413, Varentropy: 0.3000
Entropy: 0.3627, Varentropy: 1.9567
Entropy: 0.4752, Varentropy: 1.2332
Entropy: 0.0017, Varentropy: 0.0243
Entropy: 0.0025, Varentropy: 0.0372
Entropy: 3.7137, Varentropy: 4.9036
Entropy: 0.6406, Varentropy: 3.0265
Entropy: 0.1926, Varentropy: 1.3964
Entropy: 0.0291, Varentropy: 0.2954
Entropy: 3.1535, Varentropy: 4.6297
Entropy: 0.0167, Varentropy: 0.1914
Entropy: 1.0265, Varentropy: 0.3667
Entropy: 2.0927, Varentropy: 2.3343
Entropy: 2.2296, Varentropy: 10.9174
Entropy: 1.0605, Varentropy: 7.5476
Entropy: 0.1166, Varentropy: 0.6634
Entropy: 2.0237, Varentropy: 7.3195
Entropy: 0.1069, Varentropy: 0.8749
Entropy: 0.2434, Varentropy: 1.5999
Entropy: 0.1350, Varentropy

In [9]:
import torch
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the model and tokenizer
model_id = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Add special tokens to the tokenizer
special_tokens = {
    "additional_special_tokens": ["<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"]
}
tokenizer.add_special_tokens(special_tokens)

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model.resize_token_embeddings(len(tokenizer))
model.eval()

# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the conversation as a string prompt
def build_prompt(conversation):
    """
    Build the prompt string from a list of conversation turns.
    Each turn is a tuple: (role, content)
    """
    prompt = ""
    for role, content in conversation:
        if role == "system":
            prompt += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{content}<|eot_id|>"
        elif role == "user":
            prompt += f"<|start_header_id|>user<|end_header_id|>\n{content}<|eot_id|>"
        elif role == "assistant":
            prompt += f"<|start_header_id|>assistant<|end_header_id|>\n{content}"
        else:
            raise ValueError(f"Unknown role: {role}")
    return prompt

# Example conversation
conversation = [
    ("system", "You are an intelligent, helpful AI assistant."),
    ("user", "What is the capital of France?"),
    ("assistant", "The capital of France is:")  # Start of assistant's reply
]

# Build the prompt
prompt = build_prompt(conversation)

# Tokenize the input prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Initialize variables
generated = input_ids
past_key_values = None
max_new_tokens = 50  # Maximum number of new tokens to generate
temperature = 0.7

# Get the EOS token ID
eos_token = "<|eot_id|>"
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)

# Prepare to store entropy and varentropy values
entropies = []
varentropies = []

# Generate tokens one by one
for _ in range(max_new_tokens):
    if past_key_values is None:
        # For the first step, provide the full input
        outputs = model(input_ids=generated)
    else:
        # For subsequent steps, use only the last token and past key values
        outputs = model(input_ids=generated[:, -1:], past_key_values=past_key_values)
    next_token_logits = outputs.logits[:, -1, :]
    past_key_values = outputs.past_key_values

    # Apply temperature scaling
    next_token_logits = next_token_logits / temperature

    # Compute log probabilities
    log_probs = F.log_softmax(next_token_logits, dim=-1)
    probs = torch.exp(log_probs)

    # Compute entropy and varentropy
    entropy = -torch.sum(probs * log_probs, dim=-1) / math.log(2)  # Convert to bits
    varentropy = torch.sum(probs * ((log_probs / math.log(2) + entropy.unsqueeze(-1)) ** 2), dim=-1)
    entropies.append(entropy.item())
    varentropies.append(varentropy.item())

    # Sample the next token
    next_token = torch.multinomial(probs, num_samples=1)

    # Append the next token to the generated sequence
    generated = torch.cat((generated, next_token), dim=-1)

    # Decode the newly generated token
    decoded_token = tokenizer.decode(next_token[0])

    # Print the entropy, varentropy, and the generated token together
    print(f"Entropy: {entropy.item():.4f}, Varentropy: {varentropy.item():.4f}, Token: '{decoded_token}'")

    # Stop if the EOS token is generated
    if next_token.item() == eos_token_id:
        break

print("\nGenerated Text:")
output_text = tokenizer.decode(generated[0], skip_special_tokens=False)
print(output_text)


Entropy: 0.2101, Varentropy: 1.2655, Token: ' Paris'
Entropy: 1.3736, Varentropy: 0.7257, Token: '<|eot_id|>'

Generated Text:
<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an intelligent, helpful AI assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>
What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
The capital of France is: Paris<|eot_id|>


In [15]:
import torch
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the model and tokenizer
model_id = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Add special tokens to the tokenizer
special_tokens = {
    "additional_special_tokens": [
        "<|begin_of_text|>",
        "<|start_header_id|>",
        "<|end_header_id|>",
        "<|eot_id|>",
        "<thinking>",
        "</thinking>",
    ]
}
tokenizer.add_special_tokens(special_tokens)

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model.resize_token_embeddings(len(tokenizer))
model.eval()

# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the conversation as a string prompt
def build_prompt(conversation):
    """
    Build the prompt string from a list of conversation turns.
    Each turn is a tuple: (role, content)
    """
    prompt = ""
    for role, content in conversation:
        if role == "system":
            prompt += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{content}<|eot_id|>"
        elif role == "user":
            prompt += f"<|start_header_id|>user<|end_header_id|>\n{content}<|eot_id|>"
        elif role == "assistant":
            prompt += f"<|start_header_id|>assistant<|end_header_id|>\n{content}"
        else:
            raise ValueError(f"Unknown role: {role}")
    return prompt

# Example conversation
conversation = [
    (
        "system",
        "You are an expert in composing functions. You are given a question and a set of possible functions.\n"
        "Based on the question, you will need to make one or more function/tool calls to achieve the purpose.\n"
        "If none of the functions can be used, point it out. If the given question lacks the parameters required by the function, also point it out. You should only return the function call in tools call sections.\n"
        "If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]\n"
        "You SHOULD NOT include any other text in the response.\n"
        "Here is a list of functions in JSON format that you can invoke.[\n"
        '    {\n        "name": "get_user_info",\n        "description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",\n        "parameters": {\n            "type": "dict",\n            "required": [\n                "user_id"\n            ],\n            "properties": {\n                "user_id": {\n                "type": "integer",\n                "description": "The unique identifier of the user. It is used to fetch the specific user details from the database."\n            },\n            "special": {\n                "type": "string",\n                "description": "Any special information or parameters that need to be considered while fetching user details.",\n                "default": "none"\n                }\n            }\n        }\n    }\n]',
    ),
    (
        "user",
        "Can you retrieve the details for the user with the ID 7890, who has black as their special request?",
    ),
    ("assistant", ""),  # Start of assistant's reply
]

# Build the prompt
prompt = build_prompt(conversation)

# Tokenize the input prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Initialize variables
generated = input_ids
past_key_values = None
max_new_tokens = 50  # Maximum number of new tokens to generate
temperature = 0.7

# Get the EOS token ID
eos_token = "<|eot_id|>"
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)

# Prepare to store entropy and varentropy values
entropies = []
varentropies = []

# Tokens to check for (e.g., '<thinking>' token)
thinking_token = "<thinking>"
thinking_token_id = tokenizer.convert_tokens_to_ids(thinking_token)
end_thinking_token = "</thinking>"
end_thinking_token_id = tokenizer.convert_tokens_to_ids(end_thinking_token)

# Generate tokens one by one
for _ in range(max_new_tokens):
    if past_key_values is None:
        # For the first step, provide the full input
        outputs = model(input_ids=generated)
    else:
        # For subsequent steps, use only the last token and past key values
        outputs = model(input_ids=generated[:, -1:], past_key_values=past_key_values)
    next_token_logits = outputs.logits[:, -1, :]
    past_key_values = outputs.past_key_values

    # Compute probabilities before temperature scaling for entropy calculation
    log_probs = F.log_softmax(next_token_logits, dim=-1)
    probs = torch.exp(log_probs)

    # Compute entropy and varentropy
    entropy = -torch.sum(probs * log_probs, dim=-1) / math.log(2)  # Convert to bits
    varentropy = torch.sum(
        probs * ((log_probs / math.log(2) + entropy.unsqueeze(-1)) ** 2), dim=-1
    )
    entropies.append(entropy.item())
    varentropies.append(varentropy.item())

    # Adjust sampling strategy based on entropy and varentropy
    if entropy.item() < 0.5 and varentropy.item() < 0.1:
        # Low entropy and low varentropy: model is confident
        # Take argmax (greedy)
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        sampling_info = "Greedy sampling (confident)"
    elif entropy.item() > 3.0 and varentropy.item() > 1.0:
        # High entropy and high varentropy: model is very uncertain
        # Inject '<thinking>' token
        if thinking_token_id not in generated:
            # Inject '<thinking>' token
            next_token = torch.tensor([[thinking_token_id]], device=device)
            sampling_info = "Injecting '<thinking>' token (high uncertainty)"
        else:
            # Continue sampling with adjusted parameters
            adjusted_temperature = min(1.5, temperature * 1.3)
            next_token_logits = next_token_logits / adjusted_temperature
            # Sample the next token
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            sampling_info = f"Sampling with increased temperature {adjusted_temperature:.2f} (after thinking)"
    elif entropy.item() < 5.0 and varentropy.item() > 1.0:
        # Low entropy and high varentropy: model is exploring options
        # Adjust temperature and top_k
        adjusted_temperature = min(1.5, temperature * 1.2)
        top_k = 40  # Adjusted top_k
        # Apply temperature scaling
        next_token_logits = next_token_logits / adjusted_temperature
        # Apply top_k sampling
        values, indices = torch.topk(next_token_logits, k=top_k, dim=-1)
        probs = F.softmax(values, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        next_token = indices.gather(-1, next_token)
        sampling_info = (
            f"Sampling with adjusted temperature {adjusted_temperature:.2f} and top_k {top_k}"
        )
    else:
        # Default sampling
        next_token_logits = next_token_logits / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        sampling_info = "Default sampling"

    # Append the next token to the generated sequence
    generated = torch.cat((generated, next_token), dim=-1)

    # Decode the newly generated token
    decoded_token = tokenizer.decode(next_token[0])

    # Print the entropy, varentropy, sampling info, and the generated token
    print(
        f"Entropy: {entropy.item():.4f}, Varentropy: {varentropy.item():.4f}, {sampling_info}, Token: '{decoded_token}'"
    )

    # Stop if the EOS token is generated
    if next_token.item() == eos_token_id:
        break

print("\nGenerated Text:")
output_text = tokenizer.decode(generated[0], skip_special_tokens=False)
print(output_text)


Entropy: 4.4438, Varentropy: 27.5838, Injecting '<thinking>' token (high uncertainty), Token: '<thinking>'
Entropy: 14.3639, Varentropy: 8.8004, Sampling with increased temperature 0.91 (after thinking), Token: ' importer'
Entropy: 9.1457, Varentropy: 34.3509, Sampling with increased temperature 0.91 (after thinking), Token: '_info'
Entropy: 4.4108, Varentropy: 25.9069, Sampling with increased temperature 0.91 (after thinking), Token: ' ='
Entropy: 3.8362, Varentropy: 14.1959, Sampling with increased temperature 0.91 (after thinking), Token: ' get'
Entropy: 0.0452, Varentropy: 0.7342, Default sampling, Token: '_user'
Entropy: 0.0115, Varentropy: 0.1962, Default sampling, Token: '_info'
Entropy: 0.2473, Varentropy: 2.7016, Sampling with adjusted temperature 0.84 and top_k 40, Token: '(user'
Entropy: 0.0102, Varentropy: 0.1279, Default sampling, Token: '_id'
Entropy: 0.0858, Varentropy: 0.8344, Default sampling, Token: '='
Entropy: 0.0496, Varentropy: 0.6552, Default sampling, Token: '78

In [17]:
import torch
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the model and tokenizer
model_id = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Add special tokens to the tokenizer if needed (e.g., for <|eot_id|>)
special_tokens = {
    "additional_special_tokens": [
        "<|begin_of_text|>",
        "<|start_header_id|>",
        "<|end_header_id|>",
        "<|eot_id|>"
    ]
}
tokenizer.add_special_tokens(special_tokens)

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model.resize_token_embeddings(len(tokenizer))
model.eval()

# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the conversation as a string prompt
def build_prompt(conversation):
    """
    Build the prompt string from a list of conversation turns.
    Each turn is a tuple: (role, content)
    """
    prompt = ""
    for role, content in conversation:
        if role == "system":
            prompt += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{content}<|eot_id|>"
        elif role == "user":
            prompt += f"<|start_header_id|>user<|end_header_id|>\n{content}<|eot_id|>"
        elif role == "assistant":
            prompt += f"<|start_header_id|>assistant<|end_header_id|>\n{content}"
        else:
            raise ValueError(f"Unknown role: {role}")
    return prompt

# Example conversation
conversation = [
    (
        "system", "You are an intelligent AI assistant trained to do complex reasoning. You should carefully think about problems and include your thoughts in a <thinking></thinking> tag before outputting your final response.",
    ),
    (
        "user",
        "How many ways can you arrange the letters in the word 'MISSISSIPPI' such that no two 'I's are adjacent?",
    ),
    ("assistant", ""),  # Start of assistant's reply
]

# Build the prompt
prompt = build_prompt(conversation)

# Tokenize the input prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Initialize variables
generated = input_ids
past_key_values = None
max_new_tokens = 100  # Maximum number of new tokens to generate
temperature = 0.7

# Get the EOS token ID
eos_token = "<|eot_id|>"
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)

# Prepare to store entropy and varentropy values
entropies = []
varentropies = []

# Flag to check if thinking phrase has been injected
thinking_injected = False

# Generate tokens one by one
for _ in range(max_new_tokens):
    if past_key_values is None:
        # For the first step, provide the full input
        outputs = model(input_ids=generated)
    else:
        # For subsequent steps, use only the last token and past key values
        outputs = model(input_ids=generated[:, -1:], past_key_values=past_key_values)
    next_token_logits = outputs.logits[:, -1, :]
    past_key_values = outputs.past_key_values

    # Compute probabilities before temperature scaling for entropy calculation
    log_probs = F.log_softmax(next_token_logits, dim=-1)
    probs = torch.exp(log_probs)

    # Compute entropy and varentropy
    entropy = -torch.sum(probs * log_probs, dim=-1) / math.log(2)  # Convert to bits
    varentropy = torch.sum(
        probs * ((log_probs / math.log(2) + entropy.unsqueeze(-1)) ** 2), dim=-1
    )
    entropies.append(entropy.item())
    varentropies.append(varentropy.item())

    # Adjust sampling strategy based on entropy and varentropy
    if entropy.item() < 0.5 and varentropy.item() < 0.1:
        # Low entropy and low varentropy: model is confident
        # Take argmax (greedy)
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        sampling_info = "Greedy sampling (confident)"
    elif entropy.item() > 3.0 and varentropy.item() > 1.0:
        # High entropy and high varentropy: model is very uncertain
        # Inject 'Hmm... let me think...'
        if not thinking_injected:
            # Inject 'Hmm... let me think...'
            thinking_phrase = "Hmm... let me think..."
            thinking_tokens = tokenizer.encode(thinking_phrase, return_tensors="pt").to(device)
            # Append thinking tokens to generated
            generated = torch.cat((generated, thinking_tokens[:, :1]), dim=-1)
            # Process the first token to update past_key_values
            outputs = model(input_ids=thinking_tokens[:, :1], past_key_values=past_key_values)
            past_key_values = outputs.past_key_values
            # Loop over the rest of the thinking tokens
            for idx in range(1, thinking_tokens.size(1)):
                next_token = thinking_tokens[:, idx:idx+1]
                generated = torch.cat((generated, next_token), dim=-1)
                outputs = model(input_ids=next_token, past_key_values=past_key_values)
                past_key_values = outputs.past_key_values
            sampling_info = "Injecting 'Hmm... let me think...' (high uncertainty)"
            thinking_injected = True
            continue  # Proceed to next iteration
        else:
            # Continue sampling with adjusted parameters
            adjusted_temperature = min(1.5, temperature * 1.3)
            next_token_logits = next_token_logits / adjusted_temperature
            # Sample the next token
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            sampling_info = f"Sampling with increased temperature {adjusted_temperature:.2f} (after thinking)"
    elif entropy.item() < 5.0 and varentropy.item() > 1.0:
        # Low entropy and high varentropy: model is exploring options
        # Adjust temperature and top_k
        adjusted_temperature = min(1.5, temperature * 1.2)
        top_k = 40  # Adjusted top_k
        # Apply temperature scaling
        next_token_logits = next_token_logits / adjusted_temperature
        # Apply top_k sampling
        values, indices = torch.topk(next_token_logits, k=top_k, dim=-1)
        probs = F.softmax(values, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        next_token = indices.gather(-1, next_token)
        sampling_info = (
            f"Sampling with adjusted temperature {adjusted_temperature:.2f} and top_k {top_k}"
        )
    else:
        # Default sampling
        next_token_logits = next_token_logits / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        sampling_info = "Default sampling"

    # Append the next token to the generated sequence
    generated = torch.cat((generated, next_token), dim=-1)

    # Decode the newly generated token
    decoded_token = tokenizer.decode(next_token[0])

    # Print the entropy, varentropy, sampling info, and the generated token
    print(
        f"Entropy: {entropy.item():.4f}, Varentropy: {varentropy.item():.4f}, {sampling_info}, Token: '{decoded_token}'"
    )

    # Stop if the EOS token is generated
    if next_token.item() == eos_token_id:
        break

print("\nGenerated Text:")
output_text = tokenizer.decode(generated[0], skip_special_tokens=False)
print(output_text)


Entropy: 3.5499, Varentropy: 12.1908, Sampling with increased temperature 0.91 (after thinking), Token: ' 

'
Entropy: 4.8438, Varentropy: 9.7476, Sampling with increased temperature 0.91 (after thinking), Token: 'We'
Entropy: 2.0792, Varentropy: 4.0139, Sampling with adjusted temperature 0.84 and top_k 40, Token: ' can'
Entropy: 4.2960, Varentropy: 7.4074, Sampling with increased temperature 0.91 (after thinking), Token: ' think'
Entropy: 0.8434, Varentropy: 3.8189, Sampling with adjusted temperature 0.84 and top_k 40, Token: ' of'
Entropy: 2.5318, Varentropy: 5.8520, Sampling with adjusted temperature 0.84 and top_k 40, Token: ' placing'
Entropy: 1.9199, Varentropy: 8.3964, Sampling with adjusted temperature 0.84 and top_k 40, Token: ' the'
Entropy: 4.0298, Varentropy: 7.2149, Sampling with increased temperature 0.91 (after thinking), Token: ' ''
Entropy: 1.6783, Varentropy: 3.7608, Sampling with adjusted temperature 0.84 and top_k 40, Token: 'I'
Entropy: 0.5115, Varentropy: 2.0260, 

In [18]:
import torch
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

# Add special tokens to the tokenizer if needed
special_tokens = {
    "additional_special_tokens": [
        "<|begin_of_text|>",
        "<|start_header_id|>",
        "<|end_header_id|>",
        "<|eot_id|>",
    ]
}
tokenizer.add_special_tokens(special_tokens)

# Load both models
model_small_id = "meta-llama/Llama-3.2-1B-Instruct"
model_large_id = "meta-llama/Llama-3.2-3B-Instruct"

model_small = AutoModelForCausalLM.from_pretrained(
    model_small_id, torch_dtype=torch.bfloat16
)
model_large = AutoModelForCausalLM.from_pretrained(
    model_large_id, torch_dtype=torch.bfloat16
)

# Resize token embeddings if special tokens were added
model_small.resize_token_embeddings(len(tokenizer))
model_large.resize_token_embeddings(len(tokenizer))

model_small.eval()
model_large.eval()

# Move models to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_small.to(device)
model_large.to(device)

# Define the conversation as a string prompt
def build_prompt(conversation):
    """
    Build the prompt string from a list of conversation turns.
    Each turn is a tuple: (role, content)
    """
    prompt = ""
    for role, content in conversation:
        if role == "system":
            prompt += (
                f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
                f"{content}<|eot_id|>"
            )
        elif role == "user":
            prompt += (
                f"<|start_header_id|>user<|end_header_id|>\n{content}<|eot_id|>"
            )
        elif role == "assistant":
            prompt += (
                f"<|start_header_id|>assistant<|end_header_id|>\n{content}"
            )
        else:
            raise ValueError(f"Unknown role: {role}")
    return prompt

# Example conversation
conversation = [
    (
        "system",
        "You are an intelligent AI assistant trained to do complex reasoning. You should carefully think about problems and include your thoughts in a <thinking></thinking> tag before outputting your final response.",
    ),
    (
        "user",
        "How many ways can you arrange the letters in the word 'MISSISSIPPI' such that no two 'I's are adjacent?",
    ),
    ("assistant", ""),  # Start of assistant's reply
]

# Build the prompt
prompt = build_prompt(conversation)

# Tokenize the input prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Initialize variables
generated = input_ids
max_new_tokens = 100  # Maximum number of new tokens to generate
temperature = 0.7

# Get the EOS token ID
eos_token = "<|eot_id|>"
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)

# Prepare to store entropy and varentropy values
entropies = []
varentropies = []

# Flags and variables to manage model switching
thinking_injected = False
current_model = model_small
current_past_key_values = None
tokens_since_last_switch = 0  # Number of tokens generated since last model switch

# Generate tokens one by one
for _ in range(max_new_tokens):
    if current_past_key_values is None:
        # For the first step or after switching models, provide the full input
        outputs = current_model(input_ids=generated)
        tokens_since_last_switch = 0
    else:
        # For subsequent steps, use only the last token and past key values
        outputs = current_model(
            input_ids=generated[:, -1:], past_key_values=current_past_key_values
        )
    next_token_logits = outputs.logits[:, -1, :]
    current_past_key_values = outputs.past_key_values

    # Compute probabilities before temperature scaling for entropy calculation
    log_probs = F.log_softmax(next_token_logits, dim=-1)
    probs = torch.exp(log_probs)

    # Compute entropy and varentropy
    entropy = -torch.sum(probs * log_probs, dim=-1) / math.log(2)  # Convert to bits
    varentropy = torch.sum(
        probs * ((log_probs / math.log(2) + entropy.unsqueeze(-1)) ** 2), dim=-1
    )
    entropies.append(entropy.item())
    varentropies.append(varentropy.item())

    # Adjust sampling strategy based on entropy and varentropy
    if entropy.item() < 0.5 and varentropy.item() < 0.1:
        # Low entropy and low varentropy: model is confident
        # Use the smaller model if not already using it
        if current_model != model_small:
            print("Switching to small model (confident).")
            current_model = model_small
            # Recompute past_key_values for the small model
            # We need to process the generated tokens since last switch
            generated_tokens_to_reprocess = generated[:, -tokens_since_last_switch:]
            outputs = current_model(input_ids=generated_tokens_to_reprocess)
            current_past_key_values = outputs.past_key_values
            next_token_logits = outputs.logits[:, -1, :]
            tokens_since_last_switch = 0
        # Take argmax (greedy)
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        sampling_info = "Greedy sampling (confident)"
    elif entropy.item() > 3.0 and varentropy.item() > 1.0:
        # High entropy and high varentropy: model is very uncertain
        # Use the larger model if not already using it
        if current_model != model_large:
            print("Switching to large model (uncertain).")
            current_model = model_large
            # Recompute past_key_values for the large model
            # We need to process all generated tokens
            outputs = current_model(input_ids=generated)
            current_past_key_values = outputs.past_key_values
            next_token_logits = outputs.logits[:, -1, :]
            tokens_since_last_switch = 0
        # Inject 'Hmm... let me think...'
        if not thinking_injected:
            thinking_phrase = "<thinking>Hmm... let me think...</thinking>\n"
            thinking_tokens = tokenizer.encode(
                thinking_phrase, return_tensors="pt"
            ).to(device)
            # Append thinking tokens to generated
            generated = torch.cat((generated, thinking_tokens[:, :1]), dim=-1)
            # Process the first token to update past_key_values
            outputs = current_model(
                input_ids=thinking_tokens[:, :1], past_key_values=current_past_key_values
            )
            current_past_key_values = outputs.past_key_values
            # Loop over the rest of the thinking tokens
            for idx in range(1, thinking_tokens.size(1)):
                next_token = thinking_tokens[:, idx : idx + 1]
                generated = torch.cat((generated, next_token), dim=-1)
                outputs = current_model(
                    input_ids=next_token, past_key_values=current_past_key_values
                )
                current_past_key_values = outputs.past_key_values
            sampling_info = "Injecting 'Hmm... let me think...' (high uncertainty)"
            thinking_injected = True
            tokens_since_last_switch += thinking_tokens.size(1)
            continue  # Proceed to next iteration
        else:
            # Continue sampling with adjusted parameters
            adjusted_temperature = min(1.5, temperature * 1.3)
            next_token_logits = next_token_logits / adjusted_temperature
            # Sample the next token
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            sampling_info = f"Sampling with increased temperature {adjusted_temperature:.2f} (after thinking)"
    else:
        # Default sampling
        next_token_logits = next_token_logits / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        sampling_info = "Default sampling"

    # Append the next token to the generated sequence
    generated = torch.cat((generated, next_token), dim=-1)
    tokens_since_last_switch += 1

    # Decode the newly generated token
    decoded_token = tokenizer.decode(next_token[0])

    # Print the entropy, varentropy, sampling info, and the generated token
    print(
        f"Entropy: {entropy.item():.4f}, Varentropy: {varentropy.item():.4f}, {sampling_info}, Token: '{decoded_token}'"
    )

    # Stop if the EOS token is generated
    if next_token.item() == eos_token_id:
        break

print("\nGenerated Text:")
output_text = tokenizer.decode(generated[0], skip_special_tokens=False)
print(output_text)


config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

Switching to large model (uncertain).
Entropy: 2.2021, Varentropy: 6.6188, Default sampling, Token: 'To'
Entropy: 1.6461, Varentropy: 4.6722, Default sampling, Token: ' solve'
Switching to small model (confident).
Entropy: 0.0041, Varentropy: 0.0492, Greedy sampling (confident), Token: ' this'
Entropy: 1.9102, Varentropy: 11.9209, Default sampling, Token: ' problem'
Entropy: 0.3162, Varentropy: 2.8725, Default sampling, Token: ','
Entropy: 1.4102, Varentropy: 3.1123, Default sampling, Token: ' we'
Entropy: 1.4473, Varentropy: 3.5699, Default sampling, Token: ' can'
Switching to large model (uncertain).
Entropy: 3.4094, Varentropy: 10.7237, Sampling with increased temperature 0.91 (after thinking), Token: ' use'
Entropy: 0.7832, Varentropy: 1.9896, Default sampling, Token: ' the'
Entropy: 0.9345, Varentropy: 5.7506, Default sampling, Token: ' concept'
Switching to small model (confident).
Entropy: 0.0002, Varentropy: 0.0041, Greedy sampling (confident), Token: ' of'
Switching to large m

In [26]:
import torch
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

# For colored output in Jupyter
from IPython.display import clear_output, display, HTML

# Define colors for small and large models
SMALL_MODEL_COLOR = 'blue'
LARGE_MODEL_COLOR = 'red'

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

# Add special tokens to the tokenizer if needed
special_tokens = {
    "additional_special_tokens": [
        "<|begin_of_text|>",
        "<|start_header_id|>",
        "<|end_header_id|>",
        "<|eot_id|>",
    ]
}
tokenizer.add_special_tokens(special_tokens)

# Load both models
model_small_id = "meta-llama/Llama-3.2-1B-Instruct"
model_large_id = "meta-llama/Llama-3.2-3B-Instruct"

model_small = AutoModelForCausalLM.from_pretrained(
    model_small_id, torch_dtype=torch.bfloat16
)
model_large = AutoModelForCausalLM.from_pretrained(
    model_large_id, torch_dtype=torch.bfloat16
)

# Resize token embeddings if special tokens were added
model_small.resize_token_embeddings(len(tokenizer))
model_large.resize_token_embeddings(len(tokenizer))

model_small.eval()
model_large.eval()

# Move models to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_small.to(device)
model_large.to(device)

# Define the conversation as a string prompt
def build_prompt(conversation):
    """
    Build the prompt string from a list of conversation turns.
    Each turn is a tuple: (role, content)
    """
    prompt = ""
    for role, content in conversation:
        if role == "system":
            prompt += (
                f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
                f"{content}<|eot_id|>"
            )
        elif role == "user":
            prompt += (
                f"<|start_header_id|>user<|end_header_id|>\n{content}<|eot_id|>"
            )
        elif role == "assistant":
            prompt += (
                f"<|start_header_id|>assistant<|end_header_id|>\n{content}"
            )
        else:
            raise ValueError(f"Unknown role: {role}")
    return prompt

# Example conversation
conversation = [
    (
        "system",
        "You are an intelligent AI assistant trained to do complex reasoning. You should carefully think about the question before outputting your final response.",
    ),
    (
        "user",
        "Tell me a good defensive coverage against a hailmary in NFL",
    ),
    ("assistant", ""),  # Start of assistant's reply
]

# Build the prompt
prompt = build_prompt(conversation)

# Tokenize the input prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Initialize variables
generated = input_ids
max_new_tokens = 100  # Maximum number of new tokens to generate
temperature = 0.7

# Get the EOS token ID
eos_token = "<|eot_id|>"
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)

# Prepare to store entropy and varentropy values
entropies = []
varentropies = []

# Flags and variables to manage model switching
current_model = model_small
current_past_key_values = None
tokens_since_last_switch = 0  # Number of tokens generated since last model switch

# For storing generated tokens and their colors
generated_tokens = []
generated_colors = []

# Generate tokens one by one
for _ in range(max_new_tokens):
    if current_past_key_values is None:
        # For the first step or after switching models, provide the full input
        outputs = current_model(input_ids=generated)
        tokens_since_last_switch = 0
    else:
        # For subsequent steps, use only the last token and past key values
        outputs = current_model(
            input_ids=generated[:, -1:], past_key_values=current_past_key_values
        )
    next_token_logits = outputs.logits[:, -1, :]
    current_past_key_values = outputs.past_key_values

    # Compute probabilities before temperature scaling for entropy calculation
    log_probs = F.log_softmax(next_token_logits, dim=-1)
    probs = torch.exp(log_probs)

    # Compute entropy and varentropy
    entropy = -torch.sum(probs * log_probs, dim=-1) / math.log(2)  # Convert to bits
    varentropy = torch.sum(
        probs * ((log_probs / math.log(2) + entropy.unsqueeze(-1)) ** 2), dim=-1
    )
    entropies.append(entropy.item())
    varentropies.append(varentropy.item())

    # Adjust sampling strategy based on entropy and varentropy
    if entropy.item() < 0.75 and varentropy.item() < 0.3:
        # Low entropy and low varentropy: model is confident
        # Use the smaller model if not already using it
        if current_model != model_small:
            print("\nSwitching to small model (confident).")
            current_model = model_small
            # Recompute past_key_values for the small model
            # We need to process the generated tokens since last switch
            generated_tokens_to_reprocess = generated[:, -tokens_since_last_switch:]
            outputs = current_model(input_ids=generated_tokens_to_reprocess)
            current_past_key_values = outputs.past_key_values
            next_token_logits = outputs.logits[:, -1, :]
            tokens_since_last_switch = 0
        # Take argmax (greedy)
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        sampling_info = "Greedy sampling (confident)"
    elif entropy.item() > 4.0 and varentropy.item() > 1.5:
        # High entropy and high varentropy: model is very uncertain
        # Use the larger model if not already using it
        if current_model != model_large:
            print("\nSwitching to large model (uncertain).")
            current_model = model_large
            # Recompute past_key_values for the large model
            # We need to process all generated tokens
            outputs = current_model(input_ids=generated)
            current_past_key_values = outputs.past_key_values
            next_token_logits = outputs.logits[:, -1, :]
            tokens_since_last_switch = 0
        # Continue sampling with adjusted parameters
        adjusted_temperature = min(1.5, temperature * 1.3)
        next_token_logits = next_token_logits / adjusted_temperature
        # Sample the next token
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        sampling_info = f"Sampling with increased temperature {adjusted_temperature:.2f} (uncertain)"
    else:
        # Default sampling
        next_token_logits = next_token_logits / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        sampling_info = "Default sampling"

    # Append the next token to the generated sequence
    generated = torch.cat((generated, next_token), dim=-1)
    tokens_since_last_switch += 1

    # Decode the newly generated token
    decoded_token = tokenizer.decode(next_token[0], skip_special_tokens=False)

    # Store the token and its associated color
    generated_tokens.append(decoded_token)
    if current_model == model_small:
        generated_colors.append(SMALL_MODEL_COLOR)
    else:
        generated_colors.append(LARGE_MODEL_COLOR)

    # Display the generated text with colors
    # Build the HTML content
    html_content = ''
    for token, color in zip(generated_tokens, generated_colors):
        html_content += f'<span style="color: {color}">{token}</span>'
    clear_output(wait=True)
    display(HTML(html_content))

    # Stop if the EOS token is generated
    if next_token.item() == eos_token_id:
        break

# After generation, print the final generated text
print("\nGenerated Text:")
output_text = tokenizer.decode(generated[0], skip_special_tokens=False)
print(output_text)



Generated Text:
<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an intelligent AI assistant trained to do complex reasoning. You should carefully think about the question before outputting your final response.<|eot_id|><|start_header_id|>user<|end_header_id|>
Tell me a good defensive coverage against a hailmary in NFL<|eot_id|><|start_header_id|>assistant<|end_header_id|>
A hail mary in the NFL can be a thrilling play to watch and defend against, but ultimately, its unpredictability makes it challenging. However, here's a general defensive strategy to consider:

**Basic Principles:**

1. **Defensive alignment**: Line up in a nearly shell or Cover 2 deep safety look, with one safety deep (e.g., a single-high safety). This alignment encourages the defense of the ball and limits the quarterback's options.
* slot corner(s) should be assigned


In [24]:
import torch
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

# For colored output in Jupyter
from IPython.display import display, HTML, clear_output
from ipywidgets import Output
import time

# Define colors for small and large models
SMALL_MODEL_COLOR = 'blue'
LARGE_MODEL_COLOR = 'red'

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

# Add special tokens to the tokenizer if needed
special_tokens = {
    "additional_special_tokens": [
        "<|begin_of_text|>",
        "<|start_header_id|>",
        "<|end_header_id|>",
        "<|eot_id|>",
    ]
}
tokenizer.add_special_tokens(special_tokens)

# Load both models
model_small_id = "meta-llama/Llama-3.2-1B-Instruct"
model_large_id = "meta-llama/Llama-3.2-3B-Instruct"

model_small = AutoModelForCausalLM.from_pretrained(
    model_small_id, torch_dtype=torch.bfloat16
)
model_large = AutoModelForCausalLM.from_pretrained(
    model_large_id, torch_dtype=torch.bfloat16
)

# Resize token embeddings if special tokens were added
model_small.resize_token_embeddings(len(tokenizer))
model_large.resize_token_embeddings(len(tokenizer))

model_small.eval()
model_large.eval()

# Move models to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_small.to(device)
model_large.to(device)

# Define presets for model switching
presets = {
    'conservative': {
        'small_model_entropy_threshold': 1.0,
        'small_model_varentropy_threshold': 0.2,
        'large_model_entropy_threshold': 2.5,
        'large_model_varentropy_threshold': 0.8,
    },
    'balanced': {
        'small_model_entropy_threshold': 0.5,
        'small_model_varentropy_threshold': 0.1,
        'large_model_entropy_threshold': 3.0,
        'large_model_varentropy_threshold': 1.0,
    },
    'aggressive': {
        'small_model_entropy_threshold': 0.3,
        'small_model_varentropy_threshold': 0.05,
        'large_model_entropy_threshold': 4.0,
        'large_model_varentropy_threshold': 1.5,
    }
}

# Select a preset
selected_preset = 'balanced'  # Change to 'conservative' or 'aggressive' as needed
preset = presets[selected_preset]

# Define the conversation as a string prompt
def build_prompt(conversation):
    """
    Build the prompt string from a list of conversation turns.
    Each turn is a tuple: (role, content)
    """
    prompt = ""
    for role, content in conversation:
        if role == "system":
            prompt += (
                f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
                f"{content}<|eot_id|>"
            )
        elif role == "user":
            prompt += (
                f"<|start_header_id|>user<|end_header_id|>\n{content}<|eot_id|>"
            )
        elif role == "assistant":
            prompt += (
                f"<|start_header_id|>assistant<|end_header_id|>\n{content}"
            )
        else:
            raise ValueError(f"Unknown role: {role}")
    return prompt

# Example conversation
conversation = [
    (
        "system",
        "You are an intelligent AI assistant trained to do complex reasoning. You should carefully think about problems and include your thoughts in a <thinking></thinking> tag before outputting your final response.",
    ),
    (
        "user",
        "Tell me a good defensive coverage against a Hail Mary in the NFL.",
    ),
    ("assistant", ""),  # Start of assistant's reply
]

# Build the prompt
prompt = build_prompt(conversation)

# Tokenize the input prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Initialize variables
generated = input_ids
max_new_tokens = 100  # Maximum number of new tokens to generate
temperature = 0.7

# Get the EOS token ID
eos_token = "<|eot_id|>"
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)

# Prepare to store entropy and varentropy values
entropies = []
varentropies = []

# Flags and variables to manage model switching
current_model = model_small
current_past_key_values = None
tokens_since_last_switch = 0  # Number of tokens generated since last model switch

# For storing generated tokens and their colors
generated_tokens = []
generated_colors = []

# Initialize the Output widget
out = Output()
display(out)

# Generate tokens one by one
for _ in range(max_new_tokens):
    if current_past_key_values is None:
        # For the first step or after switching models, provide the full input
        outputs = current_model(input_ids=generated)
        tokens_since_last_switch = 0
    else:
        # For subsequent steps, use only the last token and past key values
        outputs = current_model(
            input_ids=generated[:, -1:], past_key_values=current_past_key_values
        )
    next_token_logits = outputs.logits[:, -1, :]
    current_past_key_values = outputs.past_key_values

    # Compute probabilities before temperature scaling for entropy calculation
    log_probs = F.log_softmax(next_token_logits, dim=-1)
    probs = torch.exp(log_probs)

    # Compute entropy and varentropy
    entropy = -torch.sum(probs * log_probs, dim=-1) / math.log(2)  # Convert to bits
    varentropy = torch.sum(
        probs * ((log_probs / math.log(2) + entropy.unsqueeze(-1)) ** 2), dim=-1
    )
    entropies.append(entropy.item())
    varentropies.append(varentropy.item())

    # Adjust sampling strategy based on entropy and varentropy
    if entropy.item() < preset['small_model_entropy_threshold'] and varentropy.item() < preset['small_model_varentropy_threshold']:
        # Low entropy and low varentropy: model is confident
        # Use the smaller model if not already using it
        if current_model != model_small:
            print("\nSwitching to small model (confident).")
            current_model = model_small
            # Recompute past_key_values for the small model
            # We need to process the generated tokens since last switch
            generated_tokens_to_reprocess = generated[:, -tokens_since_last_switch:]
            outputs = current_model(input_ids=generated_tokens_to_reprocess)
            current_past_key_values = outputs.past_key_values
            next_token_logits = outputs.logits[:, -1, :]
            tokens_since_last_switch = 0
        # Take argmax (greedy)
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        sampling_info = "Greedy sampling (confident)"
    elif entropy.item() > preset['large_model_entropy_threshold'] and varentropy.item() > preset['large_model_varentropy_threshold']:
        # High entropy and high varentropy: model is very uncertain
        # Use the larger model if not already using it
        if current_model != model_large:
            print("\nSwitching to large model (uncertain).")
            current_model = model_large
            # Recompute past_key_values for the large model
            # We need to process all generated tokens
            outputs = current_model(input_ids=generated)
            current_past_key_values = outputs.past_key_values
            next_token_logits = outputs.logits[:, -1, :]
            tokens_since_last_switch = 0
        # Continue sampling with adjusted parameters
        adjusted_temperature = min(1.5, temperature * 1.3)
        next_token_logits = next_token_logits / adjusted_temperature
        # Sample the next token
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        sampling_info = f"Sampling with increased temperature {adjusted_temperature:.2f} (uncertain)"
    else:
        # Default sampling
        next_token_logits = next_token_logits / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        sampling_info = "Default sampling"

    # Append the next token to the generated sequence
    generated = torch.cat((generated, next_token), dim=-1)
    tokens_since_last_switch += 1

    # Decode the newly generated token
    decoded_token = tokenizer.decode(next_token[0], skip_special_tokens=False)

    # Store the token and its associated color
    generated_tokens.append(decoded_token)
    if current_model == model_small:
        generated_colors.append(SMALL_MODEL_COLOR)
    else:
        generated_colors.append(LARGE_MODEL_COLOR)

    # Display the generated text with colors
    # Build the HTML content
    html_content = ''
    for token, color in zip(generated_tokens, generated_colors):
        html_content += f'<span style="color: {color}">{token}</span>'

    # Update the output
    with out:
        clear_output(wait=True)
        display(HTML(html_content))

    # Optional sleep to reduce choppiness
    time.sleep(0.05)

    # Stop if the EOS token is generated
    if next_token.item() == eos_token_id:
        break

# After generation, print the final generated text
print("\nGenerated Text:")
output_text = tokenizer.decode(generated[0], skip_special_tokens=False)
print(output_text)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Output()


Switching to large model (uncertain).

Switching to small model (confident).

Switching to large model (uncertain).

Switching to small model (confident).

Switching to large model (uncertain).

Switching to small model (confident).

Switching to large model (uncertain).

Generated Text:
<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an intelligent AI assistant trained to do complex reasoning. You should carefully think about problems and include your thoughts in a <thinking></thinking> tag before outputting your final response.<|eot_id|><|start_header_id|>user<|end_header_id|>
Tell me a good defensive coverage against a Hail Mary in the NFL.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
</thinking>

To defend against a H1N1 (Hailong or Hail Mary), the key is to identify the possible rookie quarterback and assign an extra defender to be the last line of defense. Here are some strategies that can be employed to defend against a Hail Mary:
