# Apply probes for early stopping

Pretrained probes and calibrated decision thresholds may be downloaded [here](https://figshare.com/articles/dataset/s1K_calibrated_probes/29242328). These files should be placed under `PROBE_DIR`.

Code is provided for reproducibility.

In [1]:
import os
import glob
import json
import pickle

from collections import deque

import numpy as np

# hi sklearn
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

## Preliminaries

In [2]:
PROBE_DATA_DIR = "../probes/data"  # LLM embeddings
PROBE_DIR = "../probes"  # you can customize

# this should be updated with where your outputs are saved
model_to_folder = {
    "qwen2.5": "../outputs",
    "qwq": "../outputs-qwq",
    "llama3.3": "../outputs-llama"
}

# you can modify code to loop through instead if you wish
MODEL = "qwen2.5"    # qwen2.5|qwq|llama3.3
MODE = "supervised"  # supervised|consistent|novel|leaf

Our s1K splits

In [3]:
splits = {
    "train": range(500),
    "val": range(500, 550),
    "test": range(550, 1000)
}

## Determine stopping time

In [4]:
def load_probe_inputs(model):
    with open(os.path.join(PROBE_DATA_DIR, f"{model}_embed_steps.pkl"), "rb") as f:
        reps = pickle.load(f)
    return reps


def smooth(pred, window=1):
    """
    Rolling window for smoothing
    """
    queue = deque()
    pred_smooth = []
    for p in pred:
        queue.append(p)
        if len(queue) > window:
            queue.popleft()
        pred_smooth.append(np.mean(queue))
    return pred_smooth

## Supervised and consistent

In [5]:
def get_stops(model, mode, eps):
    """
    eps      (float) tolerance. selected in `3-calibrate.ipynb`

    return stop (list[int]) in units of steps
    """
    # load everything
    with open(os.path.join(PROBE_DIR, f"probe-{mode}-{model}.pkl"), "rb") as f:
        lr, scaler, pca = pickle.load(f)
    with open(os.path.join(PROBE_DIR, f"lambdas-{model}-{mode}.json")) as f:
        ltt_lambdas = json.load(f)
    step_embeddings = load_probe_inputs(model)

    # apply probe at threshold
    if eps not in ltt_lambdas:
        raise Exception(f"Invalid eps. Choose from: {sorted(ltt_lambdas)}")
    threshold = ltt_lambdas[eps]
    stop = []
    for i in splits["val"]:  # val is test, test is cal
        ebds = step_embeddings[i]
        probs = lr.predict_proba(pca.transform(scaler.transform(ebds)))[:, 1]
        probs = smooth(probs, window=10)
        early_t = len(probs) - 1
        for t, p in enumerate(probs):
            if p >= threshold:
                early_t = t
                break
        stop.append(early_t)
    return stop

Example usage. Options for eps:
    ['0.01', '0.025', '0.05', '0.1', '0.15', '0.2', '0.25', '0.3', '0.35', '0.4', '0.5']

In [6]:
for eps in ['0.01', '0.025', '0.05', '0.1', '0.15', '0.2', '0.25', '0.3', '0.35', '0.4', '0.5']:
    stop = get_stops(MODEL, MODE, eps)
    print(len(stop), stop[:5])

50 [16, 46, 46, 56, 48]
50 [16, 46, 46, 56, 48]
50 [16, 46, 46, 56, 48]
50 [16, 46, 46, 56, 48]
50 [16, 35, 46, 56, 48]
50 [16, 34, 46, 56, 27]
50 [16, 32, 46, 56, 19]
50 [16, 31, 46, 56, 18]
50 [16, 30, 45, 56, 18]
50 [16, 28, 44, 54, 11]
50 [10, 26, 30, 49, 3]


## Novel leaf

In [10]:
def get_stops_boring(model, eps):
    """
    eps      (float) tolerance

    return stop (list[int]) in units of steps
    """
    # load everything
    with open(os.path.join(PROBE_DIR, f"probe-leaf-{model}.pkl"), "rb") as f:
        lr_leaf, scaler_leaf, pca_leaf = pickle.load(f)
    with open(os.path.join(PROBE_DIR, f"probe-novel-{model}.pkl"), "rb") as f:
        lr_novel, scaler_novel, pca_novel = pickle.load(f)
    with open(os.path.join(PROBE_DIR, f"lambdas-{model}-boring.json")) as f:
        ltt_lambdas = json.load(f)
    step_embeddings = load_probe_inputs(model)

    # apply probe at threshold
    if eps not in ltt_lambdas:
        raise Exception(f"Invalid eps. Choose from: {sorted(ltt_lambdas)}")
    threshold = ltt_lambdas[eps]

    stop = []
    for i in splits["val"]:
        cur_reps = step_embeddings[i]
        if len(cur_reps) < 2:
            stop.append(len(cur_reps))
            continue

        # p(leaf)
        leaf_preds = lr_leaf.predict_proba(pca_leaf.transform(scaler_leaf.transform(cur_reps)))[:,1]
        # # p(novel)
        cur_reps_stacked = np.concatenate([cur_reps[1:], cur_reps[:-1]], axis=1)  # look back
        novel_preds = lr_novel.predict_proba(pca_novel.transform(scaler_novel.transform(cur_reps_stacked)))[:,1]

        p_boring = leaf_preds[1:] * (1 - novel_preds)
        probs = smooth(p_boring, window=10)

        early_t = len(probs) # we started from 1 so last is ok
        for t, p in enumerate(probs):
            if p >= threshold:
                early_t = t
                break

        stop.append(early_t)
    return stop

In [11]:
for eps in ['0.01', '0.025', '0.05', '0.1', '0.15', '0.2', '0.25', '0.3', '0.35', '0.4', '0.5']:
    stop = get_stops_boring(MODEL, eps)
    print(len(stop), stop[:5])

50 [16, 46, 46, 56, 48]
50 [16, 46, 46, 56, 48]
50 [16, 46, 10, 56, 48]
50 [16, 46, 0, 56, 48]
50 [16, 18, 0, 56, 48]
50 [16, 17, 0, 56, 48]
50 [6, 16, 0, 56, 48]
50 [6, 16, 0, 56, 48]
50 [2, 5, 0, 56, 48]
50 [2, 5, 0, 55, 48]
50 [1, 1, 0, 5, 14]
