In [1]:
import polars as pl
from datatrove.utils.dataset import DatatroveFolderDataset
from tqdm.auto import tqdm

from tokenizers import Tokenizer

In [25]:
from fsspec.core import url_to_fs

In [26]:
fs, file_path = url_to_fs("hf://datasets/pietrolesci/fineweb-edu-10BT/bpe32000")

In [34]:
fs.glob("hf://datasets/pietrolesci/fineweb-edu-10BT/bpe32000/*.ds")

['datasets/pietrolesci/fineweb-edu-10BT/bpe32000/000_bpe32000.ds',
 'datasets/pietrolesci/fineweb-edu-10BT/bpe32000/001_bpe32000.ds']

In [41]:
# Get tokenized data
ds = DatatroveFolderDataset(
    folder_path="hf://datasets/pietrolesci/fineweb-edu-10BT/bpe32000",
    filename_pattern="hf://datasets/pietrolesci/fineweb-edu-10BT/bpe32000/*.ds",
    seq_len=512,
    shuffle=False,
    seed=42,
    token_size=2 if 32000 < 65_000 else 4,
)

# Get original data
df = pl.scan_parquet("hf://datasets/HuggingFaceFW/fineweb-edu/sample/10BT/*.parquet").tail().collect()

# Load tokenizer
tok = Tokenizer.from_file("outputs/tokenizers/bpe32000/tokenizer.json")
tok.eos_token_id = 0

In [42]:
for f in ds.files:
    print(f.file_path)

datasets/pietrolesci/fineweb-edu-10BT/bpe32000/000_bpe32000.ds
datasets/pietrolesci/fineweb-edu-10BT/bpe32000/001_bpe32000.ds


In [43]:
len(ds)

19507780

In [None]:
# check batches with ids >= vocab_size
vocab_size = tok.get_vocab_size()

wrong_batch = []
for _i, b in enumerate(tqdm(ds)):
    if b["input_ids"].max() >= vocab_size:
        wrong_batch.append(b)
        break

In [44]:
START_INDEX = 19493178  #  from here the problem starts
ds[START_INDEX]["input_ids"].max()

tensor(31979)

In [7]:
batch = ds[START_INDEX]["input_ids"].numpy()

In [10]:
tok.decode(batch.tolist(), skip_special_tokens=False)

' mod<|endoftext|><|endoftext|><|endoftext|> cases<|endoftext|><|endoftext|><|endoftext|> stat<|endoftext|><|endoftext|><|endoftext|>iology<|endoftext|><|endoftext|><|endoftext|> neur<|endoftext|><|endoftext|><|endoftext|> Finally<|endoftext|><|endoftext|><|endoftext|>John<|endoftext|><|endoftext|><|endoftext|> stood<|endoftext|><|endoftext|><|endoftext|> contributing<|endoftext|><|endoftext|><|endoftext|> diets<|endoftext|><|endoftext|><|endoftext|> Guard<|endoftext|><|endoftext|><|endoftext|> rend<|endoftext|><|endoftext|><|endoftext|> parliament<|endoftext|><|endoftext|><|endoftext|> Educational<|endoftext|><|endoftext|><|endoftext|> billions<|endoftext|><|endoftext|><|endoftext|> priorit<|endoftext|><|endoftext|><|endoftext|>tem<|endoftext|><|endoftext|><|endoftext|> Wang<|endoftext|><|endoftext|><|endoftext|> pear<|endoftext|><|endoftext|><|endoftext|> twin<|endoftext|><|endoftext|><|endoftext|> apnea<|endoftext|><|endoftext|><|endoftext|> risen<|endoftext|><|endoftext|><|endoftex

In [None]:
print(tok.decode(ds[len(ds) - 1]["input_ids"].tolist()))

In [None]:
print(df[-1]["text"].item())

In [None]:
t = ds[len(ds) - 1]["input_ids"].tolist()
s = [tok.decode([i], skip_special_tokens=False) for i in t]

In [None]:
t[0]

In [None]:
tok.encode(tok.decode([t[1]])).ids

In [None]:
t[1], s[1]

In [None]:
print(df[0]["text"].item())

In [None]:
s = wrong_batch[0]["input_ids"]

In [None]:
import polars as pl

pl.DataFrame(zip([tok.id_to_token(i) for i in s.tolist()], s.tolist(), strict=False))

