In [1]:
from pathlib import Path

import causalpy as cp
import plotnine as pn
import polars as pl
import srsly
import statsmodels.formula.api as smf
import torch
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel


In [2]:
run_path = Path("/home/pl487/rdd/outputs/model_eval/smol_llama-81M-tied_bpe32000minipile_2024-09-30T19-42-18_last_2024-10-08T17-49-06")
df = pl.read_parquet(run_path / "eval_last.parquet")

In [3]:
eval_df = pl.from_arrow(load_from_disk("data/minipile-eval-bpe32000minipile/eval_samples/").data.table)

In [5]:
def load_hf_from_pl(checkpoint_path: str | Path) -> PreTrainedModel:
    checkpoint = torch.load(str(checkpoint_path), weights_only=False)
    state_dict = {
        k.removeprefix("model.").removeprefix("_orig_mod."): v
        for k, v in checkpoint["state_dict"].items()
        if k.startswith("model.")
    }
    config = checkpoint["hyper_parameters"].get("config")

    # HACK: temporary -- since first run for gpt2 was without this info

    model = AutoModelForCausalLM.from_config(config)
    model.load_state_dict(state_dict)
    return model


In [7]:
model = load_hf_from_pl("outputs/model_train_pl/smol_llama-81M-tied_bpe32000minipile_2024-09-30T19-42-18/.checkpoints/last.ckpt")

In [10]:
hparams = srsly.read_yaml("outputs/model_train_pl/smol_llama-81M-tied_bpe32000minipile_2024-09-30T19-42-18/hparams.yaml")
tok = AutoTokenizer.from_pretrained(hparams["tok_path"])

In [12]:
t = torch.tensor(eval_df[0]["input_ids"], dtype=torch.long)

In [None]:
logits = model.forward(input_ids=t[:, :-1]).logits
logits.shape

In [22]:
probs = logits.softmax(-1)
logprobs = logits.log_softmax(-1)

In [23]:
labels = t[..., 1:]
token_prob = probs.take_along_dim(dim=-1, indices=labels[..., None]).squeeze(-1)
token_logprobs = logprobs.take_along_dim(dim=-1, indices=labels[..., None]).squeeze(-1)

In [None]:
torch.allclose(token_logprobs, token_prob.log(), atol=0.000001)

In [37]:
ce = torch.nn.functional.cross_entropy(logits.permute(0, 2, 1), labels, reduction="none")

In [None]:
token_logprobs

In [None]:
ce

In [None]:
torch.allclose(ce.neg(), token_logprobs, atol=0.001)

In [4]:
# def flatten(x: list[list]) -> list:
#     return [i for j in x for i in j]


# def ld_to_dl(ld: list[dict]) -> dict[str, list]:
#     return {k: [dic[k] for dic in ld] for k in ld[0]}


# fl = srsly.read_jsonl(
#     "/home/pl487/rdd/outputs/model_eval/pythia-9M-bpe32000_checkpoint-50000_2024-09-17T17-15-44/pythia-9M-bpe32000_checkpoint-50000.jsonl"
# )
# df = pl.DataFrame({k: flatten(v) for k, v in ld_to_dl(line).items()} for line in fl)  # type: ignore
# df = df.explode(df.columns)

In [6]:
raw_tok_path = Path("/home/pl487/rdd/outputs/tok_train/bpe_minipile_2024-09-22T17-58-54")
tok = pl.DataFrame(srsly.read_jsonl(Path(raw_tok_path) / "implemented_merges.jsonl"))

In [7]:
tok_type = "bpe"
vocab_size = 32_000

In [None]:
EPS = 1e-8
df = (
    df.with_columns(is_out_vocab=pl.col("new_token_id") >= vocab_size).with_columns(
        log_prob_true=(
            # For token out-of-vocab...
            pl.when(pl.col("is_out_vocab"))
            # ...keep probs of both tokens creating the merge...
            .then(pl.col("prob_true"))
            # ...and for those in-vocab only keep the prob of the token itself
            .otherwise(pl.col("prob_true").list.slice(-1, 1))
            # Compute the log-prob
            .list.eval((pl.element() + EPS).log())
            # Sum it to get the log-prob of the merge (for in-vocab is simply log-prob of token)
            .list.sum()
        ),
        log_prob_true_and_prefix=(
            # For token out-of-vocab...
            pl.when(pl.col("is_out_vocab"))
            # ...keep probs of both tokens creating the merge...
            .then(pl.col("prob_true_and_prefix"))
            # ...and for those in-vocab only keep the prob of the token itself
            .otherwise(pl.col("prob_true_and_prefix").list.slice(-1, 1))
            # Compute the log-prob
            .list.eval((pl.element() + EPS).log())
            # Sum it to get the log-prob of the merge (for in-vocab is simply log-prob of token)
            .list.sum()
        ),
    )
    # .with_columns(pl.col("tok_prob_true").list.len())["tok_prob_true"].value_counts()
)

