In [6]:
import json
import numpy as np
from collections import defaultdict


DATA_PATH = "train_data/PMA_fine_tuning_FULL_SPIDER.jsonl"
SUPPORTED_CONTEXT_LENGTHS = [512, 1024, 2048, 4096, 8192, 16384, 32768]

# Import the tokenizer
from transformers import LlamaTokenizerFast
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.eos_token

# Load the dataset
with open(DATA_PATH, 'r', encoding='utf-8') as f:
    items = [json.loads(line) for line in f]


# Utility function for proper formatting of the data
def convert_message_list_to_text(messages: list) -> str:
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
    text = ""

    if messages[0]["role"] == "system":
        messages = [
            {
                "role": messages[1]["role"],
                "content": B_SYS
                + messages[0]["content"]
                + E_SYS
                + messages[1]["content"],
            }
        ] + messages[2:]

    assert all([msg["role"] == "user" for msg in messages[::2]]) and all(
            [msg["role"] == "assistant" for msg in messages[1::2]]
        ), (
            "model only supports 'system','user' and 'assistant' roles, "
            "starting with user and alternating (u/a/u/a/u...)"
        )

    texts = []
    for prompt, answer in zip(messages[::2], messages[1::2]):
        texts.append(f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ")

    text = "</s><s>".join(texts)
    # add the bos and eos token at the beginning of the first turn and the end of the last turn
    text = "<s>" + text + " </s>"
    # During training last message should be from assistant (not from a user)
    assert (
        messages[-1]["role"] == "assistant"
    ), f"Last message must be from assistant, got {messages[-1]['role']}"

    return text


# Utility functions for calculating the statistics of the number of tokens in the dataset
def print_token_statistics(stats) -> None:
    for key in stats:
        print(f"Statistics for {key}:")
        if isinstance(stats[key], dict):
            for stat_key, stat_value in stats[key].items():
                print(f"\t{stat_key}: {stat_value:.3f}")
        else:
            print(f"\t{stats[key]}")
        print("")

def get_tokenized_stats(items: list, print_stats: bool = True):

    counters = defaultdict(list)
    for batch in items:
        messages = batch["messages"]

        # add message count
        counters["message"].append(len(messages))

        # add the number of tokens of this message to the token counter
        text = convert_message_list_to_text(messages)
        tokens = tokenizer(text)['input_ids']
        counters["token"].append(len(tokens))

    stats = {}
    for key, value in counters.items():
        stats[key] = {
            "max": float(np.max(value)),
            "min": float(np.min(value)),
            "median": float(np.median(value)),
            "mean": float(np.mean(value)),
            "p95": float(np.percentile(value, 95)),
            "p5": float(np.percentile(value, 5)),
        }
    stats["ds_size"] = len(items)

    if print_stats:
        print_token_statistics(stats)

    return stats

# Auto calculate the context length
stats = get_tokenized_stats(items, print_stats=True)
for ctx_length in SUPPORTED_CONTEXT_LENGTHS:
    if ctx_length > stats["token"]["p95"]:
        break

print("Automatically selected context length: ", ctx_length)


Statistics for message:
	max: 3.000
	min: 3.000
	median: 3.000
	mean: 3.000
	p95: 3.000
	p5: 3.000

Statistics for token:
	max: 318.000
	min: 56.000
	median: 93.000
	mean: 100.954
	p95: 156.000
	p5: 65.000

Statistics for ds_size:
	7000

Automatically selected context length:  512


In [7]:
MODEL_SIZE = "llama-8b" # or 8x7b, 70b, ...
DS_MAX_SIZE_LIMITS = {
    "mistral-7b": {
        512: 150_000,
        1024: 50_000,
        2048: 25_000,
        4096: 10_000,
        8192: 5_000,
    },
    "llama-8b": {
        512: 90_000,
        1024: 32_000,
        2048: 15_000,
        4096: 5_000,
        8192: 5_000,
        16384: 5_000,
        32768: 2_500,
    },
    "llama-70b": {
        512: 25_000,
        1024: 10_000,
        2048: 5_000,
        4096: 5_000,
        8192: 3_000,
        16384: 1_500,
    },
    "mixtral-8x7b": {
        512: 25_000,
        1024: 10_000,
        2048: 5_000,
        4096: 5_000,
        8192: 2_500,
        16384: 1_000,
        32768: 1_000,
      },
}
CONTEXT_LENGTH = ctx_length

ds_max_size = DS_MAX_SIZE_LIMITS[MODEL_SIZE][CONTEXT_LENGTH]
if len(items) > ds_max_size:
    raise ValueError(
        f"Dataset size ({len(items)}) exceeds the maximum allowable size ({ds_max_size})"
    )

In [8]:
# We will use ray data for batched data iteration
import ray # pip install ray[data]
import pandas as pd

# You can change the batch size per device here
BSIZE_PER_DEVICE = 16

# Creating a ray dataset for easier processing
df = pd.DataFrame.from_dict(items)
ds = ray.data.from_pandas(df)


def batched_convert_messages_to_text(batch: pd.DataFrame) -> pd.DataFrame:
    """Converts a batch of messages (list of roles + content) to plain text."""
    df = []
    for _, b in batch.iterrows():
        text = convert_message_list_to_text(list(b["messages"]))
        df.append({"input": text})

    return pd.DataFrame(df)

def collate_fn(batch: dict):
    return tokenizer(
        list(batch["input"]),
        padding="longest",
        max_length=CONTEXT_LENGTH,
        truncation=True,
        return_tensors="pt",
    )


# Data preprocssing pipeline
flattened_ds = ds.map_batches(
    batched_convert_messages_to_text, batch_size=16, batch_format="pandas"
)

data_set_tokens_per_epoch = 0
trained_tokens_per_epoch = 0
for batch in flattened_ds.iter_torch_batches(
    batch_size=BSIZE_PER_DEVICE, collate_fn=collate_fn
):
    trained_tokens_per_epoch += batch["input_ids"].numel()
    data_set_tokens_per_epoch += batch["attention_mask"].sum().item()

print("Num tokens in dataset per epoch: ", data_set_tokens_per_epoch)
print("Num tokens trained per epoch: ", trained_tokens_per_epoch)
print("Padding inflation ratio: ", trained_tokens_per_epoch / data_set_tokens_per_epoch)

2024-07-08 10:24:43,064	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-07-07_18-52-12_756286_35179/logs/ray-data
2024-07-08 10:24:43,065	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(batched_convert_messages_to_text)]
                                                

Num tokens in dataset per epoch:  706681
Num tokens trained per epoch:  1169216
Padding inflation ratio:  1.6545173847888934
