# Train probes based on s1K trajectories

Pretrained probes may be downloaded [here](https://figshare.com/articles/dataset/s1K_calibrated_probes/29242328) (skip to `3-calibrate.ipynb`).

The code and data for probe training are provided for reproducibility.
The data required to re-train our probes can be found [here](https://figshare.com/articles/dataset/s1K_step_embeddings/29230682) and should be placed under `PROBE_DATA_DIR`.

## Preliminaries

In [1]:
import os
import glob
import json
import pickle

from collections import Counter, deque

import numpy as np

# hi sklearn
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

from sklearn.metrics import roc_auc_score

In [2]:
PROBE_DATA_DIR = "../probes/data"  # LLM embeddings
PROBE_DIR = "../probes"  # you can customize

# this should be updated with where your outputs are saved
model_to_folder = {
    "qwen2.5": "../outputs",
    "qwq": "../outputs-qwq",
    "llama3.3": "../outputs-llama"
}

# you can modify code to loop through instead if you wish
MODEL = "qwen2.5"    # qwen2.5|qwq|llama3.3
MODE = "supervised"  # supervised|consistent|novel|leaf

Our s1K splits

In [3]:
splits = {
    "train": range(500),
    "val": range(500, 550),
    "test": range(550, 1000)
}

## Steps to embeddings

Extract mean embeddings for each step. This script takes a lot of memory and time (depending on your file system speed).

`{PROBE_DATA_DIR}/{model}_embed_steps.pkl` can be produced by the `s1_embed_{model}` prompt from `1-prepare_s1.ipynb`.

```
for model in model_to_folder:
    
    with open(f"cache/s1_metadata_{model}.json") as f:
        idx_to_step_limits = json.load(f)["step_limits"]
    
    folder = model_to_folder[model]
    fps = glob.glob(f"{folder}/s1_embed_{model}-*.pkl")
    fps = sorted(fps, key=lambda fp: int(fp.rsplit("-", 1)[1].split(".")[0]))
    print(len(fps))
    
    fp_out = f"embeddings/{model}_embed_steps.pkl"

    if not os.path.exists(fp_out):
        step_embeddings = []
        for fp in fps:
            with open(fp, "rb") as f:
                embeddings = pickle.load(f)
                for ebd in embeddings:
                    cur_idx = len(step_embeddings)
                    cur_ebs = []
                    segments = idx_to_step_limits[cur_idx]
                    # need to +1 because determined from length of previous
                    for left, right in segments:
                        if len(ebd[left+1:right+1]) == 0:
                            continue
                        cur_ebs.append(np.mean(ebd[left+1:right+1], axis=0))
                    step_embeddings.append(cur_ebs)
                print("done with", fp)

        with open(fp_out, "wb") as f:
            pickle.dump(step_embeddings, f)
        print(fp_out)
        
        for fp in fps:
            os.remove(fp)
```

## Parse probe labels

This section parses outputs of the `s1_verify_...` prompts from `1-prepare_s1.ipynb` into `{PROBE_DATA_DIR}/labels-{mode}-{model}.json`.

In [4]:
def get_ambiguous(line):
    for s in ["incorrect", "incomplete", "inaccura"]:
        if s in line:
            return 0
    if "yes" in line.lower():
        return 1
    if "no" in line.lower():
        return 0
    return -1


def parse_probe_labels(verify_outputs, return_outliers=False):
    """
    verify_outputs  (list[str]) LLM verifier outputs
    """
    verify_results = []
    outliers = []
    for line in verify_outputs:
        # only look at final answer
        result = line.strip().rsplit("\n", 1)
        if len(result) < 2:
            label = get_ambiguous(line)
            verify_results.append(label)
            if label < 0:
                outliers.append(line)
            continue
        result = result[1].strip()
        if "yes" in result.lower():
            verify_results.append(1)
        elif "no" in result.lower():
            verify_results.append(0)
        else:
            label = get_ambiguous(line)
            verify_results.append(label)
            if label < 0:
                outliers.append(line)
    if return_outliers:
        return verify_results, outliers
    return verify_results

Group inputs by original question

In [5]:
def prepare_linear_probe_data(verify_results, batch_index, idx_to_index):
    """
    verify_results  (list[int])        output from parse_probe_labels
    batch_index     (list[int])        [0, 0, 0, ..., 1, 1, ...] etc from metadata                           
    idx_to_index    (list[list[int]])  [[0,1,2,3], [4,5,6], ...] etc from metadata
    """
    num_problems = max(batch_index)
    
    xs = []
    ys = []
    # loop through problems
    for i in range(num_problems + 1):
        # loop through trajectories
        keep_x = []
        keep_y = []
        for j in idx_to_index[i]:
            if verify_results[j] < 0:
                continue
            keep_x.append(j)
            keep_y.append(verify_results[j])
        xs.append(keep_x)
        ys.append(keep_y)
    
    return {
        "index": xs,
        "label": ys
    }

## Train probe

This section traines probes using `{PROBE_DATA_DIR}/{model}_embed_steps.pkl` (inputs) and `{PROBE_DATA_DIR}/labels-{mode}-{model}.json` (labels).

In [6]:
def load_metadata(model, mode):
    """
    mode  (str) supervised|consistent|novel|leaf
    """
    if mode in ["supervised", "consistent"]:
        fp_metadata = f"cache/s1_metadata_{model}.json"
    else:
        fp_metadata = f"cache/s1_metadata_step.json"
    with open(fp_metadata) as f:
        info = json.load(f)
    return info


def load_probe_inputs(model):
    with open(os.path.join(PROBE_DATA_DIR, f"{model}_embed_steps.pkl"), "rb") as f:
        reps = pickle.load(f)
    return reps


def load_probe_labels(model, mode="supervised"):
    """
    mode  (str) supervised|consistent|novel|leaf
    """
    # novel and leaf labels are not dependent on model
    if mode in ["supervised", "consistent"]:
        fp_labels = f"labels-{mode}-{model}.json"
    else:
        fp_labels = f"labels-{mode}.json"
    # load JSON
    with open(os.path.join(PROBE_DATA_DIR, fp_labels)) as f:
        labels = json.load(f)
        index = labels["index"]
        label = labels["label"]
    assert len(index) == len(label)
    # transform label to cumulative for supervised|consistent
    # for calibration validity
    if mode in ["supervised", "consistent"]:
        to_cumulative(index, label)
    return index, label

def to_cumulative(index, label):
    """
    applied for supervised and consistent probe.
    modifies `label` and `index` in place.
    """
    for i, lbl in enumerate(label):
        if 1 not in lbl:
            label[i] = []  # skip
            index[i] = []
            continue
        first = lbl.index(1)
        for j in range(first, len(lbl)):
            lbl[j] = 1

In [7]:
def get_last(abs_idx, rel_idx, batch_size=10):
    """
    get representation of last step
    
    batch_size  should match generate_truncated_prompts
    """
    question_idx = batch_index[abs_idx]
    last_rep = rel_idx * batch_size + batch_size
    if last_rep >= len(reps[question_idx]):  # batch_size overshoots
        last_rep = len(reps[question_idx]) - 1
    return reps[question_idx][last_rep]


def get_up_to_last(abs_idx, rel_idx, batch_size=10):
    """
    get representations of all steps
    
    batch_size  should match generate_truncated_prompts
    """
    question_idx = batch_index[abs_idx]
    last_rep = rel_idx * batch_size + batch_size
    return reps[question_idx][:last_rep]

### Supervised and consistent

In [8]:
info = load_metadata(MODEL, MODE)
batch_index = info["batch_idx"]
idx_to_index = info["idx_to_index"]
# Xs
reps = load_probe_inputs(MODEL)
# index of Xs, corresponding ys
index, label = load_probe_labels(MODEL, MODE)

**Minor note**: Sometimes PCA hangs even with fixed `random_state`. Restart notebook and it should work O_O

In [9]:
fp_probes = os.path.join(PROBE_DIR, f"probe-{MODE}-{MODEL}.pkl")

In [10]:
X_train = []
y_train = []
for i in splits["train"]:
    for j, idx in enumerate(index[i]):
        X_train.append(get_last(idx, j))
        y_train.append(label[i][j])

print(len(X_train), len(y_train))

X_val = []
y_val = []
for i in splits["val"]:
    for j, idx in enumerate(index[i]):
        X_val.append(get_last(idx, j))
        y_val.append(label[i][j])

print(len(X_val), len(y_val))

X_test = []
y_test = []
for i in splits["test"]:
    if len(index[i]) == 0:
        continue
    np.random.seed(i)
    j = np.random.choice(len(index[i]))
    idx = index[i][j]
    X_test.append(get_last(idx, j))
    y_test.append(label[i][j])

print(len(X_test), len(y_test))

X_train = np.array(X_train)
y_train = np.array(y_train)

X_val = np.array(X_val)
y_val = np.array(y_val)

X_test = np.array(X_test)
y_test = np.array(y_test)

if os.path.exists(fp_probes):
    with open(fp_probes, "rb") as f:
        lr, scaler, pca = pickle.load(f)
else:
    scaler = StandardScaler().fit(X_train)

X_train = scaler.transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

if not os.path.exists(fp_probes):
    pca = PCA(n_components=256, random_state=0).fit(X_train)

X_train = pca.transform(X_train)
X_val = pca.transform(X_val)
X_test = pca.transform(X_test)

if not os.path.exists(fp_probes):
    lr = LogisticRegression(max_iter=5000).fit(
        X=X_train, y=y_train)
    
    with open(fp_probes, "wb") as f:
        pickle.dump([lr, scaler, pca], f)

2038 2038
189 189
200 200


In [11]:
y_train_pred = lr.predict_proba(X_train)
print(roc_auc_score(y_train, y_train_pred[:,1]))

y_val_pred = lr.predict_proba(X_val)
print(roc_auc_score(y_val, y_val_pred[:,1]))

y_test_pred = lr.predict_proba(X_test)
print(roc_auc_score(y_test, y_test_pred[:,1]))

0.9329294002156455
0.7431823304034257
0.7879735835940216


### Novel and leaf

In [12]:
MODE = "leaf"

In [13]:
fp_probes = os.path.join(PROBE_DIR, f"probe-{MODE}-{MODEL}.pkl")

In [14]:
info = load_metadata(MODEL, MODE)
batch_index = info["batch_idx"]
idx_to_index = info["idx_to_index"]
# Xs
reps = load_probe_inputs(MODEL)
# index of Xs, corresponding ys
index, label = load_probe_labels(MODEL, MODE)

splits_inv = {}
for i in range(len(info["idx_to_index"])):
    for s in splits:
        if i in splits[s]:
            splits_inv[i] = s
            break

In [15]:
Xs = {split: [] for split in splits}
ys = {split: [] for split in splits}

for idx, (cur_labels, cur_reps) in enumerate(zip(label, reps)):
    cur_split = splits_inv[idx]
    assert len(cur_labels) == len(cur_reps), (idx, len(cur_labels), len(cur_reps))

    # slightly different data indexing
    if MODE == "leaf":
        # calibration must be exchangeable
        if cur_split == "test":
            if len(cur_labels) - cur_labels.count(-1) < 1:
                continue
            valid_leaves = [i for i, lbl in enumerate(cur_labels) if lbl >= 0]
            np.random.seed(idx)
            j = np.random.choice(valid_leaves)
            Xs[cur_split].append(cur_reps[j])
            ys[cur_split].append(cur_labels[j])
            continue
    
        # otherwise add all
        for i, (lbl, rep) in enumerate(zip(cur_labels, cur_reps)):
            if lbl < 0:
                continue
            Xs[cur_split].append(rep)
            ys[cur_split].append(lbl)

    elif MODE == "novel":
        # calibration must be exchangeable
        if cur_split == "test":
            if len(cur_labels) - cur_labels.count(-1) < 2:
                continue
            valid_scores = [i+1 for i, lbl in enumerate(cur_labels[1:]) if lbl >= 0]
            np.random.seed(idx)
            j = np.random.choice(valid_scores)
            Xs[cur_split].append(np.concatenate([cur_reps[j], cur_reps[j-1]]))  # [i] is prev since we start from [1:]
            ys[cur_split].append(cur_labels[j])
            continue
    
        # otherwise add all
        for i, (lbl, rep) in enumerate(zip(cur_labels[1:], cur_reps[1:])):
            if lbl < 0:
                continue
            Xs[cur_split].append(np.concatenate([rep, cur_reps[i]]))  # [i] is prev since we start from [1:]
            ys[cur_split].append(lbl)
    
    else:
        raise Exception("Use the code above for other modes")


Xs = {split: np.array(X) for split, X in Xs.items()}
ys = {split: np.array(y) for split, y in ys.items()}

for split, X in Xs.items():
    print(split, len(X))

if os.path.exists(fp_probes):
    with open(fp_probes, "rb") as f:
        lr, scaler, pca = pickle.load(f)
else:
    scaler = StandardScaler().fit(Xs["train"])

for split, X in Xs.items():
    Xs[split] = scaler.transform(X)

if not os.path.exists(fp_probes):
    pca = PCA(n_components=256, random_state=0).fit(Xs["train"])

for split, X in Xs.items():
    Xs[split] = pca.transform(X)

if not os.path.exists(fp_probes):
    lr = LogisticRegression(max_iter=5000).fit(
        X=Xs["train"], y=ys["train"])
    
    with open(fp_probes, "wb") as f:
        pickle.dump([lr, scaler, pca], f)

train 31015
val 2973
test 450


In [16]:
y_train_pred = lr.predict_proba(Xs["train"])
print(roc_auc_score(ys["train"], y_train_pred[:,1]))

y_val_pred = lr.predict_proba(Xs["val"])
print(roc_auc_score(ys["val"], y_val_pred[:,1]))

y_test_pred = lr.predict_proba(Xs["test"])
print(roc_auc_score(ys["test"], y_test_pred[:,1]))

0.8680944749659607
0.854880753664281
0.8393614871691251
