In [74]:
from pathlib import Path

import polars as pl
import srsly
from datasets import Dataset

In [28]:
tok_path = Path("/home/pl487/rdd/outputs/tokenizer_train/bpe_2024-09-04T12-59-54/")
data_path = Path("data/slim-pajama-subset-validation/")
vocab_size = 32000

In [67]:
tok_type = tok_path.name.split("_")[0]
assert tok_type in ("bpe",)

tok_name = f"{tok_type}{vocab_size}"
out_path = data_path.parent / f"{data_path.name}-sample-{tok_name}"
out_path.mkdir(exist_ok=True, parents=True)

In [29]:
data = (
    pl.scan_parquet(f"hf://datasets/pietrolesci/slim-pajama-subset-validation/bpe{vocab_size}/train-*.parquet")
    .with_columns(tok_pos=pl.int_ranges(pl.col("input_ids").list.len()), seq_len=pl.col("input_ids").list.len())
    .collect()
)

In [30]:
merges = pl.DataFrame(srsly.read_jsonl(tok_path / "implemented_merges.jsonl")).with_columns(
    pl.col("new_token_id").cast(pl.Int32)
)

In [31]:
# get window around cutoff
num_tok_window = 1_000
df = merges.filter(
    (pl.col("new_token_id") >= vocab_size - num_tok_window) & (pl.col("new_token_id") < vocab_size + num_tok_window)
)

In [32]:
in_vocab = df.filter(pl.col("new_token_id") < vocab_size)
out_vocab = df.filter(pl.col("new_token_id") >= vocab_size)

out_vocab = (
    out_vocab.with_columns(pl.col("pair").list.to_struct())
    .unnest("pair")
    .rename({"field_0": "tok_a", "field_1": "tok_b"})
    .with_columns(pl.col("tok_a").cast(pl.Int32), pl.col("tok_b").cast(pl.Int32))
)

In [33]:
# get document uid and position of the token in doc for tokens in vocab
in_vocab_index = data.explode(["input_ids", "tok_pos"]).join(
    in_vocab.select(["new_token_id"]), left_on="input_ids", right_on="new_token_id", how="right"
)

In [34]:
out_vocab_index = (
    data.explode(["input_ids", "tok_pos"])
    .with_columns(next_input_id=pl.col("input_ids").shift(-1))
    .join(
        out_vocab.select(["tok_a", "tok_b", "new_token_id"]),
        left_on=["input_ids", "next_input_id"],
        right_on=["tok_a", "tok_b"],
        how="right",
    )
    .drop(["input_ids", "next_input_id"])
    .rename({"tok_pos": "tok_pos_a"})
    .with_columns(tok_pos_b=pl.col("tok_pos_a") + 1)
    .filter(pl.col("tok_pos_b") < pl.col("seq_len"))
)

In [35]:
# get docs per each token
num_samples = 50

in_vocab_sample = in_vocab_index.filter(
    # https://stackoverflow.com/a/72636610
    pl.int_range(pl.len()).shuffle(seed=42).over("new_token_id") < num_samples
)

out_vocab_sample = out_vocab_index.filter(
    # https://stackoverflow.com/a/72636610
    pl.int_range(pl.len()).shuffle(seed=42).over("new_token_id") < num_samples
)

In [36]:
# from each doc, get the context (with the token appended) of the required size (+1, since the token is appended)
context_length = 2048

in_vocab_df = (
    in_vocab_sample.join(data.select(["uid", "input_ids"]), on="uid", how="left")
    .with_columns(
        context_start=(
            pl.when(pl.col("tok_pos") > context_length).then(pl.col("tok_pos") - context_length).otherwise(0)
        )
    )
    .with_columns(
        context=pl.col("input_ids").list.slice(
            offset=pl.col("context_start"), length=pl.col("tok_pos") - pl.col("context_start") + 1
        )
    )
    .drop(["input_ids"])
)

# check that last token in context is exactly the token we want to predict
assert in_vocab_df.with_columns(pl.col("context").list.get(-1) == pl.col("new_token_id"))["context"].all()

In [37]:
out_vocab_df = (
    out_vocab_sample.join(data.select(["uid", "input_ids"]), on="uid", how="left")
    .with_columns(
        context_start=(
            pl.when(pl.col("tok_pos_b") > context_length).then(pl.col("tok_pos_a") - context_length).otherwise(0)
        )
    )
    .with_columns(
        context=pl.col("input_ids").list.slice(
            offset=pl.col("context_start"), length=pl.col("tok_pos_a") - pl.col("context_start") + 2
        )
    )
    .drop(["input_ids"])
)

# check that last token in context is exactly the first token we want to predict
assert out_vocab_df.with_columns(pl.col("context").list.get(-2) == pl.col("tok_a"))["context"].all()

# check that last token in context is exactly the second token we want to predict
assert out_vocab_df.with_columns(pl.col("context").list.get(-1) == pl.col("tok_b"))["context"].all()

In [40]:
in_vocab_df = in_vocab_df.with_columns(context_len=pl.col("context").list.len()).select(
    ["new_token_id", "uid", "seq_len", "tok_pos", "context_start", "context_len", "context"]
)

In [43]:
out_vocab_df = out_vocab_df.with_columns(context_len=pl.col("context").list.len()).select(
    [
        "new_token_id",
        "tok_a",
        "tok_b",
        "uid",
        "seq_len",
        "tok_pos_a",
        "tok_pos_b",
        "context_start",
        "context_len",
        "context",
    ]
)

In [70]:
in_vocab_df.write_parquet(out_path / "in_vocab.parquet")
out_vocab_df.write_parquet(out_path / "out_vocab.parquet")

In [75]:
all_context = pl.concat(
    [
        in_vocab_df.select(["new_token_id", "uid", "context"]), 
        out_vocab_df.select(["new_token_id", "uid", "context"]),
    ]
)

ds = Dataset.from_polars(all_context)

In [78]:
ds.save_to_disk(out_path / "contexts")

Saving the dataset (0/1 shards):   0%|          | 0/99699 [00:00<?, ? examples/s]