In [5]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import math

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Prepare input
prompt = "The quick brown fox"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

# Generate tokens
outputs = model.generate(
    input_ids,
    max_new_tokens=5,  # Generate 5 new tokens
    return_dict_in_generate=True,
    output_scores=True  # To get the logits for each token
)

# Compute transition scores (log probabilities of the generated tokens)
transition_scores = model.compute_transition_scores(
    outputs.sequences, outputs.scores, normalize_logits=True
)

# Get the generated tokens (excluding the prompt tokens)
generated_tokens = outputs.sequences[0, input_ids.shape[1]:]

# Print the generated sequence
print(f"Generated sequence: {tokenizer.decode(generated_tokens)}")

# Display transition scores for each token
print("\nToken\t| Next Token\t| Log Probability")
print("-" * 50)

for i in range(len(generated_tokens) - 1):
    current_token = tokenizer.decode(generated_tokens[i])
    next_token = tokenizer.decode(generated_tokens[i + 1])
    log_prob = transition_scores[0, i].item()
    print(f"{current_token}\t| {next_token}\t | {log_prob:.4f}")

# Calculate sequence log probability
sequence_log_prob = transition_scores[0].sum().item()
sequence_prob = math.exp(sequence_log_prob)

print(f"\nTotal log probability of the generated sequence: {sequence_log_prob:.4f}")
print(f"Total probability of the generated sequence: {sequence_prob:.4e}")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Generated sequence: es are a great way

Token	| Next Token	| Log Probability
--------------------------------------------------
es	|  are	 | -2.0427
 are	|  a	 | -2.3210
 a	|  great	 | -2.9873
 great	|  way	 | -3.1264

Total log probability of the generated sequence: -12.4331
Total probability of the generated sequence: 3.9843e-06
