In [None]:
from pathlib import Path

import plotnine as pn
import polars as pl
import srsly

from src.utilities import load_tokenizer_with_vocab_size

In [None]:
tok_path = Path("/home/pl487/rdd/outputs/tokenizer_train/bpe_2024-09-04T12-59-54/")
data_path = Path("data/slim-pajama-subset-validation/")
vocab_size = 32000

In [None]:
data = (
    pl.scan_parquet(f"hf://datasets/pietrolesci/slim-pajama-subset-validation/bpe{vocab_size}/train-*.parquet")
    .with_columns(token_position=pl.int_ranges(pl.col("input_ids").list.len()))
    .collect()
)

In [None]:
merges = pl.DataFrame(srsly.read_jsonl(tok_path / "implemented_merges.jsonl")).with_columns(
    pl.col("new_token_id").cast(pl.Int32)
)

In [None]:
# get window around cutoff
num_tok_window = 1_000
df = merges.filter(
    (pl.col("new_token_id") >= vocab_size - num_tok_window) & (pl.col("new_token_id") < vocab_size + num_tok_window)
)

In [None]:
in_vocab = df.filter(pl.col("new_token_id") < vocab_size)
out_vocab = df.filter(pl.col("new_token_id") >= vocab_size)

In [None]:
# get document uid and position of the token in doc for tokens in vocab
in_vocab_index = data.explode(["input_ids", "token_position"]).join(
    in_vocab.select(["new_token_id"]), left_on="input_ids", right_on="new_token_id", how="right"
)

In [None]:
# get document uid and position of the token in doc for tokens not in vocab
out_vocab = (
    out_vocab
    .with_columns(pl.col("pair").list.to_struct())
    .unnest("pair")
    .rename({"field_0": "tok_a", "field_1": "tok_b"})
    .with_columns(pl.col("tok_a").cast(pl.Int32), pl.col("tok_b").cast(pl.Int32))
)

In [None]:
out_vocab_index = (
    data
    .explode(["input_ids", "token_position"])
    .with_columns(next_input_id=pl.col("input_ids").shift(-1))
    .join(out_vocab.select(["tok_a", "tok_b", "new_token_id"]), left_on=["input_ids", "next_input_id"], right_on=["tok_a", "tok_b"], how="right")
    # .drop(["input_ids", "next_input_id"])
    .rename({"token_position": "token_position_a"})
    .with_columns(token_position_b=pl.col("token_position_a") + 1)
)
    

In [None]:
# get docs per each token
num_samples = 50

in_vocab_sample = in_vocab_index.filter(
    # https://stackoverflow.com/a/72636610
    pl.int_range(pl.len()).shuffle(seed=42).over("new_token_id") < num_samples
)

out_vocab_sample = out_vocab_index.filter(
    # https://stackoverflow.com/a/72636610
    pl.int_range(pl.len()).shuffle(seed=42).over("new_token_id") < num_samples
)

In [None]:
# from each doc, get the context (with the token appended) of the required size (+1, since the token is appended)
context_length = 2048

in_vocab_df = (
    in_vocab_sample
    .join(data.select(["uid", "input_ids"]), on="uid", how="left")
    .with_columns(
        context_start=(
            pl.when(pl.col("token_position") > context_length)
            .then(pl.col("token_position") - context_length)
            .otherwise(0)
        ),
    )
    .with_columns(
        context=pl.col("input_ids").list.slice(
            offset=pl.col("context_start"), 
            length=pl.col("token_position") - pl.col("context_start") + 1,
        )
    )
    .drop(["input_ids"])     
)

# check that last token in context is exactly the token we want to predict
assert in_vocab_df.with_columns(pl.col("context").list.get(-1) == pl.col("new_token_id"))["context"].all()

In [None]:
out_vocab_df = (
    out_vocab_sample
    .join(data.select(["uid", "input_ids"]), on="uid", how="left")
    .with_columns(
        context_start=(
            pl.when(pl.col("token_position_a") > context_length)
            .then(pl.col("token_position_a") - context_length)
            .otherwise(0)
        ),
    )
    .with_columns(
        context=pl.col("input_ids").list.slice(
            offset=pl.col("context_start"), 
            length=pl.col("token_position_a") - pl.col("context_start") + 2,
        )
    )
    .drop(["input_ids"])
)

