In [1]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# model_name = "deepseek-ai/deepseek-math-7b-instruct"
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
# model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id

NUM_COT = 3

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:16<00:00,  5.47s/it]


My Attempt

In [36]:
def generate_text_and_scores(model, tokenizer, question):
    messages = [
        {"role": "user", "content": question + "\nPlease reason step by step, and put your final answer within \\boxed{}."}
    ]
    input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
    outputs = model.generate(input_tensor.to(model.device), max_new_tokens=1000, return_dict_in_generate=True, output_scores=True)
    transition_scores = model.compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=True
    )
    print(outputs.sequences)

    input_length = input_tensor.shape[1]
    generated_tokens = outputs.sequences[:, input_length:]
    print(generated_tokens)
    for tok, score in zip(generated_tokens[0], transition_scores[0]):
        # | token | token string | log probability | probability
        print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.cpu().numpy()} | {np.exp(score.cpu().numpy()):.2%}")

    return outputs.sequences, generated_tokens[0], transition_scores[0]

In [33]:
def to_tokens_and_logprobs(model, tokenizer, input_ids):
    outputs = model(input_ids)
    probs = torch.log_softmax(outputs.logits, dim=-1).detach()

    # Adjust indices to ignore the first token's log prob as it corresponds to the second token
    probs = probs[:, :-1, :]
    input_ids = input_ids[:, 1:]
    gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)

    batch = []
    for input_sentence, input_probs in zip(input_ids, gen_probs):
        text_sequence = []
        for token, p in zip(input_sentence, input_probs):
            if token not in tokenizer.all_special_ids:
                text_sequence.append((tokenizer.decode(token), p.item()))
        batch.append(text_sequence)
    return batch

In [37]:
# Load your model and tokenizer
# model_name = "mistralai/Mistral-7B-Instruct-v0.2"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.eval()

# Generate text and scores, and calculate log probabilities
question = "What is 9 + 10 * 2?"
total_tokens, generated_tokens, generated_logprobs = generate_text_and_scores(model, tokenizer, question)
print("Generated Text Log Probabilities:")
for i, (token, score) in enumerate(zip(generated_tokens, generated_logprobs)):
    print(f"Token: {token}, Generated Log Prob: {score.item()}")
    
log_probs = to_tokens_and_logprobs(model, tokenizer, total_tokens) 

print("\nCalculated Log Probabilities from Concatenated Text:")
for sequence in log_probs:
    for token, log_prob in sequence:
        print(f"Token: {token}, Calculated Log Prob: {log_prob}")

