In [None]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# model_name = "deepseek-ai/deepseek-math-7b-instruct"
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
# model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id

NUM_COT = 3

In [None]:
"""[INST] Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6
Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.
Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
A: [/INST] In April, Natalia sold clips to 48 friends. In May, she sold half as many as in April, so she sold 48 / 2 = 24 clips. In total, she sold 48 + 24 = 72 clips. The answer is 72.

Or, we could use the following equation:
Total Clips Sold = Clips Sold in April + Clips Sold in May
Total Clips Sold = 48 + 24
Total Clips Sold = 72.

Therefore, the answer is 72 clips sold altogether in April and May.
"""

Messing Around

In [None]:
import pandas as pd

gpt35_df = pd.read_csv('../conditional/data/112_gsm8k_gpt35_cot_onesent_responses.csv')
gpt35_df.head(2)

In [None]:
questions = gpt35_df['Question'].to_list()[:2]
questions

In [None]:
question = questions[1]
question

In [None]:
question_tokens = tokenizer(question, add_special_tokens=False)
question_input_ids = question_tokens.input_ids
print(len(question_tokens.input_ids), type(question_tokens), type(question_tokens.input_ids))
question_tokens

In [None]:
question_tokens = {f"question_{k}": v for k, v in question_tokens.items()}

In [None]:
answers = gpt35_df['Answer'].to_list()[:2]
answers

In [None]:
answer = answers[1]
answer

In [None]:
full_tokenized = tokenizer(question + answer, add_special_tokens=False)
print(len(full_tokenized.input_ids), len(question_input_ids))

In [None]:
answer_input_ids = full_tokenized["input_ids"][len(question_input_ids) :]
answer_attention_mask = full_tokenized["attention_mask"][len(question_input_ids) :]

In [None]:
def build_tokenized_answer(tokenizer, prompt, answer):
        """
        Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
        It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
        Reference:
            https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
        """

        full_tokenized = tokenizer(prompt + answer, add_special_tokens=False)
        prompt_input_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]

        answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
        answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]

        # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
        full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])

        # Prepare input tokens for token by token comparison
        full_input_ids = np.array(full_tokenized["input_ids"])

        if len(full_input_ids) != len(full_concat_input_ids):
            raise ValueError("Prompt input ids and answer input ids should have the same length.")

        # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
        # can be merged together when tokenizing prompt+answer. This could result
        # on the last token from the prompt being different when tokenized on its own
        # vs when done as prompt+answer.
        response_token_ids_start_idx = len(prompt_input_ids)

        # If tokenized prompt is different than both prompt+answer, then it means the
        # last token has changed due to merging.
        if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
            response_token_ids_start_idx -= 1

        prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
        prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]

        if len(prompt_input_ids) != len(prompt_attention_mask):
            raise ValueError("Prompt input ids and attention mask should have the same length.")

        answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
        answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

        return dict(
            prompt_input_ids=prompt_input_ids,
            prompt_attention_mask=prompt_attention_mask,
            input_ids=answer_input_ids,
            attention_mask=answer_attention_mask,
        )

In [None]:
tokenized_answer = build_tokenized_answer(tokenizer, question, answer)
len(tokenized_answer['prompt_input_ids'])

In [None]:
question_len_input_ids = len(question_tokens["question_input_ids"])
question_len_input_ids

In [None]:
tokenized_answer.keys()

In [None]:
type(tokenized_answer), type(tokenized_answer['prompt_input_ids'])

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from util import pad_to_length