In [None]:
# check that last token in context is exactly the first token we want to predict
# assert out_vocab_df.with_columns(pl.col("context").list.get(-2) == pl.col("tok_a"))["context"].all()

# check that last token in context is exactly the second token we want to predict
assert out_vocab_df.with_columns(pl.col("context").list.get(-1) == pl.col("tok_b"))["context"].all()

In [None]:
out_vocab_df.filter(pl.col("context").list.get(-1) != pl.col("tok_b"))

In [None]:
out_vocab_df

In [None]:
a = data.head()

In [None]:
a = a.with_columns(index=pl.int_range(pl.len()))

In [None]:
a.with_columns(other=pl.col("input_ids").list.slice((, 2))

In [None]:
out_vocab.head()

In [None]:
# in_vocab_index.group_by("new_token_id").agg(
#     q25=pl.col("token_position").quantile(.25),
#     median=pl.col("token_position").median(), 
#     q75=pl.col("token_position").quantile(.75),
# )

In [None]:
s = in_vocab_index.filter(
    # https://stackoverflow.com/a/72636610
    pl.int_range(pl.len()).shuffle().over("new_token_id") < 10
)

In [None]:
# only keep unique doc-tok pairs (if a token appears multiple times in the same doc, keep only one)
# which one is kept is not deterministic
in_vocab_index = in_vocab_index.unique(subset=["uid", "new_token_id"])

In [None]:
in_vocab_index.group_by("new_token_id").agg(num_docs_per_token=pl.col("uid").len()).filter(pl.col("num_docs_per_token") < 100).sort("num_docs_per_token")

In [None]:
s.group_by("new_token_id").agg(pl.col("uid").n_unique())["uid"].value_counts()

In [None]:
in_vocab_index.group_by("new_token_id").agg(
    uid=pl.col("uid"),
    uid=pl.col("uid"),
)

In [None]:
out_vocab_index

In [None]:
out_vocab_index.shu

In [None]:
def loc_of(value) -> pl.Expr:
    # https://github.com/pola-rs/polars/issues/5503#issuecomment-1315401973
    # only execute if the item is contained in the list
    return (
        pl.when(pl.col("input_ids").list.contains(value))
        .then(
            # create array of True/False, then cast to 1's and 0's
            # arg_max() then finds the first occurrence of 1, i.e. the first occurence of value
            pl.col("input_ids").list.eval((pl.element() == value).cast(pl.UInt8).arg_max(), parallel=True).list.first()
        )
        .otherwise(None)  # return null if not found
    )

In [None]:
token_doc_index

In [None]:
# min_context_len = 200
max_num_seq = 500
seq_with_token = (
    data.with_columns(loc=loc_of(14))
    .drop_nulls("loc")
    # .filter(pl.col("loc") >= min_context_len)
    # .sort("loc", descending=True)
    .head(max_num_seq)
    # .collect()
)

In [None]:
(seq_with_token.with_columns(pl.col("input_ids").list.slice(pl.col("loc") - min_context_len, min_context_len + 1)))

In [None]:
q = split_tokens.with_columns(pl.col("pair").cast(pl.List(pl.String)).list.join(","))["pair"][0]

In [None]:
(
    data
    # .head()
    .filter(pl.col("input_ids").cast(pl.List(pl.String)).list.join(",").str.contains(q))
)

In [None]:
data[1]["input_ids"].to_list()[0].index(14)

In [None]:
(data.explode("input_ids").with_columns(j=pl.first().cumcount().over("i")))

In [None]:
merges

In [None]:
max(tok.vocab.values())

In [None]:
conf = srsly.read_json(path / "tokenizer.json")

In [None]:
pl.DataFrame([{"part_a": m[0], "part_b": m[1]} for m in conf["model"]["merges"]])

In [None]:
tok = load_tokenizer_with_vocab_size(path, 1000)

In [None]:
max(tok.vocab.values())

In [None]:
dir(tok.backend_tokenizer.model)

In [None]:
vocab = pl.DataFrame({"tokens": [i for i in conf["model"]["vocab"]]})

In [None]:
vocab.join(impl_merges, left_on="tokens", right_on="new_token", how="anti")

In [None]:
(pn.ggplot(impl_merges.with_row_index(), pn.aes(y="index", x="count")) + pn.geom_line() + pn.scale_x_log10())