In [1]:
from pathlib import Path

import plotnine as pn
import polars as pl
import srsly
from transformers import AutoTokenizer

from src.utilities import load_tokenizer_with_vocab_size


In [9]:
path = Path("/home/pl487/rdd/outputs/tokenizer_train/2024-08-30T12-00-43/")

merges = (
    pl.DataFrame(srsly.read_jsonl(path / "implemented_merges.jsonl"))
    .with_columns(pl.col("new_token_id").cast(pl.Int32))
)


tok = AutoTokenizer.from_pretrained(path / "tok-vocab32000")

In [10]:
window = 1000
df = merges.filter((pl.col("new_token_id") >= tok.vocab_size - window) & (pl.col("new_token_id") < tok.vocab_size + window))

In [11]:
full_tokens = df.filter(pl.col("new_token_id") < tok.vocab_size)
split_tokens = df.filter(pl.col("new_token_id") >= tok.vocab_size)

In [12]:
full_tokens

pair,part_a,part_b,count,new_token_id,new_token
list[i64],str,str,i64,i32,str
"[710, 73]","""ath""","""i""",25738,31000,"""athi"""
"[390, 41]","""ĠN""","""I""",25737,31001,"""ĠNI"""
"[1398, 266]","""hor""","""se""",25735,31002,"""horse"""
"[1037, 69]","""ĠHow""","""e""",25733,31003,"""ĠHowe"""
"[769, 464]","""Ġfl""","""ap""",25729,31004,"""Ġflap"""
…,…,…,…,…,…
"[248, 79]","""as""","""o""",24406,31995,"""aso"""
"[686, 1577]","""Ġpr""","""une""",24406,31996,"""Ġprune"""
"[298, 2867]","""el""","""ius""",24405,31997,"""elius"""
"[320, 717]","""ĠP""","""oss""",24403,31998,"""ĠPoss"""


In [156]:
data = (
    pl.scan_parquet('hf://datasets/pietrolesci/slim-pajama-subset-validation/tok-vocab32000/train-*.parquet')
    .with_columns(token_index=pl.int_ranges(pl.col("input_ids").list.len()))
    .collect()
)

In [157]:
# get document uid and position of the token in doc for tokens in vocab
token_doc_index = (
    data
    .explode(["input_ids", "token_index"])
    .join(full_tokens.select(["new_token_id"]), left_on="input_ids", right_on="new_token_id", how="right")
)

In [63]:
def loc_of(value) -> pl.Expr:
    # https://github.com/pola-rs/polars/issues/5503#issuecomment-1315401973
    # only execute if the item is contained in the list
    return (
        pl.when(pl.col("input_ids").list.contains(value))
        .then(
            # create array of True/False, then cast to 1's and 0's
            # arg_max() then finds the first occurrence of 1, i.e. the first occurence of value
            pl.col("input_ids").list.eval((pl.element() == value).cast(pl.UInt8).arg_max(), parallel=True).list.first()
        )
        .otherwise(None) # return null if not found
    )

In [68]:
min_context_len = 200
max_num_seq = 1_000
seq_with_token = (
    data
    .with_columns(loc=loc_of(14))
    .drop_nulls("loc")
    .filter(pl.col("loc") >= min_context_len)
    .sort("loc", descending=True)
    .head(max_num_seq)
    # .collect()
)

In [92]:
(
    seq_with_token
    .with_columns(pl.col("input_ids").list.slice(pl.col("loc") - min_context_len, min_context_len + 1))
)

uid,input_ids,loc
u32,list[i32],u32
447805,"[60, 60, … 14]",65554
72472,"[24625, 9297, … 14]",9663
62777,"[14909, 212, … 14]",5276
102865,"[2692, 4951, … 14]",4863
275010,"[26, 409, … 14]",4138
…,…,…
304352,"[17816, 2301, … 14]",376
312298,"[1944, 174, … 14]",376
349012,"[20, 15, … 14]",376
3096,"[1347, 6700, … 14]",375


In [113]:
q = split_tokens.with_columns(pl.col("pair").cast(pl.List(pl.String)).list.join(","))["pair"][0]

In [116]:
(
    data
    # .head()
    .filter(pl.col("input_ids").cast(pl.List(pl.String)).list.join(",").str.contains(q))
)

uid,input_ids
u32,list[i32]
5509,"[8, 30197, … 14]"
6084,"[2087, 9755, … 1074]"
7077,"[12588, 228, … 227]"
12641,"[15167, 8273, … 14]"
12979,"[18094, 256, … 4435]"
…,…
484729,"[4829, 785, … 1944]"
490542,"[2748, 2324, … 6588]"
493797,"[43, 13, … 16113]"
496269,"[2253, 238, … 14]"


In [45]:
data[1]["input_ids"].to_list()[0].index(14)

71

In [31]:
(
    data
    .explode("input_ids")
    .with_columns(j=pl.first().cumcount().over('i'))
)

uid,input_ids
u32,i32
1,18090
1,18834
1,13
1,35
1,1986
…,…
1080,268
1080,598
1080,787
1080,2434


In [None]:
merges

In [None]:
max(tok.vocab.values())

In [None]:
conf = srsly.read_json(path / "tokenizer.json")

In [None]:
pl.DataFrame([{"part_a": m[0], "part_b": m[1]} for m in conf["model"]["merges"]])

In [None]:
tok = load_tokenizer_with_vocab_size(path, 1000)

In [None]:
max(tok.vocab.values())

In [None]:
dir(tok.backend_tokenizer.model)

In [None]:
vocab = pl.DataFrame({"tokens": [i for i in conf["model"]["vocab"]]})

In [None]:
vocab.join(impl_merges, left_on="tokens", right_on="new_token", how="anti")

In [None]:
(pn.ggplot(impl_merges.with_row_index(), pn.aes(y="index", x="count")) + pn.geom_line() + pn.scale_x_log10())