In [None]:
import json
import pickle as pkl
from typing import Dict, Iterable, List, Sequence, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt

In [4]:
def load_safeset(path: str) -> set:
    with open(path, "rb") as f:
        return pkl.load(f)


def filterize(texts: Iterable[str], safeset: set) -> List[str]:
    return [t for t in texts if t and len(set(t).difference(safeset)) == 0]


def load_texts(files: Sequence[str]) -> List[str]:
    data = []
    for path in files:
        with open(path, "r") as f:
            data.extend(json.loads(line)["text"] for line in f)
    return data


def chunk_texts(texts: Sequence[str], splitlen: int) -> List[str]:
    windows: List[str] = []
    for t in texts:
        if not t:
            continue
        windows.extend(
            t[i : i + splitlen] for i in np.arange(0, len(t), splitlen) if t[i : i + splitlen]
        )
    return windows


def prepare_samples(
    files: Sequence[str],
    safeset: set,
    splitlen: int,
    nsamples: int,
    buffer: int,
    seed: int,
) -> List[str]:
    btext = load_texts(files)
    ftext = filterize(btext, safeset)
    ftext_chunked = chunk_texts(ftext, splitlen)

    lengths = np.array([len(t) for t in ftext_chunked])
    indsort = np.flip(np.argsort(lengths))

    sample_count = min(nsamples, len(indsort))
    rng = np.random.default_rng(seed)
    starts = rng.choice(buffer, size=sample_count) if buffer > 0 else np.zeros(sample_count, dtype=int)
    ftext_sorted = [ftext_chunked[i][s:] for i, s in zip(indsort[:sample_count], starts)]
    return ftext_sorted


def build_safe_token_ids(tokenizer, safeset: set) -> np.ndarray:
    safe_ids = []
    for tid in range(len(tokenizer)):
        text = tokenizer.decode(tid, clean_up_tokenization_spaces=False)
        if text and set(text).issubset(safeset):
            safe_ids.append(tid)
    return np.array(safe_ids, dtype=np.int64)


def rows_all_safe(tok_batch: np.ndarray, attn_batch: np.ndarray, safe_id_set: set) -> np.ndarray:
    active = attn_batch == 1  # ignore padding
    return np.array([all((tid in safe_id_set) for tid in seq[mask]) for seq, mask in zip(tok_batch, active)])


