In [1]:
from transformers import PreTrainedTokenizerFast
import os

DATA_DIR = "/home/wyf/orcd/pool/reverse-llm/data"
TOKENIZER_DIR = "/home/wyf/orcd/pool/reverse-llm/tokenizers"
MODEL_DIR = "/home/wyf/orcd/pool/reverse-llm/models"

model_name = f"reverse-gpt2-0.35B-fineweb-10BT-ctx-1024"

USER_ROLE_NAME = "user"[::-1]
ASSISTANT_ROLE_NAME = "assistant"[::-1]

dataset_name = "alpaca"

In [2]:
tokenizer = PreTrainedTokenizerFast.from_pretrained(f"{TOKENIZER_DIR}/fineweb_bpe_200k")
tokenizer.add_special_tokens({ "additional_special_tokens": ["<im_start>", "<im_end>"] })

2

In [3]:
from datasets import Dataset, load_dataset
raw_dataset = load_dataset("tatsu-lab/alpaca", split="train")
split_datasets = raw_dataset.train_test_split(test_size=0.1, seed=0)

In [4]:
split_datasets

DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output', 'text'],
        num_rows: 46801
    })
    test: Dataset({
        features: ['instruction', 'input', 'output', 'text'],
        num_rows: 5201
    })
})

In [5]:
def process_alapca_data(ds_split: Dataset):
    convos = []
    for ex in ds_split:
        instr = ex.get("instruction", "").strip()[::-1]
        ctx = ex.get("input", "").strip()[::-1]
        response = ex.get("output", "").strip()[::-1]

        user_msg_parts = [instr]
        if ctx != "":
            user_msg_parts.append(ctx)
        user_msg = "\n\n".join(user_msg_parts)

        if user_msg and response:
            convos.append([
                {"role": USER_ROLE_NAME, "content": user_msg},
                {"role": ASSISTANT_ROLE_NAME, "content": response},
            ])

    return convos

processed_alpaca = {
    "train": process_alapca_data(split_datasets["train"]),
    "valid": process_alapca_data(split_datasets["test"]),
}

In [6]:
split_datasets["train"][0]