In [None]:
tok.token_to_id("")

In [None]:
s.unsqueeze(-1).shape

In [None]:
batch_size = 64
s = ds[512 * batch_size * 58]["input_ids"]
s.max()

In [None]:
import srsly

conf: dict = srsly.read_json("outputs/tokenizers/bpe32000/tokenizer.json")  # type: ignore

In [None]:
toks_in_merge = [tok.token_to_id(i) for j in conf["model"]["merges"] for i in j]

In [None]:
max(toks_in_merge)

In [None]:
from pathlib import Path

import plotnine as pn
import polars as pl
import srsly
from transformers import AutoTokenizer

from src.utilities import load_tokenizer_with_vocab_size

In [None]:
path = Path("/home/pl487/rdd/outputs/tokenizer_train/2024-08-30T12-00-43/")

merges = pl.DataFrame(srsly.read_jsonl(path / "implemented_merges.jsonl")).with_columns(
    pl.col("new_token_id").cast(pl.Int32)
)


tok = AutoTokenizer.from_pretrained(path / "tok-vocab32000")

In [None]:
window = 1000
df = merges.filter(
    (pl.col("new_token_id") >= tok.vocab_size - window) & (pl.col("new_token_id") < tok.vocab_size + window)
)

In [None]:
full_tokens = df.filter(pl.col("new_token_id") < tok.vocab_size)
split_tokens = df.filter(pl.col("new_token_id") >= tok.vocab_size)

In [None]:
full_tokens

In [None]:
data = (
    pl.scan_parquet("hf://datasets/pietrolesci/slim-pajama-subset-validation/tok-vocab32000/train-*.parquet")
    .with_columns(token_index=pl.int_ranges(pl.col("input_ids").list.len()))
    .collect()
)

In [None]:
# get document uid and position of the token in doc for tokens in vocab
token_doc_index = data.explode(["input_ids", "token_index"]).join(
    full_tokens.select(["new_token_id"]), left_on="input_ids", right_on="new_token_id", how="right"
)

In [None]:
def loc_of(value) -> pl.Expr:
    # https://github.com/pola-rs/polars/issues/5503#issuecomment-1315401973
    # only execute if the item is contained in the list
    return (
        pl.when(pl.col("input_ids").list.contains(value))
        .then(
            # create array of True/False, then cast to 1's and 0's
            # arg_max() then finds the first occurrence of 1, i.e. the first occurence of value
            pl.col("input_ids").list.eval((pl.element() == value).cast(pl.UInt8).arg_max(), parallel=True).list.first()
        )
        .otherwise(None)  # return null if not found
    )

In [None]:
min_context_len = 200
max_num_seq = 1_000
seq_with_token = (
    data.with_columns(loc=loc_of(14))
    .drop_nulls("loc")
    .filter(pl.col("loc") >= min_context_len)
    .sort("loc", descending=True)
    .head(max_num_seq)
    # .collect()
)

In [None]:
(seq_with_token.with_columns(pl.col("input_ids").list.slice(pl.col("loc") - min_context_len, min_context_len + 1)))

In [None]:
q = split_tokens.with_columns(pl.col("pair").cast(pl.List(pl.String)).list.join(","))["pair"][0]

In [None]:
(
    data
    # .head()
    .filter(pl.col("input_ids").cast(pl.List(pl.String)).list.join(",").str.contains(q))
)

In [None]:
data[1]["input_ids"].to_list()[0].index(14)

In [None]:
(data.explode("input_ids").with_columns(j=pl.first().cumcount().over("i")))

In [None]:
merges

In [None]:
max(tok.vocab.values())

In [None]:
conf = srsly.read_json(path / "tokenizer.json")

In [None]:
pl.DataFrame([{"part_a": m[0], "part_b": m[1]} for m in conf["model"]["merges"]])

In [None]:
tok = load_tokenizer_with_vocab_size(path, 1000)

In [None]:
max(tok.vocab.values())

In [None]:
dir(tok.backend_tokenizer.model)

In [None]:
vocab = pl.DataFrame({"tokens": [i for i in conf["model"]["vocab"]]})

In [None]:
vocab.join(impl_merges, left_on="tokens", right_on="new_token", how="anti")

In [None]:
(pn.ggplot(impl_merges.with_row_index(), pn.aes(y="index", x="count")) + pn.geom_line() + pn.scale_x_log10())