In [1]:
from pathlib import Path

import polars as pl
import srsly
from datasets import Dataset

In [2]:
# # Get dataset
# splits = {'validation': 'data/validation-00000-of-00001-a2192e61a091cecb.parquet', 'test': 'data/test-00000-of-00001-010a6231c4b54d31.parquet'}
# df = pl.concat([
#     pl.scan_parquet(f'hf://datasets/JeanKaddour/minipile/{v}').with_columns(split=pl.lit(k)) 
#     for k, v in splits.items()
# ]).collect()
# df = df.with_row_index().rename({"index": "uid"})
# df.write_parquet("data/minipile-eval.parquet")

In [33]:
tok_type = "bpe"
vocab_size = 32_000
dataset = "minipile"
tok_path = Path(f"./outputs/tokenizers/{tok_type}{vocab_size}{dataset}")
out_path = Path(f"data/{dataset}-eval-{tok_type}{vocab_size}{dataset}")
out_path.mkdir(exist_ok=True, parents=True)

In [4]:
# # Load tokenized corpus
# file_path = out_path / "data.parquet"
# if file_path.exists():
#     df = pl.read_parquet(file_path)
# else:
#     df = (
#         pl.scan_parquet(f"hf://datasets/pietrolesci/slim-pajama-eval/{tok_type}{vocab_size}/train-*.parquet")
#         .with_columns(tok_pos=pl.int_ranges(pl.col("input_ids").list.len()), seq_len=pl.col("input_ids").list.len())
#         .collect()
#     )
#     df.write_parquet(file_path)

In [17]:
df = (
    pl.read_parquet("data/minipile-eval-bpe32000minipile.parquet")
    .with_columns(tok_pos=pl.int_ranges(pl.col("input_ids").list.len()), seq_len=pl.col("input_ids").list.len())
)

In [18]:
# Load merges
with (tok_path / "raw_tok_path.txt").open("r") as fl:
    raw_tok_path = Path(fl.read())

merges_df = pl.DataFrame(srsly.read_jsonl(raw_tok_path / "implemented_merges.jsonl")).with_columns(
    pl.col("new_token_id").cast(pl.Int32)
)

In [19]:
# Get window around cutoff
num_tok_window = 1_500
window_df = merges_df.filter(
    (pl.col("new_token_id") >= vocab_size - num_tok_window) & (pl.col("new_token_id") < vocab_size + num_tok_window)
)

In [20]:
# Create query for in- and out- of vocab tokens around the cutoff
in_vocab = window_df.filter(pl.col("new_token_id") < vocab_size)

out_vocab = window_df.filter(pl.col("new_token_id") >= vocab_size)
out_vocab = (
    out_vocab.with_columns(pl.col("pair").list.to_struct())
    .unnest("pair")
    .rename({"field_0": "tok_a", "field_1": "tok_b"})
    .with_columns(pl.col("tok_a").cast(pl.Int32), pl.col("tok_b").cast(pl.Int32))
)

In [21]:
# We want at least 3 tokens in the context because
# we evaluate on 2 tokens and need 1 for context
dfe = df.lazy().explode(["input_ids", "tok_pos"]).filter(pl.col("tok_pos") >= 3)

In [22]:
# Get uid and tok_pos for tokens in window
in_vocab_index = (
    in_vocab.lazy().select(["new_token_id"]).join(dfe, left_on="new_token_id", right_on="input_ids", how="inner")
)

In [23]:
out_vocab_index = (
    out_vocab.lazy()
    .select(["tok_a", "tok_b", "new_token_id"])
    .join(
        dfe.with_columns(next_input_id=pl.col("input_ids").shift(-1)),
        left_on=["tok_a", "tok_b"],
        right_on=["input_ids", "next_input_id"],
        how="inner",
    )
    .rename({"tok_pos": "tok_pos_a"})
    .with_columns(tok_pos_b=pl.col("tok_pos_a") + 1)
    .filter(pl.col("tok_pos_b") < pl.col("seq_len"))
)

