In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

import matplotlib.pyplot as plt
import ipywidgets as widgets

In [2]:
model = "google/gemma-3-1b-pt"

tokenizer = AutoTokenizer.from_pretrained(model)
model     = AutoModelForCausalLM.from_pretrained(model).to("cuda")

In [3]:
prompt = "The cat sat on the"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# Get logits without gradients
with torch.no_grad():
    outputs = model(**inputs)

last_token_logits = outputs.logits[0, -1]
# show the length of probs
print(f"Length of probs: {len(last_token_logits)}")

Length of probs: 262144


In [4]:
# Define the plotting function
def plot_top_tokens(temp):
    scaled_logits = last_token_logits / temp
    probs = torch.softmax(scaled_logits, dim=-1)

    top_probs, top_indices = torch.topk(probs, k=10)
    top_tokens = tokenizer.convert_ids_to_tokens(top_indices.tolist())

    plt.figure(figsize=(16, 5))
    plt.bar(top_tokens, top_probs.cpu().numpy())
    plt.title(f"Top Predicted Next Tokens (Temp = {temp:.2f})")
    plt.xlabel("Token")
    plt.ylabel("Probability")
    plt.show()

# Create a slider widget
temp_slider = widgets.FloatSlider(
    value=1.0,
    min=0.05,
    max=4.0,
    step=0.05,
    description='Temperature:',
    continuous_update=False
)

# Link the slider to the plotting function
widgets.interact(plot_top_tokens, temp=temp_slider)

interactive(children=(FloatSlider(value=1.0, continuous_update=False, description='Temperature:', max=4.0, min…

<function __main__.plot_top_tokens(temp)>