def concatenated_inputs(
        batch: Dict[str, Union[List, torch.LongTensor]],
        is_encoder_decoder: bool = False,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        device: Optional[torch.device] = None,
    ) -> Dict[str, torch.LongTensor]:
    """Concatenate the chosen and rejected inputs into a single tensor.

    Args:
        batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
        is_encoder_decoder: Whether the model is an encoder-decoder model.
        label_pad_token_id: The label pad token id.
        padding_value: The padding value to use for the concatenated inputs_ids.
        device: The device for the concatenated inputs.

    Returns:
        A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
    """
    concatenated_batch = {}

    if is_encoder_decoder:
        max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
    else:
        max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])

    for k in batch:
        if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
            if "labels" in k or is_encoder_decoder:
                pad_value = label_pad_token_id
            elif k.endswith("_input_ids"):
                pad_value = padding_value
            elif k.endswith("_attention_mask"):
                pad_value = 0
            concatenated_key = k.replace("chosen", "concatenated")
            concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
    for k in batch:
        if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
            if "labels" in k or is_encoder_decoder:
                pad_value = label_pad_token_id
            elif k.endswith("_input_ids"):
                pad_value = padding_value
            elif k.endswith("_attention_mask"):
                pad_value = 0
            concatenated_key = k.replace("rejected", "concatenated")
            concatenated_batch[concatenated_key] = torch.cat(
                (
                    concatenated_batch[concatenated_key],
                    pad_to_length(batch[k], max_length, pad_value=pad_value),
                ),
                dim=0,
            ).to(device=device)

    if is_encoder_decoder:
        concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
        concatenated_batch["concatenated_attention_mask"] = (
            batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
        )

    return concatenated_batch

In [None]:
question

In [None]:
answer1 = " The answer is $12."
answer2 = " The answer is 10."
answer3 = " The answer is 3."
sample_answers = [answer1, answer2, answer3]

In [None]:
import torch
from typing import List

def tokenize_sample_answers(tokenizer, prefix_text: str, suffixes: List[str], padding_value: int = 0):
    """Create tensor of tokenized suffixes

    Args:
        tokenizer: tokenizer.
        prefix_text: text kept constant that precedes the suffixes
        suffixes: List of answers to consider for given prefix_text
        padding_value: padding token to make suffix input ids equal length
    """
    combined_texts = [prefix_text + suffix for suffix in suffixes]
    full_tokenized = tokenizer(combined_texts, add_special_tokens=True, padding=True, return_tensors="pt")
    print("full shape: ", full_tokenized["input_ids"].shape)
    
    prefix_input_ids = tokenizer(prefix_text, add_special_tokens=True, return_tensors='pt')['input_ids']
    print("prefix shape: ", prefix_input_ids.shape)

    suffix_start_idx = prefix_input_ids.shape[1]
    suffix_input_ids = full_tokenized["input_ids"][:, suffix_start_idx:]

    print("suffix shape: ", suffix_input_ids.shape)
    print("suffix input ids: ", suffix_input_ids)

    repeated_prefix_input_ids = prefix_input_ids.repeat(full_tokenized["input_ids"].shape[0], 1)
    print("repeated prefix shape: ", repeated_prefix_input_ids.shape)

    # Check if the first 'prefix_length' tokens of each entry in full_tokenized are the same as prefix_input_ids
    is_prefix_equal = torch.all(full_tokenized["input_ids"][:, :suffix_start_idx] == repeated_prefix_input_ids, dim=1)
    print("Is prefix equal for all entries: ", is_prefix_equal)
    
    return full_tokenized, suffix_input_ids, suffix_start_idx 

sample_answers = [' The answer is $12.', ' The answer is 10.', ' The answer is 3.']
full_tokenized, suffix_input_ids, suffix_start_idx = tokenize_sample_answers(tokenizer, question, sample_answers)

In [None]:
import torch

def create_labels_mask(full_tokenized, suffix_start_idx):
    """
    Create a binary mask for the suffix tokens in the tokenized batch.

    Args:
        full_tokenized: The tokenized data containing 'input_ids' and 'attention_mask'.
        suffix_start_idx: The start index of the suffix in the tokenized sequences.

    Returns:
        A binary mask tensor of the same shape as `full_tokenized['input_ids']` where suffix tokens are marked with 1 and others with 0.
    """
    batch_size, seq_length = full_tokenized['input_ids'].shape

    labels_mask = torch.zeros((batch_size, seq_length), dtype=torch.long)

    labels_mask[:, suffix_start_idx:] = 1

    labels_mask *= full_tokenized['attention_mask']

    return labels_mask

