In [2]:
import torch

def calculate_perplexity(logits, target):
    """
    Calculate perplexity from logits and target labels.

    Args:
    - logits (torch.Tensor): Logits output from the model (batch_size, seq_length, vocab_size).
    - target (torch.Tensor): Ground truth labels (batch_size, seq_length).

    Returns:
    - perplexity (float): The perplexity score.
    """

    # Convert logits to log probabilities
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    # Gather the log probabilities for the correct target tokens
    # log_probs has shape (batch_size, seq_length, vocab_size)
    # target has shape (batch_size, seq_length)
    # The gather method will pick the log probabilities of the true target tokens
    target_log_probs = log_probs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)

    # Calculate the negative log likelihood
    negative_log_likelihood = -target_log_probs

    # Calculate the mean negative log likelihood over all tokens
    mean_nll = negative_log_likelihood.mean()

    # Calculate perplexity as exp(mean negative log likelihood)
    perplexity = torch.exp(mean_nll)

    return perplexity.item()

In [3]:

# Example usage
# Simulate a batch of logits (batch_size=2, seq_length=4, vocab_size=10)
logits = torch.randn(2, 4, 10)
# Simulate ground truth target tokens
target = torch.tensor([[1, 2, 3, 4], [4, 3, 2, 1]])

# Calculate perplexity
perplexity = calculate_perplexity(logits, target)
print(f'Perplexity: {perplexity}')

Perplexity: 11.153843879699707


In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the model and tokenizer (e.g., GPT-2)
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Assign the EOS token as the padding token
tokenizer.pad_token = tokenizer.eos_token

def calculate_batch_perplexity(input_texts):
    """
    Calculate perplexity for a batch of input texts using a pretrained language model.

    Args:
    - input_texts (List[str]): A list of input texts to evaluate.

    Returns:
    - List[float]: A list of perplexity scores, one for each input text.
    """
    # Tokenize the batch of texts with padding for uniform length
    inputs = tokenizer(
        input_texts, return_tensors="pt", padding=True, truncation=True
    )

    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Pass the input batch through the model to get logits
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    # Shift the logits and input_ids to align targets correctly
    # Logits dimensions are: (batch_size, seq_length, vocab_size) 
    shift_logits = logits[:, :-1, :]  # Ignore the last token's logits
    shift_labels = input_ids[:, 1:]   # Skip the first token in the labels

    # Compute log probabilities
    log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)

    # Gather the log probabilities for the correct tokens
    target_log_probs = log_probs.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)

    # Mask out positions corresponding to padding tokens
    target_log_probs = target_log_probs * attention_mask[:, 1:].to(log_probs.dtype)

    # Compute the mean negative log-likelihood for each sequence
    negative_log_likelihood = -target_log_probs.sum(dim=-1) / attention_mask[:, 1:].sum(dim=-1)

    # Compute perplexity for each sequence
    perplexities = torch.exp(negative_log_likelihood)
    perplexities = perplexities.tolist()
    
    # Take mean of perplexities of each batch
    mean_perplexity_score = torch.mean(perplexities)

    return {"perplexities": perplexities, "mean_perplexity": mean_perplexity_score}

# Example usage
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "A journey of a thousand miles begins with a single step."
]
print(f"Perplexity scores: {calculate_batch_perplexity(texts)}")

  from .autonotebook import tqdm as notebook_tqdm
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


TypeError: mean(): argument 'input' (position 1) must be Tensor, not list