def compute_entropies(
    dataloader: DataLoader,
    tokenizer,
    model,
    device: torch.device,
    shapecut: int,
    context_lengths: Sequence[int],
    near_zero_thresh: float,
    safe_id_set: set,
    top_k: int,
) -> Tuple[Dict[int, np.ndarray], List[Tuple[int, int, str]]]:
    ent_by_ctx: Dict[int, List[np.ndarray]] = {k: [] for k in context_lengths}
    zero_bin_tokens: List[int] = []

    model.eval()
    with torch.no_grad():
        for docs in tqdm(dataloader):
            inputs = tokenizer(
                docs,
                return_tensors="pt",
                return_token_type_ids=False,
                padding="max_length",
                max_length=shapecut,
                truncation=True,
            ).to(device)

            logits = model(**inputs).logits  # [B, T, V]
            probs = F.softmax(logits, dim=-1)

            tok = inputs["input_ids"].cpu().numpy()
            attn = inputs["attention_mask"].cpu().numpy()

            keep_rows = rows_all_safe(tok, attn, safe_id_set)
            if not keep_rows.any():
                continue

            tok = tok[keep_rows]
            attn = attn[keep_rows]
            probs = probs[keep_rows]

            token_entropy = -(probs * torch.log2(probs.clamp_min(1e-10))).sum(dim=-1)

            for k in context_lengths:
                if k > token_entropy.shape[1]:
                    continue
                ent_slice = token_entropy[:, k - 1].cpu().numpy()
                attn_slice = attn[:, k - 1]
                ent_valid = ent_slice[attn_slice == 1]
                if not ent_valid.size:
                    continue
                ent_by_ctx[k].append(ent_valid)
                near_zero_mask = ent_valid < near_zero_thresh
                if near_zero_mask.any():
                    zero_bin_tokens.extend(tok[:, k - 1][attn_slice == 1][near_zero_mask])

    ent_by_ctx = {k: (np.concatenate(v) if v else np.array([])) for k, v in ent_by_ctx.items()}

    if zero_bin_tokens:
        uniq, counts = np.unique(zero_bin_tokens, return_counts=True)
        zero_entropy_tokens = sorted(zip(uniq, counts), key=lambda x: x[1], reverse=True)
        top_zero_tokens = [
            (tid, cnt, tokenizer.decode(tid, clean_up_tokenization_spaces=False))
            for tid, cnt in zero_entropy_tokens[:top_k]
        ]
    else:
        top_zero_tokens = []

    return ent_by_ctx, top_zero_tokens


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Compute next-token entropy distributions with safelist filtering.")
    parser.add_argument(
        "--files",
        nargs="+",
        default=[
            "/scratch/gpfs/DATASETS/hugging_face/c4/en/c4-train.00217-of-01024.json",
        ],
        help="JSONL files to read (expects a 'text' field per line).",
    )
    parser.add_argument("--safeset-path", type=str, default="colin_files/safeset2.txt", help="Pickled safeset path.")
    parser.add_argument("--model", type=str, default="allenai/OLMo-2-0425-1B", help="HF model name.")
    parser.add_argument("--splitlen", type=int, default=15000, help="Non-overlapping window length for chunking text.")
    parser.add_argument("--nsamples", type=int, default=2000, help="How many windows to keep (after sorting by length).")
    parser.add_argument("--buffer", type=int, default=100, help="Random start offset range for each selected window.")
    parser.add_argument("--seed", type=int, default=223291, help="RNG seed for start offsets.")
    parser.add_argument("--batch-size", type=int, default=2, help="Dataloader batch size.")
    parser.add_argument("--shapecut", type=int, default=1024, help="Max tokens per sample (padding/truncation).")
    parser.add_argument(
        "--context-lengths",
        type=str,
        default="1,3,10,40,100,1000",
        help="Comma-separated context lengths to sample entropy from.",
    )
    parser.add_argument("--near-zero-thresh", type=float, default=1e-3, help="Threshold for ~zero-entropy bin.")
    parser.add_argument("--top-k", type=int, default=30, help="How many zero-entropy tokens to display.")
    parser.add_argument(
        "--save-ent",
        type=str,
        default="",
        help="Optional path to save entropies as npz (keys ctx_<len>).",
    )
    parser.add_argument(
        "--save-zero",
        type=str,
        default="",
        help="Optional path to save zero-entropy tokens as TSV (id, count, text).",
    )
    return parser.parse_args()


def main():

args = parse_args()

context_lengths = [int(x) for x in args.context_lengths.split(",") if x]
if args.shapecut < max(context_lengths):
    raise ValueError(f"shapecut {args.shapecut} must be >= max context length {max(context_lengths)}")

safeset = load_safeset(args.safeset_path)

