In [1]:
from pathlib import Path

import plotnine as pn
import polars as pl
import statsmodels.formula.api as smf
from datasets import load_from_disk
from IPython.display import HTML, display
from transformers import AutoTokenizer, PreTrainedTokenizer


In [2]:
def decode_sequence(tok: PreTrainedTokenizer, input_ids: list[int], highlight_ids: list[int], highlight_color: str = "green") -> None:
    print(f"Checking tokens ids: {highlight_ids}")
    print(f"Checking tokens: {tok.convert_ids_to_tokens(highlight_ids)}")
    
    # Convert token IDs to tokens
    tokens = tok.convert_ids_to_tokens(input_ids, skip_special_tokens=False)

    # Highlight tokens that are in highlight_ids
    highlighted_tokens = [
        f"<span style='background-color:{highlight_color}'>{tok}</span>" if input_ids[i] in highlight_ids else tok
        for i, tok in enumerate(tokens)
    ]

    # Convert tokens back to a single string
    decoded_string = tok.convert_tokens_to_string(highlighted_tokens)

    # Display the result in Jupyter notebook
    display(HTML(decoded_string))

In [3]:
tok_type = "bpe"
vocab_size = 32_000
dataset = "minipile"
tok_path = Path(f"./outputs/tokenizers/{tok_type}{vocab_size}{dataset}")

tok = AutoTokenizer.from_pretrained(tok_path)
eval_df = (
    pl.from_arrow(load_from_disk("data/minipile-eval-bpe32000minipile/eval_samples/").data.table)
    .with_columns(pl.col("uid").cast(pl.Int64), pl.col("new_token_id").cast(pl.Int64))
)

In [4]:
out_vocab_path = Path("/home/pl487/rdd/data/minipile-eval-bpe32000minipile/out_vocab_samples.parquet")
out_df = pl.read_parquet(out_vocab_path).with_columns(pl.col("uid").cast(pl.Int64), pl.col("new_token_id").cast(pl.Int64))

In [5]:
base_path = Path("/home/pl487/rdd/outputs/model_eval/")
exp_name = "smol_llama-81M-tied_bpe32000minipile_2024-09-30T19-42-18_last_2024-10-08T17-49-06"
path = base_path / exp_name / "eval_last.parquet"
df = (
    pl.read_parquet(path)
    # .with_row_index(name="doc_idx")
    # .with_columns(pl.col("doc_idx").cast(pl.UInt64))
)

In [6]:
# avg_pos = (
#     df
#     # .drop("doc_idx")
#     .with_columns(tok_pos=pl.int_ranges(pl.col("token_logprob").list.len()))
#     .explode(["token_logprob", "tok_pos"])
#     .group_by("tok_pos")
#     .agg(
#         mean=pl.col("token_logprob").mean(),
#         std=pl.col("token_logprob").std(),
#     )
#     .sort(["mean", "std"], descending=False)
# )
# avg_pos

In [7]:
df = (
    df
    .with_columns(in_vocab=pl.col("new_token_id") < 32000)
    .with_columns(
        token_logps=(
            pl.when(pl.col("in_vocab"))
            .then(pl.col("token_logprob").list.tail(1))
            .otherwise(pl.col("token_logprob").list.tail(2))
        )
    )
    .with_columns(logp=pl.col("token_logps").list.sum())
    .sort("logp")
)

In [8]:
df = (
    df.join(
        (
            out_df
            .select(["uid", "new_token_id", "tok_a", "tok_b"])
            .unique()
            .with_columns(pl.col("uid").cast(pl.Int64), pl.col("new_token_id").cast(pl.Int64))
        ),
        on=["uid", "new_token_id"],
        how="left",
    )
    .with_columns(toks=pl.when(pl.col("tok_a").is_null()).then(pl.concat_list("new_token_id")).otherwise(pl.concat_list(["tok_a", "tok_b"])))
    .drop(["tok_a", "tok_b"])
    .with_columns(in_vocab=pl.col("toks").list.len() == 1)
)

In [None]:
df.head()

In [None]:
(
    pn.ggplot(df.with_row_index(), pn.aes("index", "logp")) +
    pn.geom_line()
)

In [None]:
dd = (
    df.join(eval_df, on=["new_token_id", "uid"])
    .select(["new_token_id", "uid", "token_logprob", "input_ids"])
    .unique()
)

In [None]:
eval_df

In [None]:
dd = df.head(10).join(eval_df, on=["new_token_id", "uid"], how="inner")
dd

In [None]:
idx = 4
input_ids = dd[idx]["input_ids"].to_list()[0]
highlight_ids = dd[idx]["toks"].to_list()[0]
decode_sequence(tok, input_ids, highlight_ids)

In [None]:
input_ids[-3:]

In [100]:
window = 100
dd = (
    df.filter(
        # (pl.col("logp") > -3.5)
        # & (pl.col("logp") < -2.5)
        (pl.col("new_token_id") < 32000 + window)
        & (pl.col("new_token_id") > 32000 - window)
    )
    .group_by(["new_token_id", "in_vocab"])
    .agg(avg_logp=pl.col("logp").mean())
    .to_pandas()
)

In [None]:
# Fit model
rdd = smf.ols("avg_logp ~ new_token_id + in_vocab", dd).fit(cov_type="HC3")
rdd.summary()

In [None]:
p = (
    pn.ggplot(dd, pn.aes(x="new_token_id", y="avg_logp"))
    + pn.geom_point(alpha=0.15)
    + pn.geom_line(pn.aes(y=rdd.fittedvalues, color="in_vocab"), size=2)
    # + pn.coord_cartesian(ylim=(-8, -5))
    # + pn.geom_vline(xintercept=vocab_size / 1000, linetype="dashed", color="black")
    # + pn.scale_x_reverse()
    # + pn.labs(x="", y="", colour="", title=f"Discontinuity at threshold: {discontinuity_at_threshold:.2f}")
    + pn.scale_colour_discrete(guide=None)
    + pn.theme_bw()
    # pn.scale_y_log10()
)

p

In [None]:
win = .1
regdata = avg_df.filter((pl.col("new_token_id") < 32 + win) & (pl.col("new_token_id") > 32 - win))

# Fit model
rdd = smf.ols(
    "zscore ~ new_token_id + is_out_vocab + new_token_id:is_out_vocab", 
    regdata.to_pandas()
).fit(cov_type="HC3")

In [42]:
in_vocab_path = Path("/home/pl487/rdd/data/minipile-eval-bpe32000minipile/in_vocab_samples.parquet")


in_df = pl.read_parquet(in_vocab_path).with_columns(pl.col("uid").cast(pl.Int64), pl.col("new_token_id").cast(pl.Int64))


In [None]:
out_df

In [19]:
df = (
    df.join(
        (
            out_df
            .select(["uid", "new_token_id", "tok_a", "tok_b"])
            .with_columns(pl.col("uid").cast(pl.Int64), pl.col("new_token_id").cast(pl.Int64))
        ),
        on=["uid", "new_token_id"],
        how="left",
    )
    .with_columns(toks=pl.when(pl.col("tok_a").is_null()).then(pl.concat_list("new_token_id")).otherwise(pl.concat_list(["tok_a", "tok_b"])))
    .drop(["tok_a", "tok_b"])
    .with_columns(in_vocab=pl.col("toks").list.len() == 1)
)

In [None]:
df.with_columns(
    logp=(
        pl.when(pl.col("in_vocab"))
        .then(pl.col("token_logprob").list.tail(1))
        .otherwise(pl.col("token_logprob").list.tail(2))
    )
).filter(pl.col("in_vocab").not_())