tensor([[    1,   733, 16289, 28793,  1824,   349, 28705, 28774,   648, 28705,
         28740, 28734,   398, 28705, 28750, 28804,    13, 12069,  2611,  3707,
           486,  3707, 28725,   304,  1658,   574,  1480,  4372,  2373,   414,
          2858,   286, 28751,  2051,   733, 28748, 16289, 28793,  1791, 12049,
           272,  5782, 28705, 28774,   648, 28705, 28740, 28734,   398, 28705,
         28750, 28725,   478,   927,   298,  1372,   272,  1745,   302,  6933,
         28725,   690,   349,  2608, 10216,   486,   272,  1183,  1689,  1082,
         21025,  4915,  2109, 28747, 18712,  2053,   274, 28725,  1529,  6445,
         28725, 18317,  2459,   304,  8618,   325,  3211,  1749,   298,  1103,
           557,   304,  3301,   685,   304,  5078,   434,  1774,   325,  3211,
          1749,   298,  1103,   609,    13,    13,   657,   456,  1222, 28725,
           736,   460,   708,  2564,  2053,   274,   442,   439,  6445, 28725,
           579,   478,  2318,   356,   298,   272,  

Long Prompt Attempt

tensor([[    1,   733, 16289, 28793,  1824,   349, 28705, 28774,   648, 28705,
         28740, 28734,   398, 28705, 28750, 28804,    13, 12069,  2611,  3707,
           486,  3707, 28725,   304,  1658,   574,  1480,  4372,  2373,   414,
          2858,   286, 28751,  2051,   733, 28748, 16289, 28793,  1791, 12049,
           272,  5782, 28705, 28774,   648, 28705, 28740, 28734,   398, 28705,
         28750, 28725,   478,   927,   298,  1372,   272,  1745,   302,  6933,
         28725,   690,   349,  2608, 10216,   486,   272,  1183,  1689,  1082,
         21025,  4915,  2109, 28747, 18712,  2053,   274, 28725,  1529,  6445,
         28725, 18317,  2459,   304,  8618,   325,  3211,  1749,   298,  1103,
           557,   304,  3301,   685,   304,  5078,   434,  1774,   325,  3211,
          1749,   298,  1103,   609,    13,    13,   657,   456,  1222, 28725,
           736,   460,   708,  2564,  2053,   274,   442,   439,  6445, 28725,
           579,   478,  2318,   356,   298,   272,  

In [9]:
def generate_text_with_input_and_scores(question, model, tokenizer):
    messages = [
        {"role": "user", "content": question + "\nPlease reason step by step, and put your final answer within \\boxed{}."}
    ]
    input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
    outputs = model.generate(input_tensor.to(model.device), max_new_tokens=1000, return_dict_in_generate=True, output_scores=True)
    
    generated_text = tokenizer.decode(outputs.sequences[0][input_tensor.shape[1]:], skip_special_tokens=True)
    input_text = tokenizer.decode(input_tensor[0], skip_special_tokens=True)
    
    # Compute transition scores (log probabilities)
    transition_scores = model.compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=True
    )
    
    # Extract log probabilities for the generated tokens
    token_log_probs = transition_scores[0][input_tensor.shape[1]:].tolist()  # Assuming single batch

    return input_text, generated_text, token_log_probs

In [14]:
def to_tokens_and_logprobs(model, tokenizer, input_text, generated_text):
    full_text = input_text + generated_text
    input_ids = tokenizer(full_text, padding=True, return_tensors="pt").input_ids.to(model.device)
    
    outputs = model(input_ids)
    probs = torch.log_softmax(outputs.logits, dim=-1).detach()

    # Adjust indices to ignore the first token's log prob as it corresponds to the second token
    probs = probs[:, :-1, :]
    input_ids = input_ids[:, 1:]
    gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)

    batch = []
    for input_sentence, input_probs in zip(input_ids, gen_probs):
        text_sequence = []
        for token, p in zip(input_sentence, input_probs):
            if token not in tokenizer.all_special_ids:
                text_sequence.append((tokenizer.decode(token), p.item()))
        batch.append(text_sequence)
    return batch

In [15]:
# Load your model and tokenizer
# model_name = "mistralai/Mistral-7B-Instruct-v0.2"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.eval()

# Generate text and scores, and calculate log probabilities
question = "What is 9 + 10 * 2?"
input_text, generated_text, generated_scores = generate_text_with_input_and_scores(question, model, tokenizer)
log_probs = to_tokens_and_logprobs(model, tokenizer, input_text, generated_text)

# Print the results and compare
print("Generated Text Log Probabilities:")
for i, (token, score) in enumerate(zip(tokenizer.tokenize(generated_text), generated_scores)):
    print(f"Token: {token}, Generated Log Prob: {score.item()}")

print("\nCalculated Log Probabilities from Concatenated Text:")
for sequence in log_probs:
    for token, log_prob in sequence:
        print(f"Token: {token}, Calculated Log Prob: {log_prob}")

Generated Text Log Probabilities:
Token: ▁To, Generated Log Prob: -19.497920989990234
Token: ▁solve, Generated Log Prob: -24.336538314819336
Token: ▁the, Generated Log Prob: -26.707199096679688
Token: ▁expression, Generated Log Prob: -27.356489181518555
Token: ▁, Generated Log Prob: -19.32635498046875
Token: 9, Generated Log Prob: -21.62142562866211
Token: ▁+, Generated Log Prob: -30.819366455078125
Token: ▁, Generated Log Prob: -29.140625
Token: 1, Generated Log Prob: -21.78125
Token: 0, Generated Log Prob: -21.6875
Token: ▁*, Generated Log Prob: -25.496135711669922
Token: ▁, Generated Log Prob: -27.359375
Token: 2, Generated Log Prob: -27.93359375
Token: ,, Generated Log Prob: -22.229019165039062
Token: ▁we, Generated Log Prob: -22.026994705200195
Token: ▁need, Generated Log Prob: -23.68183135986328
Token: ▁to, Generated Log Prob: -30.445314407348633
Token: ▁follow, Generated Log Prob: -22.601839065551758
Token: ▁the, Generated Log Prob: -23.01563835144043
Token: ▁order, Generated Lo

In [5]:
def debug_tokens_and_logprobs(model, tokenizer, input_text, generated_text):
    full_text = input_text + generated_text
    input_ids = tokenizer(full_text, padding=True, return_tensors="pt").input_ids.to(model.device)
    
    print("Full Text Token IDs:", input_ids)
    
    outputs = model(input_ids)
    probs = torch.log_softmax(outputs.logits, dim=-1).detach()

    # Adjust indices to ignore the first token's log prob as it corresponds to the second token
    probs = probs[:, :-1, :]
    input_ids = input_ids[:, 1:]
    gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)

    batch = []
    for input_sentence, input_probs in zip(input_ids, gen_probs):
        text_sequence = []
        for token, p in zip(input_sentence, input_probs):
            decoded_token = tokenizer.decode([token])
            print(f"Token: {decoded_token}, ID: {token}, Log Prob: {p.item()}")
            if token not in tokenizer.all_special_ids:
                text_sequence.append((decoded_token, p.item()))
        batch.append(text_sequence)
    return batch

# Use this function to debug
input_text, generated_text, _ = generate_text_with_input_and_scores(question, model, tokenizer)
debug_tokens_and_logprobs(model, tokenizer, input_text, generated_text)

Full Text Token IDs: tensor([[    1,   733, 16289, 28793,  1824,   349, 28705, 28774,   648, 28705,
         28740, 28734,   398, 28705, 28750, 28804,    13, 12069,  2611,  3707,
           486,  3707, 28725,   304,  1658,   574,  1480,  4372,  2373,   414,
          2858,   286, 28751,  2051,   733, 28748, 16289, 28793,  1551, 12049,
           272,  5782, 28705, 28774,   648, 28705, 28740, 28734,   398, 28705,
         28750, 28725,   478,   927,   298,  1372,   272,  1745,   302,  6933,
         28725,   690,   349,  2608, 10216,   486,   272,  1183,  1689,  1082,
         21025,  4915,  2109, 28747, 18712,  2053,   274, 28725,  1529,  6445,
         28725, 18317,  2459,   304,  8618,   325,  3211,  1749,   298,  1103,
           557,   304,  3301,   685,   304,  5078,   434,  1774,   325,  3211,
          1749,   298,  1103,   609,    13,    13,   657,   456,  1222, 28725,
           736,   460,   708,  2564,  2053,   274,   442,   439,  6445, 28725,
           579,   478,  2318,  

[[('[', -9.379661560058594),
  ('INST', -11.709663391113281),
  (']', -12.90582275390625),
  ('What', -13.26082992553711),
  ('is', -0.6441332697868347),
  ('', -7.122815132141113),
  ('9', -3.511679172515869),
  ('+', -5.391874313354492),
  ('', -0.23111653327941895),
  ('1', -1.1535981893539429),
  ('0', -1.4491122961044312),
  ('*', -1.6907577514648438),
  ('', -0.11482894420623779),
  ('2', -0.3742263615131378),
  ('?', -0.12354228645563126),
  ('\n', -0.2805356979370117),
  ('Please', -9.997407913208008),
  ('reason', -12.95058536529541),
  ('step', -5.26135778427124),
  ('by', -0.062126148492097855),
  ('step', -7.497983460780233e-05),
  (',', -2.7327542304992676),
  ('and', -3.515874147415161),
  ('put', -8.407917022705078),
  ('your', -4.023624897003174),
  ('final', -4.041463851928711),
  ('answer', -0.032512418925762177),
  ('within', -7.389629364013672),
  ('\\', -12.62710952758789),
  ('box', -4.384269714355469),
  ('ed', -0.00013183678674977273),
  ('{', -0.071790717542171

Reference Code

In [22]:
def query_logprobs(question):
    messages = [
        {"role": "user", "content": question + "\nPlease reason step by step, and put your final answer within \\boxed{}."}
    ]
    input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
    outputs = model.generate(input_tensor.to(model.device), max_new_tokens=1000, return_dict_in_generate=True, output_scores=True) # greedy sampling, should use do_sample for sampling
    transition_scores = model.compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=True
    )

    input_length = input_tensor.shape[1]
    generated_tokens = outputs.sequences[:, input_length:]
    print(generated_tokens)
    for tok, score in zip(generated_tokens[0], transition_scores[0]):
        # | token | token string | log probability | probability
        print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.cpu().numpy()} | {np.exp(score.cpu().numpy()):.2%}")

In [21]:
query_logprobs("What is 9 + 10 * 2?")

b 
tensor([[ 1791, 12049,   272,  5782, 28705, 28774,   648, 28705, 28740, 28734,
           398, 28705, 28750, 28725,   478,   927,   298,  1372,   272,  1745,
           302,  6933, 28725,   690,   349,  2608, 10216,   486,   272,  1183,
          1689,  1082, 21025,  4915,  2109, 28747, 18712,  2053,   274, 28725,
          1529,  6445, 28725, 18317,  2459,   304,  8618,   325,  3211,  1749,
           298,  1103,   557,   304,  3301,   685,   304,  5078,   434,  1774,
           325,  3211,  1749,   298,  1103,   609,    13,    13,   657,   456,
          1222, 28725,   736,   460,   708,  2564,  2053,   274,   442,   439,
          6445, 28725,   579,   478,  2318,   356,   298,   272,  1679,  3707,
         28723,    13,    13,  9977, 28705, 28740, 28747,  2744,   674,   272,
          6079,  2459,    13, 28740, 28734,   398, 28705, 28750,   327, 28705,
         28750, 28734,    13,    13,  9977, 28705, 28750, 28747,  2744,   674,
           272,  4518,    13, 28774,   648, 28705

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def to_tokens_and_logprobs(model, tokenizer, input_texts):
    # Tokenize input texts
    inputs = tokenizer(input_texts, return_tensors="pt", padding=True)
    input_ids = inputs.input_ids.to(model.device)

    # Generate outputs from the model
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        logits = outputs.logits

    # Calculate log probabilities using log softmax
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    # Extract log probabilities for the actual tokens
    shifted_input_ids = input_ids[:, 1:]  # Shift input ids to the right to align with logits
    shifted_log_probs = log_probs[:, :-1]  # Align logits with shifted input ids

    # Gather the log probabilities for each token in the input
    gathered_log_probs = torch.gather(shifted_log_probs, 2, shifted_input_ids.unsqueeze(-1)).squeeze(-1)

    # Prepare output
    batch_results = []
    for i, sentence_log_probs in enumerate(gathered_log_probs):
        tokens = tokenizer.convert_ids_to_tokens(shifted_input_ids[i])
        token_log_probs = [(token, log_prob.item()) for token, log_prob in zip(tokens, sentence_log_probs)]
        batch_results.append(token_log_probs)

    return batch_results

# Example usage
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.eval()  # Set the model to evaluation mode

input_text = "Example text to process."
results = to_tokens_and_logprobs(model, tokenizer, [input_text])
for token, log_prob in results[0]:
    print(f"Token: {token}, Log Prob: {log_prob}")

Failed Attempt

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def get_token_log_probs(input_text: str, model, tokenizer) -> torch.Tensor:
    # Tokenize the input text
    inputs = tokenizer(input_text, return_tensors="pt")
    input_ids = inputs['input_ids']

    # Move input ids to the same device as the model
    input_ids = input_ids.to(model.device)

    # Generate logits from the model
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits

    # Calculate log probabilities using log softmax
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    # Gather the log probabilities for each token in the input
    # We use torch.gather here to pick out the log probabilities of the actual tokens
    token_log_probs = torch.gather(log_probs, 2, input_ids.unsqueeze(-1)).squeeze(-1)

    return token_log_probs

model.eval()  # Set the model to evaluation mode

input_text = "Example text to process."
log_probs = get_token_log_probs(input_text, model, tokenizer)

tokens = tokenizer.tokenize(input_text)
for token, log_prob in zip(tokens, log_probs[0]):
    print(f"Token: {token}, Log Prob: {log_prob.item()}")

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList, StoppingCriteriaList

def generate_text_with_log_probs(input_text: str, model, tokenizer):
    # Tokenize the input text
    inputs = tokenizer(input_text, return_tensors="pt")
    input_ids = inputs['input_ids'].to(model.device)

    # Set up generation arguments
    logits_processor = LogitsProcessorList()
    stopping_criteria = StoppingCriteriaList()

    # Generate text with log probabilities
    output_sequences = model.generate(
        input_ids=input_ids,
        logits_processor=logits_processor,
        stopping_criteria=stopping_criteria,
        output_scores=True,
        return_dict_in_generate=True
    )

    # Extract log probabilities from the output
    scores = output_sequences.scores
    log_probs = [torch.nn.functional.log_softmax(score, dim=-1) for score in scores]
    token_log_probs = [torch.gather(log_prob, 1, output_sequences.sequences[:, i+1].unsqueeze(-1)).squeeze(-1) for i, log_prob in enumerate(log_probs)]

    return token_log_probs

In [None]:
# Input text
input_text = "Example text to process."

# Get log probabilities using the defined function
calculated_log_probs = get_token_log_probs(input_text, model, tokenizer)

# Generate text and get log probabilities during generation
generated_log_probs = generate_text_with_log_probs(input_text, model, tokenizer)

# Compare the results
tokens = tokenizer.tokenize(input_text)
for token, calc_log_prob, gen_log_prob in zip(tokens, calculated_log_probs[0], generated_log_probs):
    print(f"Token: {token}, Calculated Log Prob: {calc_log_prob.item()}, Generated Log Prob: {gen_log_prob.item()}")

Unsuccessful tinkering

In [None]:
import torch
from typing import List

def tokenize_sample_answers(tokenizer, prefix_text: str, suffixes: List[str], padding_value: int = 0):
    """Create tensor of tokenized suffixes

    Args:
        tokenizer: tokenizer.
        prefix_text: text kept constant that precedes the suffixes
        suffixes: List of answers to consider for given prefix_text
        padding_value: padding token to make suffix input ids equal length
    """
    combined_texts = [prefix_text + suffix for suffix in suffixes]
    full_tokenized = tokenizer(combined_texts, add_special_tokens=True, padding=True, return_tensors="pt")
    print("full shape: ", full_tokenized["input_ids"].shape)
    
    prefix_input_ids = tokenizer(prefix_text, add_special_tokens=True, return_tensors='pt')['input_ids']
    print("prefix shape: ", prefix_input_ids.shape)

    suffix_start_idx = prefix_input_ids.shape[1]
    suffix_input_ids = full_tokenized["input_ids"][:, suffix_start_idx:]

    print("suffix shape: ", suffix_input_ids.shape)
    print("suffix input ids: ", suffix_input_ids)

    repeated_prefix_input_ids = prefix_input_ids.repeat(full_tokenized["input_ids"].shape[0], 1)
    print("repeated prefix shape: ", repeated_prefix_input_ids.shape)

    # Check if the first 'prefix_length' tokens of each entry in full_tokenized are the same as prefix_input_ids
    is_prefix_equal = torch.all(full_tokenized["input_ids"][:, :suffix_start_idx] == repeated_prefix_input_ids, dim=1)
    print("Is prefix equal for all entries: ", is_prefix_equal)
    
    return full_tokenized, suffix_input_ids, suffix_start_idx 

# sample_answers = [' The answer is $12.', ' The answer is 10.', ' The answer is 3.']
# full_tokenized, suffix_input_ids, suffix_start_idx = tokenize_sample_answers(tokenizer, question, sample_answers)

In [None]:
def get_batch_logps(
        logits: torch.FloatTensor,
        token_mask: torch.LongTensor,
        average_log_prob: bool = False,
        label_pad_token_id: int = -100,
        is_encoder_decoder: bool = False,
    ) -> torch.FloatTensor:
    """Compute the log probabilities of the given labels under the given logits.

    Args:
        logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
        labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
        average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
        label_pad_token_id: The label pad token id.
        is_encoder_decoder: Whether the model is an encoder-decoder model.

    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
    """
    if logits.shape[:-1] != token_mask.shape:
        raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

    # dummy token; we'll ignore the losses on these tokens later
    token_mask[token_mask == label_pad_token_id] = 0

    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=token_mask.unsqueeze(2)).squeeze(2)
    print(per_token_logps.shape)
    print(token_mask)
    print(per_token_logps * token_mask)

    tokens = [tokenizer.decode(ids) for ids in token_mask]

    for i, token_sequence in enumerate(tokens):
        print(f"Tokens for sequence {i}: {token_sequence}")
        for j, token in enumerate(token_sequence.split()):
            token_log_prob = per_token_logps[i, j]#.max().item()  # Get the max log prob for this token
            print(f"Token: {token}, Log Prob: {token_log_prob}")

    if average_log_prob:
        return (per_token_logps * token_mask).sum(-1) / token_mask.sum(-1)
    else:
        return (per_token_logps * token_mask).sum(-1)

In [None]:
def logprobs_per_token(tokenizer, prefix, suffixes):
    full_tokenized, suffix_input_ids, suffix_start_idx = tokenize_sample_answers(tokenizer, prefix, suffixes)
    all_logits = model(full_tokenized['input_ids'].to(model.device)).logits
    get_batch_logps(all_logits.to("cpu"), full_tokenized.input_ids)

In [None]:
ex_prefix = """[INST] Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6
Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.
Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
A: [/INST] In April, Natalia sold clips to 48 friends. In May, she sold half as many as in April, so she sold 48 / 2 = 24 clips. In total, she sold 48 + 24 = 72 clips. The answer is 72.

Or, we could use the following equation:
Total Clips Sold = Clips Sold in April + Clips Sold in May
Total Clips Sold = 48 + 24
Total Clips Sold = 72.

"""

In [None]:
ex_suffixes = ["The answer is 72.", "The answer is 64.", "Therefore, the answer is 72 clips sold altogether in April and May.", ]

In [None]:
logprobs_per_token(tokenizer, ex_prefix, ex_suffixes)

GPT Turbo Fix

In [None]:
import torch

def create_suffix_mask(full_tokenized, suffix_start_idx):
    """Create a mask for the suffix tokens."""
    batch_size, seq_length = full_tokenized['input_ids'].shape
    mask = torch.zeros_like(full_tokenized['input_ids'])
    mask[:, suffix_start_idx:] = 1
    return mask

def get_batch_logps(logits, token_mask, tokenizer, suffix_input_ids):
    """Calculate log probabilities for suffix tokens."""
    # Apply the mask to logits
    masked_logits = logits * token_mask.unsqueeze(-1)  # Extend mask for vocab size

    # Calculate log softmax
    log_probs = torch.nn.functional.log_softmax(masked_logits, dim=-1)

    # Gather log probabilities for actual token indices
    gathered_log_probs = torch.gather(log_probs, 2, suffix_input_ids.unsqueeze(-1)).squeeze(-1)

    # Decode tokens and print their log probabilities
    tokens = [tokenizer.decode(ids) for ids in suffix_input_ids]
    for i, token_sequence in enumerate(tokens):
        print(f"Tokens for sequence {i}: {token_sequence}")
        for j, token in enumerate(token_sequence.split()):
            token_log_prob = gathered_log_probs[i, j].item()
            print(f"Token: {token}, Log Prob: {token_log_prob}")

    return gathered_log_probs

def logprobs_per_token(tokenizer, prefix, suffixes):
    full_tokenized, suffix_input_ids, suffix_start_idx = tokenize_sample_answers(tokenizer, prefix, suffixes)
    suffix_mask = create_suffix_mask(full_tokenized, suffix_start_idx)
    all_logits = model(full_tokenized['input_ids'].to(model.device)).logits
    get_batch_logps(all_logits.to("cpu"), suffix_mask, tokenizer, suffix_input_ids)

# Example usage
logprobs_per_token(tokenizer, ex_prefix, ex_suffixes)

In [None]:
def get_batch_logps(logits, token_mask, average_log_prob=False):
    """Compute the log probabilities of the given tokens under the given logits.

    Args:
        logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
        token_mask: Indices of the tokens for which to compute the log probabilities. Shape: (batch_size, sequence_length)
        average_log_prob: If True, return the average log probability per token. Otherwise, return the sum of the log probabilities.

    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given tokens under the given logits.
    """
    # Ensure logits and token_mask have compatible shapes
    if logits.shape[:-1] != token_mask.shape:
        raise ValueError("Logits and token mask must have the same shape.")

    # Compute log softmax over the vocabulary dimension
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    # Use gather to select the log probabilities of the actual tokens
    gathered_log_probs = torch.gather(log_probs, 2, token_mask.unsqueeze(-1)).squeeze(-1)

    # Print each token and its log probability
    for i, sequence in enumerate(token_mask):
        tokens = tokenizer.convert_ids_to_tokens(sequence)
        log_probs_sequence = gathered_log_probs[i]
        print(f"Sequence {i + 1}:")
        for token, log_prob in zip(tokens, log_probs_sequence):
            print(f"Token: {token}, Log Prob: {log_prob.item()}")

    # Compute the sum or average of the log probabilities
    if average_log_prob:
        token_count = (token_mask != -100).sum(dim=1).float()  # Assuming -100 is used to mask tokens
        return (gathered_log_probs * (token_mask != -100)).sum(dim=1) / token_count
    else:
        return (gathered_log_probs * (token_mask != -100)).sum(dim=1)

# Example usage
# Assuming `all_logits` and `full_tokenized['input_ids']` are available from your model's output
def logprobs_per_token(tokenizer, prefix, suffixes):
    full_tokenized, suffix_input_ids, suffix_start_idx = tokenize_sample_answers(tokenizer, prefix, suffixes)
    suffix_mask = create_suffix_mask(full_tokenized, suffix_start_idx)
    all_logits = model(full_tokenized['input_ids'].to(model.device)).logits
    get_batch_logps(all_logits.to("cpu"), suffix_mask)

# Example usage
logprobs_per_token(tokenizer, ex_prefix, ex_suffixes)

Messing Around

In [None]:
import pandas as pd

gpt35_df = pd.read_csv('../conditional/data/112_gsm8k_gpt35_cot_onesent_responses.csv')
gpt35_df.head(2)

In [None]:
questions = gpt35_df['Question'].to_list()[:2]
questions

In [None]:
question = questions[1]
question

In [None]:
question_tokens = tokenizer(question, add_special_tokens=False)
question_input_ids = question_tokens.input_ids
print(len(question_tokens.input_ids), type(question_tokens), type(question_tokens.input_ids))
question_tokens

In [None]:
question_tokens = {f"question_{k}": v for k, v in question_tokens.items()}

In [None]:
answers = gpt35_df['Answer'].to_list()[:2]
answers

In [None]:
answer = answers[1]
answer

In [None]:
full_tokenized = tokenizer(question + answer, add_special_tokens=False)
print(len(full_tokenized.input_ids), len(question_input_ids))

In [None]:
answer_input_ids = full_tokenized["input_ids"][len(question_input_ids) :]
answer_attention_mask = full_tokenized["attention_mask"][len(question_input_ids) :]

In [None]:
def build_tokenized_answer(tokenizer, prompt, answer):
        """
        Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
        It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
        Reference:
            https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
        """

        full_tokenized = tokenizer(prompt + answer, add_special_tokens=False)
        prompt_input_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]

        answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
        answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]

        # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
        full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])

        # Prepare input tokens for token by token comparison
        full_input_ids = np.array(full_tokenized["input_ids"])

        if len(full_input_ids) != len(full_concat_input_ids):
            raise ValueError("Prompt input ids and answer input ids should have the same length.")

        # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
        # can be merged together when tokenizing prompt+answer. This could result
        # on the last token from the prompt being different when tokenized on its own
        # vs when done as prompt+answer.
        response_token_ids_start_idx = len(prompt_input_ids)

        # If tokenized prompt is different than both prompt+answer, then it means the
        # last token has changed due to merging.
        if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
            response_token_ids_start_idx -= 1

        prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
        prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]

        if len(prompt_input_ids) != len(prompt_attention_mask):
            raise ValueError("Prompt input ids and attention mask should have the same length.")

        answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
        answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

        return dict(
            prompt_input_ids=prompt_input_ids,
            prompt_attention_mask=prompt_attention_mask,
            input_ids=answer_input_ids,
            attention_mask=answer_attention_mask,
        )

In [None]:
tokenized_answer = build_tokenized_answer(tokenizer, question, answer)
len(tokenized_answer['prompt_input_ids'])

In [None]:
question_len_input_ids = len(question_tokens["question_input_ids"])
question_len_input_ids

In [None]:
tokenized_answer.keys()

In [None]:
type(tokenized_answer), type(tokenized_answer['prompt_input_ids'])

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from util import pad_to_length

def concatenated_inputs(
        batch: Dict[str, Union[List, torch.LongTensor]],
        is_encoder_decoder: bool = False,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        device: Optional[torch.device] = None,
    ) -> Dict[str, torch.LongTensor]:
    """Concatenate the chosen and rejected inputs into a single tensor.

    Args:
        batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
        is_encoder_decoder: Whether the model is an encoder-decoder model.
        label_pad_token_id: The label pad token id.
        padding_value: The padding value to use for the concatenated inputs_ids.
        device: The device for the concatenated inputs.

    Returns:
        A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
    """
    concatenated_batch = {}

    if is_encoder_decoder:
        max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
    else:
        max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])

    for k in batch:
        if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
            if "labels" in k or is_encoder_decoder:
                pad_value = label_pad_token_id
            elif k.endswith("_input_ids"):
                pad_value = padding_value
            elif k.endswith("_attention_mask"):
                pad_value = 0
            concatenated_key = k.replace("chosen", "concatenated")
            concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
    for k in batch:
        if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
            if "labels" in k or is_encoder_decoder:
                pad_value = label_pad_token_id
            elif k.endswith("_input_ids"):
                pad_value = padding_value
            elif k.endswith("_attention_mask"):
                pad_value = 0
            concatenated_key = k.replace("rejected", "concatenated")
            concatenated_batch[concatenated_key] = torch.cat(
                (
                    concatenated_batch[concatenated_key],
                    pad_to_length(batch[k], max_length, pad_value=pad_value),
                ),
                dim=0,
            ).to(device=device)

    if is_encoder_decoder:
        concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
        concatenated_batch["concatenated_attention_mask"] = (
            batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
        )

    return concatenated_batch

In [None]:
question

In [None]:
answer1 = " The answer is $12."
answer2 = " The answer is 10."
answer3 = " The answer is 3."
sample_answers = [answer1, answer2, answer3]

In [None]:
import torch

def create_labels_mask(full_tokenized, suffix_start_idx):
    """
    Create a binary mask for the suffix tokens in the tokenized batch.

    Args:
        full_tokenized: The tokenized data containing 'input_ids' and 'attention_mask'.
        suffix_start_idx: The start index of the suffix in the tokenized sequences.

    Returns:
        A binary mask tensor of the same shape as `full_tokenized['input_ids']` where suffix tokens are marked with 1 and others with 0.
    """
    batch_size, seq_length = full_tokenized['input_ids'].shape

    labels_mask = torch.zeros((batch_size, seq_length), dtype=torch.long)

    labels_mask[:, suffix_start_idx:] = 1

    labels_mask *= full_tokenized['attention_mask']

    return labels_mask

wrong_suffix_mask = create_labels_mask(full_tokenized, suffix_start_idx)
print("Labels mask shape:", wrong_suffix_mask.shape)
print("Labels mask:", wrong_suffix_mask)

In [None]:
all_logits = model(full_tokenized['input_ids'].to(model.device)).logits

In [None]:
all_logits.shape

In [None]:
def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

    We do this to avoid doing two forward passes, because it's faster for FSDP.
    """
    concatenated_batch = self.concatenated_inputs(
        batch,
        is_encoder_decoder=self.is_encoder_decoder,
        label_pad_token_id=self.label_pad_token_id,
        padding_value=self.padding_value,
        device=self.accelerator.device,
    )
    len_chosen = batch["chosen_labels"].shape[0]

    model_kwargs = (
        {
            "labels": concatenated_batch["concatenated_labels"],
            "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
        }
        if self.is_encoder_decoder
        else {}
    )
    all_logits = model(
        concatenated_batch["concatenated_input_ids"],
        attention_mask=concatenated_batch["concatenated_attention_mask"],
        use_cache=False,
        **model_kwargs,
    ).logits

    all_logps = self.get_batch_logps(
        all_logits,
        concatenated_batch["concatenated_labels"],
        average_log_prob=self.loss_type == "ipo",
        is_encoder_decoder=self.is_encoder_decoder,
        label_pad_token_id=self.label_pad_token_id,
    )

    chosen_logps = all_logps[:len_chosen]
    rejected_logps = all_logps[len_chosen:]

    chosen_logits = all_logits[:len_chosen]
    rejected_logits = all_logits[len_chosen:]

    return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)

In [None]:
wrong_suffix_mask.shape

In [None]:
full_tokenized.input_ids.shape

In [None]:
get_batch_logps(all_logits.to("cpu"), full_tokenized.input_ids)