In [4]:
from pydoc import resolve
from typing import Any, Dict, List, Optional

from transformers import PreTrainedTokenizerBase, AutoTokenizer

In [36]:
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0125-Instruct", trust_remote_code=True)
examples = {
    "question": ["Is 123 a prime?"],
    "response": ["No, 123 is not a prime number. It can be factored as 3 × 41."]
}

In [37]:
def apply_general_chat_template(
        question: str,
        tokenizer: PreTrainedTokenizerBase,
        response: Optional[str] = None,
):
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": question}
    ]
    if response is None:
        return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    else:
        messages.append({"role": "assistant", "content": response})
        return tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)


def sft_olmoe_train_batch_preprocess_fn(
        examples: Dict[str, List[Any]],
        tokenizer: PreTrainedTokenizerBase,
):
    if tokenizer is None:
        raise ValueError("Tokenizer is required for SFT training.")

    # 1. apply general chat template to each example
    all_chat_texts = []

    for question, response in zip(examples["question"], examples["response"]):
        chat_text = apply_general_chat_template(question, response=response, tokenizer=tokenizer)
        all_chat_texts.append(chat_text)

    # 2. Tokenize the chat
    all_input_ids = []
    all_attention_masks = []
    all_labels = []

    for chat_text in all_chat_texts:
        encoded = tokenizer(chat_text, padding=False, truncation=True)
        input_ids = encoded["input_ids"]
        attention_mask = encoded["attention_mask"]

        # 3. Only apply LM loss on the assistant's response & "<|endoftext|>"
        labels = [-100] * len(input_ids)

        assistant_token_id = tokenizer("<|assistant|>", add_special_tokens=False)["input_ids"]
        end_token_id = tokenizer.convert_tokens_to_ids("|||IP_ADDRESS|||")

        pos_assistant = -1
        pos_end_after_response = -1

        i = 0
        while i <= len(input_ids) - len(assistant_token_id):
            matched = True
            for j in range(len(assistant_token_id)):
                if input_ids[i + j] != assistant_token_id[j]:
                    matched = False
                    break

            if matched:
                pos_assistant = i + len(assistant_token_id) - 1
                break
            i += 1

        if pos_assistant != -1:
            for i in range(pos_assistant + 1, len(input_ids)):
                if input_ids[i] == end_token_id:
                    pos_end_after_response = i
                    break

        if pos_assistant != -1 and pos_end_after_response != -1:
            for i in range(pos_assistant + 1, pos_end_after_response):
                labels[i] = input_ids[i]

        all_input_ids.append(input_ids)
        all_attention_masks.append(attention_mask)
        all_labels.append(labels)
        print(pos_assistant, pos_end_after_response)

    return {
        "input_ids": all_input_ids,
        "attention_mask": all_attention_masks,
        "labels": all_labels
    }

In [38]:
results = sft_olmoe_train_batch_preprocess_fn(examples, tokenizer)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


28 49


In [41]:
results

{'input_ids': [[50279,
   29,
   93,
   10394,
   49651,
   187,
   1394,
   403,
   247,
   9371,
   13372,
   15,
   187,
   29,
   93,
   4537,
   49651,
   187,
   2513,
   15567,
   247,
   4335,
   32,
   187,
   29,
   93,
   515,
   5567,
   49651,
   187,
   2302,
   13,
   15567,
   310,
   417,
   247,
   4335,
   1180,
   15,
   733,
   476,
   320,
   958,
   2149,
   347,
   495,
   6806,
   7609,
   15,
   50279]],
 'attention_mask': [[1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1]],
 'labels': [[-100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
   -100,
  