wrong_suffix_mask = create_labels_mask(full_tokenized, suffix_start_idx)
print("Labels mask shape:", wrong_suffix_mask.shape)
print("Labels mask:", wrong_suffix_mask)

In [None]:
all_logits = model(full_tokenized['input_ids'].to(model.device)).logits

In [None]:
all_logits.shape

In [None]:
def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

    We do this to avoid doing two forward passes, because it's faster for FSDP.
    """
    concatenated_batch = self.concatenated_inputs(
        batch,
        is_encoder_decoder=self.is_encoder_decoder,
        label_pad_token_id=self.label_pad_token_id,
        padding_value=self.padding_value,
        device=self.accelerator.device,
    )
    len_chosen = batch["chosen_labels"].shape[0]

    model_kwargs = (
        {
            "labels": concatenated_batch["concatenated_labels"],
            "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
        }
        if self.is_encoder_decoder
        else {}
    )
    all_logits = model(
        concatenated_batch["concatenated_input_ids"],
        attention_mask=concatenated_batch["concatenated_attention_mask"],
        use_cache=False,
        **model_kwargs,
    ).logits

    all_logps = self.get_batch_logps(
        all_logits,
        concatenated_batch["concatenated_labels"],
        average_log_prob=self.loss_type == "ipo",
        is_encoder_decoder=self.is_encoder_decoder,
        label_pad_token_id=self.label_pad_token_id,
    )

    chosen_logps = all_logps[:len_chosen]
    rejected_logps = all_logps[len_chosen:]

    chosen_logits = all_logits[:len_chosen]
    rejected_logits = all_logits[len_chosen:]

    return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)

In [None]:
def get_batch_logps(
        logits: torch.FloatTensor,
        token_mask: torch.LongTensor,
        average_log_prob: bool = False,
        label_pad_token_id: int = -100,
        is_encoder_decoder: bool = False,
    ) -> torch.FloatTensor:
    """Compute the log probabilities of the given labels under the given logits.

    Args:
        logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
        labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
        average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
        label_pad_token_id: The label pad token id.
        is_encoder_decoder: Whether the model is an encoder-decoder model.

    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
    """
    if logits.shape[:-1] != token_mask.shape:
        raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

    # dummy token; we'll ignore the losses on these tokens later
    token_mask[token_mask == label_pad_token_id] = 0

    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=token_mask.unsqueeze(2)).squeeze(2)
    print(per_token_logps.shape)
    print(token_mask)
    print(per_token_logps * token_mask)

    tokens = [tokenizer.decode(ids) for ids in token_mask]

    for i, token_sequence in enumerate(tokens):
        print(f"Tokens for sequence {i}: {token_sequence}")
        for j, token in enumerate(token_sequence.split()):
            token_log_prob = per_token_logps[i, j]#.max().item()  # Get the max log prob for this token
            print(f"Token: {token}, Log Prob: {token_log_prob}")

    if average_log_prob:
        return (per_token_logps * token_mask).sum(-1) / token_mask.sum(-1)
    else:
        return (per_token_logps * token_mask).sum(-1)

In [None]:
wrong_suffix_mask.shape

In [None]:
full_tokenized.input_ids.shape

In [None]:
get_batch_logps(all_logits.to("cpu"), full_tokenized.input_ids)

In [None]:
def logprobs_per_token(tokenizer, question, answers):
    full_tokenized, suffix_input_ids, suffix_start_idx = tokenize_sample_answers(tokenizer, question, answers)
    all_logits = model(full_tokenized['input_ids'].to(model.device)).logits
    get_batch_logps(all_logits.to("cpu"), full_tokenized.input_ids)