{'instruction': 'Write a request for further information',
 'input': '',
 'output': 'Could you please provide me with additional information that could help me understand this better?',
 'text': 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWrite a request for further information\n\n### Response:\nCould you please provide me with additional information that could help me understand this better?'}

In [7]:
print(processed_alpaca["train"][2])

[{'role': 'resu', 'content': '.si eert noisiced a tahw nialpxE'}, {'role': 'tnatsissa', 'content': '.snoisiced demrofni ekam ot redro ni semoctuo tciderp dna etaulave ot desu si eert noisiced ehT .eert eht fo dne eht yfingis sevael eht dna ,noisiced nevig a no desab neppah nac taht semoctuo elbissop eht tneserper sedon eht morf sehcnarb ehT .tniop noisiced a tneserper sedon eseht dna ,sehcnarb eht ni noitcesretni na si eert eht ni edon hcaE .emoctuo ro noisiced laitnetop a gniniatnoc hcae ,sedon dne erom ro eno ot edon toor elgnis a morf tuo sehcnarb eert ehT .ssecorp gnikam-noisiced eht fo pam lausiv dezinagro na gnitaerc yb snoitpo elpitlum neewteb esoohc elpoep pleh ot desu loot a si tI .snoitidnoc niatrec no desab ,melborp a ot snoitulos elbissop fo noitatneserper lacihparg a si eert noisiced A'}]


In [8]:
tokenizer.chat_template = """{% for message in messages -%}
<im_start>{{ message['role'] }}
{{ message['content'] }}<im_end>
{%- endfor -%}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
<im_start>assistant
{%- endif %}"""

print(tokenizer.decode(tokenizer.apply_chat_template(processed_alpaca["train"][2])))

<im_start>resu
.si eert noisiced a tahw nialpxE<im_end><im_start>tnatsissa
.snoisiced demrofni ekam ot redro ni semoctuo tciderp dna etaulave ot desu si eert noisiced ehT .eert eht fo dne eht yfingis sevael eht dna ,noisiced nevig a no desab neppah nac taht semoctuo elbissop eht tneserper sedon eht morf sehcnarb ehT .tniop noisiced a tneserper sedon eseht dna ,sehcnarb eht ni noitcesretni na si eert eht ni edon hcaE .emoctuo ro noisiced laitnetop a gniniatnoc hcae ,sedon dne erom ro eno ot edon toor elgnis a morf tuo sehcnarb eert ehT .ssecorp gnikam-noisiced eht fo pam lausiv dezinagro na gnitaerc yb snoitpo elpitlum neewteb esoohc elpoep pleh ot desu loot a si tI .snoitidnoc niatrec no desab ,melborp a ot snoitulos elbissop fo noitatneserper lacihparg a si eert noisiced A<im_end>


In [9]:
from tqdm import tqdm

context_length = 1024

def filter_and_prepare_conversations(convos, tokenizer, max_len):
    filtered_convos = []
    for convo in tqdm(convos, desc="Filtering long conversations"):
        if not convo:
	        continue
        prompt_text = tokenizer.apply_chat_template(
            convo,
            tokenize=False,
            add_generation_prompt=False
        )
        tokenized_len = len(tokenizer.encode(prompt_text, truncation=False))

        if tokenized_len > 0 and tokenized_len <= max_len:
            filtered_convos.append(convo)
        elif tokenized_len == 0:
            print(f"Zero length: {convo}")
        
    return { "conversations": filtered_convos }


filtered_convos = {
    "train": filter_and_prepare_conversations(processed_alpaca["train"], tokenizer, context_length),
    "valid": filter_and_prepare_conversations(processed_alpaca["valid"], tokenizer, context_length),
}
print(len(filtered_convos["train"]["conversations"]))

Filtering long conversations: 100%|██████████| 46775/46775 [00:07<00:00, 6491.64it/s]
Filtering long conversations: 100%|██████████| 5199/5199 [00:00<00:00, 6430.61it/s]

46774





In [10]:
def formatting_func(example):
    # 'example' here is something like {"conversations": [{"role": ..., "content": ...}, ...]}
    conversation = example["conversations"]

    prompt_text = tokenizer.apply_chat_template(
        conversation,
        tokenize=False,
        add_generation_prompt=False
    )

    tokenized_inputs = tokenizer(
        prompt_text,
        truncation=True, # not really needed here bcs we already remove those that > max length
        max_length=context_length,
        return_attention_mask=True,
        padding="max_length",
    )
    input_ids = tokenized_inputs["input_ids"]

    labels = [-100] * len(input_ids)
    
    # find the last assistant response
    last_assistant_idx = max(
        idx for idx, turn in enumerate(conversation)
        if turn["role"] == ASSISTANT_ROLE_NAME
    )

    # Use <im_start> and <im_end> tokens instead of BOS/EOS
    im_start_token_id = 52000  # <im_start>
    im_end_token_id = 52001    # <im_end>
    
    current_token_idx = 0
    for turn_idx, turn in enumerate(conversation):
        role = turn["role"]
        content = turn["content"]

        try:
            start_of_turn_bos_idx = input_ids.index(im_start_token_id, current_token_idx)
        except ValueError:
            break 

        search_for_eos_from = start_of_turn_bos_idx + 1 # search after the current <im_start>
        end_of_turn_eos_idx = -1

        for k_eos in range(search_for_eos_from, len(input_ids)):
            if input_ids[k_eos] == im_end_token_id:
                end_of_turn_eos_idx = k_eos
                break
        if end_of_turn_eos_idx == -1:
            print(f"Warning: Could not find <im_end> token for turn: {turn}")
            return None

        role_and_newline_text = f"{role}\n"
        role_and_newline_tokens = tokenizer.encode(role_and_newline_text, add_special_tokens=False)

        # The actual start of content tokens
        start_of_content_idx = start_of_turn_bos_idx + 1 + len(role_and_newline_tokens) # +1 for <im_start>

        if role == ASSISTANT_ROLE_NAME and turn_idx == last_assistant_idx:
            # unmask tokens from start_of_content_idx up to (but not including) end_of_turn_eos_idx
            for k_label in range(start_of_content_idx, end_of_turn_eos_idx + 1):
                if k_label >= 0 and k_label < len(labels):
                    labels[k_label] = input_ids[k_label]
        
        current_token_idx = end_of_turn_eos_idx + 1

    return {
        "input_ids": input_ids,
        "attention_mask": tokenized_inputs["attention_mask"],
        "labels": labels,
    }

tokenized = {}
for split in ["train", "valid"]:
    tokenized[split] = (
        Dataset.from_dict(filtered_convos[split])
        .map(formatting_func)
        .select_columns(["input_ids", "attention_mask", "labels"])
    )
    tokenized[split].save_to_disk(f"{DATA_DIR}/{dataset_name}/tokenized_{context_length}_{split}")

Map:   0%|          | 0/46774 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/46774 [00:00<?, ? examples/s]

Map:   0%|          | 0/5199 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5199 [00:00<?, ? examples/s]