print("Preparing samples...")
ftext_sorted = prepare_samples(
    files=args.files,
    safeset=safeset,
    splitlen=args.splitlen,
    nsamples=args.nsamples,
    buffer=args.buffer,
    seed=args.seed,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(args.model, device_map="auto")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(args.model, device_map="auto")
model.to(device)

dataloader = DataLoader(ftext_sorted, batch_size=args.batch_size)

safe_ids = build_safe_token_ids(tokenizer, safeset)
safe_id_set = set(safe_ids.tolist())

print("Computing entropies...")
ent_by_ctx, top_zero_tokens = compute_entropies(
    dataloader=dataloader,
    tokenizer=tokenizer,
    model=model,
    device=device,
    shapecut=args.shapecut,
    context_lengths=context_lengths,
    near_zero_thresh=args.near_zero_thresh,
    safe_id_set=safe_id_set,
    top_k=args.top_k,
)

sizes = {k: v.size for k, v in ent_by_ctx.items()}
print("Entropy counts per context length:", sizes)
print("Top zero-entropy tokens (id, count, text):", top_zero_tokens)

if args.save_ent:
    save_dict = {f"ctx_{k}": v for k, v in ent_by_ctx.items()}
    np.savez_compressed(args.save_ent, **save_dict)
    print(f"Saved entropies to {args.save_ent}")

if args.save_zero and top_zero_tokens:
    with open(args.save_zero, "w") as f:
        for tid, cnt, text in top_zero_tokens:
            safe_text = text.replace("\t", " ").replace("\n", " ")
            f.write(f"{tid}\t{cnt}\t{safe_text}\n")
    print(f"Saved zero-entropy tokens to {args.save_zero}")

In [6]:
files=['/scratch/gpfs/DATASETS/hugging_face/c4/en/c4-train.00217-of-01024.json',
    #    '/scratch/gpfs/DATASETS/hugging_face/c4/en/c4-train.00023-of-01024.json',
    #    '/scratch/gpfs/DATASETS/hugging_face/c4/en/c4-train.00345-of-01024.json'
]

data=[]
for file in files:
    with open(file,'r') as f:
        data+=[json.loads(l) for l in f]

btext=[d['text'] for d in data]

ftext = filterize(btext)

# chunk each filtered doc into non-overlapping windows so we maximize usable samples
splitlen = 15000
def chunk_texts(texts, splitlen):
    windows = []
    for t in texts:
        if not t:
            continue
        windows.extend([t[i:i+splitlen] for i in np.arange(0, len(t), splitlen) if t[i:i+splitlen]])
    return windows

ftext_chunked = chunk_texts(ftext, splitlen)

lengths=np.array([len(t) for t in ftext_chunked])
indsort=np.flip(np.argsort(lengths))
nsamples=2000
buffer=100
rng=np.random.default_rng(223291)
starts=rng.choice(buffer,size=nsamples)
ftext_sorted=[ftext_chunked[i][s:] for i,s in zip(indsort[:nsamples],starts)]


100%|██████████| 356317/356317 [00:07<00:00, 47594.82it/s]


In [None]:
model = 'allenai/OLMo-2-0425-1B'
tokenizer = AutoTokenizer.from_pretrained(model, device_map='auto')
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model, device_map='auto')
model.eval();

Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.24s/it]