In [None]:
(pn.ggplot(df, pn.aes("log_prob_true", fill="is_out_vocab")) + pn.geom_histogram(colour="black"))

In [33]:
# Compute average log-prob per token across contexts
avg_df = df.group_by(["new_token_id", "is_out_vocab"]).agg(
    pl.col("log_prob_true").mean(), pl.col("log_prob_true_and_prefix").mean()
)

In [34]:
avg_df = avg_df.join(tok.select(["new_token_id", "count"]), on="new_token_id")

In [35]:
# for idx, col in enumerate(df.columns):
#     if df.dtypes[idx] not in (pl.Boolean, pl.List):
#         print(col, df[col].is_infinite().sum())

In [36]:
# Prepare for regression making the scale more compatible
avg_df = avg_df.with_columns(pl.col("new_token_id") / 1_000)

In [None]:
avg_df.head()

In [40]:
avg_df = avg_df.with_columns(
    zscore=(pl.col("log_prob_true") - pl.col("log_prob_true").mean()) / pl.col("log_prob_true").std()
)

In [None]:
avg_df.filter((pl.col("new_token_id") < 32. + .5) & (pl.col("new_token_id") > 32. - .5))

In [None]:
win = .1
regdata = avg_df.filter((pl.col("new_token_id") < 32 + win) & (pl.col("new_token_id") > 32 - win))

# Fit model
rdd = smf.ols(
    "zscore ~ new_token_id + is_out_vocab + new_token_id:is_out_vocab", 
    regdata.to_pandas()
).fit(cov_type="HC3")
# rdd = (
#     smf.ols(
#         "log_prob_true ~ count + is_out_vocab + count:is_out_vocab", 
#         avg_df.to_pandas()
#     )
#     .fit(cov_type="HC3")
# )

# Compute discontinuity at threasold
# discontinuity_at_threshold = rdd.predict({"new_token_id": [31.999, 32.0], "is_out_vocab": [False, True]}).to_dict()
# discontinuity_at_threshold = discontinuity_at_threshold[1] - discontinuity_at_threshold[0]
# print(discontinuity_at_threshold)

rdd.summary()

In [None]:
p = (
    pn.ggplot(regdata, pn.aes(x="new_token_id", y="zscore"))
    + pn.geom_point(alpha=0.15)
    + pn.geom_line(pn.aes(y=rdd.fittedvalues, color="is_out_vocab"), size=2)
    # + pn.coord_cartesian(ylim=(-8, -5))
    # + pn.geom_vline(xintercept=vocab_size / 1000, linetype="dashed", color="black")
    # + pn.scale_x_reverse()
    # + pn.labs(x="", y="", colour="", title=f"Discontinuity at threshold: {discontinuity_at_threshold:.2f}")
    + pn.scale_colour_discrete(guide=None)
    + pn.theme_bw()
    # pn.scale_y_log10()
)

p

In [None]:
avg_df.group_by("is_out_vocab").agg(pl.col("log_prob_true").mean())

In [None]:
# Fit model
rdd = smf.ols(
    "log_prob_true_and_prefix ~ new_token_id + is_out_vocab + new_token_id:is_out_vocab", avg_df.to_pandas()
).fit(cov_type="HC3")

# Compute discontinuity at threasold
discontinuity_at_threshold = rdd.predict({"new_token_id": [31.999, 32.0], "is_out_vocab": [False, True]}).to_dict()
discontinuity_at_threshold = discontinuity_at_threshold[1] - discontinuity_at_threshold[0]
print(discontinuity_at_threshold)

rdd.summary()

In [None]:
p = (
    pn.ggplot(avg_df, pn.aes(x="new_token_id", y="log_prob_true_and_prefix"))
    + pn.geom_point(alpha=0.15)
    + pn.geom_line(pn.aes(y=rdd.fittedvalues, color="is_out_vocab"), size=2)
    # pn.coord_cartesian(ylim=(4, 12)) +
    + pn.geom_vline(xintercept=vocab_size / 1000, linetype="dashed", color="black")
    + pn.labs(x="", y="", colour="", title=f"Discontinuity at threshold: {discontinuity_at_threshold:.2f}")
    + pn.scale_colour_discrete(guide=None)
    + pn.theme_bw()
    # pn.scale_y_log10()
)

p

----

In [102]:
from sklearn.linear_model import LinearRegression

result = cp.RegressionDiscontinuity(
    data=avg_df.to_pandas().rename(columns={"is_out_vocab": "treated"}),
    formula="log_prob_true ~ 1 + new_token_id + treated",
    model=LinearRegression(),
    treatment_threshold=vocab_size / 1_000,
    running_variable_name="new_token_id",
)

# fig, ax = result.plot()

In [None]:
result.plot()

In [None]:
result.summary()

In [None]:
result.print_coefficients()