In [24]:
# # Check
# window_df.filter(
#     pl.col("new_token_id")
#     .is_in(out_vocab_index["new_token_id"].to_list() + in_vocab_index["new_token_id"].to_list())
#     .not_()
# )

In [26]:
# Sample num_samples docs for each token
num_samples = 100

in_vocab_sample = in_vocab_index.filter(
    # https://stackoverflow.com/a/72636610
    pl.int_range(pl.len()).shuffle(seed=42).over("new_token_id") < num_samples
).collect()

out_vocab_sample = out_vocab_index.filter(
    # https://stackoverflow.com/a/72636610
    pl.int_range(pl.len()).shuffle(seed=42).over("new_token_id") < num_samples
).collect()

In [27]:
# Get context
context_length = 2048

in_vocab_df = (
    in_vocab_sample.join(df.select(["uid", "input_ids"]), on="uid", how="inner")
    .with_columns(
        context_start=(
            pl.when(pl.col("tok_pos") > context_length).then(pl.col("tok_pos") - context_length).otherwise(0)
        )
    )
    .with_columns(
        context=pl.col("input_ids").list.slice(
            offset=pl.col("context_start"), length=pl.col("tok_pos") - pl.col("context_start") + 1
        )
    )
    .drop(["input_ids"])
)

# check that last token in context is exactly the token we want to predict
assert in_vocab_df.with_columns(pl.col("context").list.get(-1) == pl.col("new_token_id"))["context"].all()

In [28]:
out_vocab_df = (
    out_vocab_sample.join(df.select(["uid", "input_ids"]), on="uid", how="left")
    .with_columns(
        context_start=(
            pl.when(pl.col("tok_pos_b") > context_length).then(pl.col("tok_pos_a") - context_length + 1).otherwise(0)
        )
    )
    .with_columns(
        context=pl.col("input_ids").list.slice(
            offset=pl.col("context_start"), length=pl.col("tok_pos_a") - pl.col("context_start") + 2
        )
    )
    .drop(["input_ids"])
)

# Check that penultimate token in context is exactly the first token we want to predict
assert out_vocab_df.with_columns(pl.col("context").list.get(-2) == pl.col("tok_a"))["context"].all()

# Check that last token in context is exactly the second token we want to predict
assert out_vocab_df.with_columns(pl.col("context").list.get(-1) == pl.col("tok_b"))["context"].all()

In [29]:
in_vocab_df = in_vocab_df.with_columns(context_len=pl.col("context").list.len()).select(
    ["new_token_id", "uid", "seq_len", "tok_pos", "context_start", "context_len", "context"]
)

out_vocab_df = out_vocab_df.with_columns(context_len=pl.col("context").list.len()).select(
    [
        "new_token_id",
        "tok_a",
        "tok_b",
        "uid",
        "seq_len",
        "tok_pos_a",
        "tok_pos_b",
        "context_start",
        "context_len",
        "context",
    ]
)

In [34]:
in_vocab_df.write_parquet(out_path / "in_vocab_samples.parquet")
out_vocab_df.write_parquet(out_path / "out_vocab_samples.parquet")

In [6]:
in_vocab_df = pl.read_parquet(out_path / "in_vocab_samples.parquet")
out_vocab_df = pl.read_parquet(out_path / "out_vocab_samples.parquet")

In [35]:
all_context = pl.concat(
    [
        in_vocab_df.rename({"context": "input_ids"}).select(["new_token_id", "uid", "input_ids"]),
        out_vocab_df.rename({"context": "input_ids"}).select(["new_token_id", "uid", "input_ids"]),
    ]
)
all_context = all_context.sort(pl.col("input_ids").list.len(), descending=True)
ds = Dataset.from_polars(all_context)
ds.save_to_disk(out_path / "eval_samples")

Saving the dataset (0/2 shards):   0%|          | 0/100708 [00:00<?, ? examples/s]

In [36]:
s# Check
all_context["new_token_id"].value_counts().filter(pl.col("count") < num_samples)

new_token_id,count
i32,u64
31822,54
33238,20
31975,20
32997,22
31161,39
…,…
33022,2
33420,6
33300,23
32742,16