Olmo2ForCausalLM(
  (model): Olmo2Model(
    (embed_tokens): Embedding(100352, 2048, padding_idx=100277)
    (layers): ModuleList(
      (0-15): 16 x Olmo2DecoderLayer(
        (self_attn): Olmo2Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): Olmo2RMSNorm((2048,), eps=1e-06)
          (k_norm): Olmo2RMSNorm((2048,), eps=1e-06)
        )
        (mlp): Olmo2MLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (post_attention_layernorm): Olmo2RMSNorm((2048,), eps=1e-06

In [None]:
batch_size = 2
dataloader = DataLoader(ftext_sorted, batch_size=batch_size)

In [10]:
# plot a histogram of token entropies depending on how much context is given to the model

# similar to this code
# for step,ax in zip(steps,axes[:,i]):       
#     msk=data['m'][:,step+1].astype(bool)
#     ax.hist(data['e'][msk,step+1],density=True,color='k',bins=bins);
#     ax.axvline(avs['e'][step+1],c='gray',zorder=-1,ls='--')

In [None]:
def load_safeset(path: str) -> set:
    with open(path, "rb") as f:
        return pkl.load(f)


def filterize(texts: Iterable[str], safeset: set) -> List[str]:
    return [t for t in texts if t and len(set(t).difference(safeset)) == 0]


def load_texts(files: Sequence[str]) -> List[str]:
    data = []
    for path in files:
        with open(path, "r") as f:
            data.extend(json.loads(line)["text"] for line in f)
    return data


def chunk_texts(texts: Sequence[str], splitlen: int) -> List[str]:
    windows: List[str] = []
    for t in texts:
        if not t:
            continue
        windows.extend(
            t[i : i + splitlen] for i in np.arange(0, len(t), splitlen) if t[i : i + splitlen]
        )
    return windows


def prepare_samples(
    files: Sequence[str],
    safeset: set,
    splitlen: int,
    nsamples: int,
    buffer: int,
    seed: int,
) -> List[str]:
    btext = load_texts(files)
    ftext = filterize(btext, safeset)
    ftext_chunked = chunk_texts(ftext, splitlen)

    lengths = np.array([len(t) for t in ftext_chunked])
    indsort = np.flip(np.argsort(lengths))

    sample_count = min(nsamples, len(indsort))
    rng = np.random.default_rng(seed)
    starts = rng.choice(buffer, size=sample_count) if buffer > 0 else np.zeros(sample_count, dtype=int)
    ftext_sorted = [ftext_chunked[i][s:] for i, s in zip(indsort[:sample_count], starts)]
    return ftext_sorted


def build_safe_token_ids(tokenizer, safeset: set) -> np.ndarray:
    safe_ids = []
    for tid in range(len(tokenizer)):
        text = tokenizer.decode(tid, clean_up_tokenization_spaces=False)
        if text and set(text).issubset(safeset):
            safe_ids.append(tid)
    return np.array(safe_ids, dtype=np.int64)


def rows_all_safe(tok_batch: np.ndarray, attn_batch: np.ndarray, safe_id_set: set) -> np.ndarray:
    active = attn_batch == 1  # ignore padding
    return np.array([all((tid in safe_id_set) for tid in seq[mask]) for seq, mask in zip(tok_batch, active)])


def compute_entropies(
    dataloader: DataLoader,
    tokenizer,
    model,
    device: torch.device,
    shapecut: int,
    context_lengths: Sequence[int],
    near_zero_thresh: float,
    safe_id_set: set,
    top_k: int,
) -> Tuple[Dict[int, np.ndarray], List[Tuple[int, int, str]]]:
    ent_by_ctx: Dict[int, List[np.ndarray]] = {k: [] for k in context_lengths}
    zero_bin_tokens: List[int] = []

    model.eval()
    with torch.no_grad():
        for docs in tqdm(dataloader):
            inputs = tokenizer(
                docs,
                return_tensors="pt",
                return_token_type_ids=False,
                padding="max_length",
                max_length=shapecut,
                truncation=True,
            ).to(device)

            logits = model(**inputs).logits  # [B, T, V]
            probs = F.softmax(logits, dim=-1)

            tok = inputs["input_ids"].cpu().numpy()
            attn = inputs["attention_mask"].cpu().numpy()

            keep_rows = rows_all_safe(tok, attn, safe_id_set)
            if not keep_rows.any():
                continue

            tok = tok[keep_rows]
            attn = attn[keep_rows]
            probs = probs[keep_rows]

            token_entropy = -(probs * torch.log2(probs.clamp_min(1e-10))).sum(dim=-1)

            for k in context_lengths:
                if k > token_entropy.shape[1]:
                    continue
                ent_slice = token_entropy[:, k - 1].cpu().numpy()
                attn_slice = attn[:, k - 1]
                ent_valid = ent_slice[attn_slice == 1]
                if not ent_valid.size:
                    continue
                ent_by_ctx[k].append(ent_valid)
                near_zero_mask = ent_valid < near_zero_thresh
                if near_zero_mask.any():
                    zero_bin_tokens.extend(tok[:, k - 1][attn_slice == 1][near_zero_mask])

    ent_by_ctx = {k: (np.concatenate(v) if v else np.array([])) for k, v in ent_by_ctx.items()}

    if zero_bin_tokens:
        uniq, counts = np.unique(zero_bin_tokens, return_counts=True)
        zero_entropy_tokens = sorted(zip(uniq, counts), key=lambda x: x[1], reverse=True)
        top_zero_tokens = [
            (tid, cnt, tokenizer.decode(tid, clean_up_tokenization_spaces=False))
            for tid, cnt in zero_entropy_tokens[:top_k]
        ]
    else:
        top_zero_tokens = []

    return ent_by_ctx, top_zero_tokens


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Compute next-token entropy distributions with safelist filtering.")
    parser.add_argument(
        "--files",
        nargs="+",
        default=[
            "/scratch/gpfs/DATASETS/hugging_face/c4/en/c4-train.00217-of-01024.json",
        ],
        help="JSONL files to read (expects a 'text' field per line).",
    )
    parser.add_argument("--safeset-path", type=str, default="colin_files/safeset2.txt", help="Pickled safeset path.")
    parser.add_argument("--model", type=str, default="allenai/OLMo-2-0425-1B", help="HF model name.")
    parser.add_argument("--splitlen", type=int, default=15000, help="Non-overlapping window length for chunking text.")
    parser.add_argument("--nsamples", type=int, default=2000, help="How many windows to keep (after sorting by length).")
    parser.add_argument("--buffer", type=int, default=100, help="Random start offset range for each selected window.")
    parser.add_argument("--seed", type=int, default=223291, help="RNG seed for start offsets.")
    parser.add_argument("--batch-size", type=int, default=2, help="Dataloader batch size.")
    parser.add_argument("--shapecut", type=int, default=1024, help="Max tokens per sample (padding/truncation).")
    parser.add_argument(
        "--context-lengths",
        type=str,
        default="1,3,10,40,100,1000",
        help="Comma-separated context lengths to sample entropy from.",
    )
    parser.add_argument("--near-zero-thresh", type=float, default=1e-3, help="Threshold for ~zero-entropy bin.")
    parser.add_argument("--top-k", type=int, default=30, help="How many zero-entropy tokens to display.")
    parser.add_argument(
        "--save-ent",
        type=str,
        default="",
        help="Optional path to save entropies as npz (keys ctx_<len>).",
    )
    parser.add_argument(
        "--save-zero",
        type=str,
        default="",
        help="Optional path to save zero-entropy tokens as TSV (id, count, text).",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    context_lengths = [int(x) for x in args.context_lengths.split(",") if x]
    if args.shapecut < max(context_lengths):
        raise ValueError(f"shapecut {args.shapecut} must be >= max context length {max(context_lengths)}")

    safeset = load_safeset(args.safeset_path)

    print("Preparing samples...")
    ftext_sorted = prepare_samples(
        files=args.files,
        safeset=safeset,
        splitlen=args.splitlen,
        nsamples=args.nsamples,
        buffer=args.buffer,
        seed=args.seed,
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(args.model, device_map="auto")
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(args.model, device_map="auto")
    model.to(device)

    dataloader = DataLoader(ftext_sorted, batch_size=args.batch_size)

    safe_ids = build_safe_token_ids(tokenizer, safeset)
    safe_id_set = set(safe_ids.tolist())

    print("Computing entropies...")
    ent_by_ctx, top_zero_tokens = compute_entropies(
        dataloader=dataloader,
        tokenizer=tokenizer,
        model=model,
        device=device,
        shapecut=args.shapecut,
        context_lengths=context_lengths,
        near_zero_thresh=args.near_zero_thresh,
        safe_id_set=safe_id_set,
        top_k=args.top_k,
    )

    sizes = {k: v.size for k, v in ent_by_ctx.items()}
    print("Entropy counts per context length:", sizes)
    print("Top zero-entropy tokens (id, count, text):", top_zero_tokens)

    if args.save_ent:
        save_dict = {f"ctx_{k}": v for k, v in ent_by_ctx.items()}
        np.savez_compressed(args.save_ent, **save_dict)
        print(f"Saved entropies to {args.save_ent}")

    if args.save_zero and top_zero_tokens:
        with open(args.save_zero, "w") as f:
            for tid, cnt, text in top_zero_tokens:
                safe_text = text.replace("\t", " ").replace("\n", " ")
                f.write(f"{tid}\t{cnt}\t{safe_text}\n")
        print(f"Saved zero-entropy tokens to {args.save_zero}")