In [1]:
# !pip install accelerate scikit-learn #matplotlib seaborn datasets scikit-learn

In [2]:
# RAST  —  Redundancy-Aware Steering Technique (token-efficiency experiment)
# Uses your steerit.SteeringVector / SteeringModel definitions.
#
# PSEUDOCODE
# 1.  Load DeepSeek-R1-Distill-Qwen-1.5B and wrap with SteeringModel.
# 2.  For each difficulty level L in {1…5}:
#     a.  Generate N_TRAIN traces (step-by-step answers) with *no steering*.
#     b.  For every token t≥k:
#         • ΔKL = KL(p_t  ||  p_{t-k})   (# compare logits after rolling back k)
#         • If ΔKL < τ   → low-gain  → save hidden h_t in LOW
#           else          high-gain → save hidden h_t in HIGH
#     c.  Vector v_L = mean(HIGH) − mean(LOW)   (layer STEER_LAY only)
# 3.  Inference with ΔKL gate:
#       keep sliding buffer of logits; if current ΔKL<τ → set coeff α∈[α_lo,α_hi],
#       else coeff 0; SteeringModel hook adds α·v_L to layer activations.
# 4.  Record tokens/answer & accuracy for baseline vs RAST; plot %–saving vs level.
# ──────────────────────────────────────────────────────────────────────────────
#!/usr/bin/env python3
# RAST on GSM8K with automatic 5-level difficulty bins
# Requires steerit.SteeringVector and SteeringModel to be importable.

import os, random, math, time, warnings
import numpy as np
import torch, torch.nn.functional as F
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from steerit.steerit import SteeringVector, SteeringModel        # ← your library

# ─────────────────────────── settings ────────────────────────────
MODEL_NAME  = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
HF_TOKEN    = ''
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE       = torch.float16 if DEVICE == "cuda" else torch.float32

In [18]:
STEER_LAY    = 20          # layer to steer
K_WIN        = 10           # rollback window for ΔKL
DKL_THR      = 0.05        # ΔKL < τ → low-gain
ALPHA_HI     = 1.0         # steering strength when gate fires
MAX_TOKENS   = 2048
N_TRAIN      = 50          # traces to build vector
N_EVAL       = 50         # evaluation problems
SEED         = 42
torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
warnings.filterwarnings("ignore")

# ---------- LOAD MODEL -------------------------------------------------------
print(f"Loading {MODEL_NAME} …")
tok = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=DTYPE,
    device_map="auto" if DEVICE=="cuda" else None,
    token=HF_TOKEN
)
model = SteeringModel(base_model, [STEER_LAY], DEVICE)
print("Model ready.\n")

In [None]:
# ---------- UTILS -----------------------------------------------------------
STOP_IDS  = tok("Final answer:", add_special_tokens=False).input_ids
GEN_LIMIT = 2048                      # shorter pass when BUILDING the vector

def kl_div(p, q):
    return F.kl_div(F.log_softmax(p, dim=-1),
                    F.softmax(q, dim=-1), reduction="batchmean").item()

@torch.no_grad()
def stream(prompt, max_tokens):
    """Greedy stream with trimmed cache; returns ids and list[logits]."""
    ids  = tok(prompt, return_tensors="pt").to(DEVICE)["input_ids"]
    past, lg = None, []
    for _ in range(max_tokens):
        out  = model(input_ids=ids[:, -1:], past_key_values=past,
                     use_cache=True, repetition_penalty=1.1)
        past = tuple((k[..., -K_WIN:, :].contiguous(),
                      v[..., -K_WIN:, :].contiguous())
                     for k, v in out.past_key_values)
        logits = out.logits[:, -1, :]
        lg.append(logits.detach())
        nxt = logits.argmax(-1, keepdim=True)
        ids = torch.cat([ids, nxt], dim=-1)
        if nxt.item() in STOP_IDS or nxt.item() == tok.eos_token_id:
            break
    return ids.squeeze(), lg

def numeric_match(pred, gold):
    import re, math
    m = re.search(r'([-+]?\d[\d\.]*)\s*$', pred)
    if not m: return False
    try: return math.isclose(float(m.group(1)), float(eval(gold)), rel_tol=1e-3)
    except: return False

# ---------- DATASET ---------------------------------------------------------
gsm_all = load_dataset("gsm8k", "main")["test"].shuffle(seed=SEED)
train_rows = gsm_all.select(range(N_TRAIN))
eval_rows  = gsm_all.select(range(N_TRAIN, N_TRAIN + N_EVAL))

# ---------- BUILD RAST VECTOR ----------------------------------------------
MAX_POOL = 4000
def rs_add(pool, vec):
    if len(pool) < MAX_POOL:
        pool.append(vec)
    else:
        j = random.randrange(len(pool) + 1)
        if j < MAX_POOL: pool[j] = vec

hi, lo = [], []
print(f"Building RAST vector from {N_TRAIN} traces …")
for row in tqdm(train_rows):
    prm = f"Question: {row['question']} Answer step by step."
    ids, lg = stream(prm, GEN_LIMIT)
    hs = model(input_ids=ids.unsqueeze(0).to(DEVICE),
               output_hidden_states=True).hidden_states[STEER_LAY + 1][0]
    off = len(ids) - len(lg)                    # prompt offset
    for j in range(K_WIN, len(lg)):
        dkl = kl_div(lg[j], lg[j-K_WIN])
        vec = hs[off + j].detach().cpu().numpy()
        rs_add(hi if dkl >= DKL_THR else lo, vec)

print(f"  hi:{len(hi)}  lo:{len(lo)}")
vec_dir = (np.mean(hi,0) - np.mean(lo,0)).astype(np.float32)
rast_vec = SteeringVector({STEER_LAY: vec_dir})
print("Vector built.\n")

# ---------- RAST GENERATION --------------------------------------------------
from collections import deque
@torch.no_grad()
def rast_generate(prompt, vec):
    ids = tok(prompt, return_tensors="pt").to(DEVICE)["input_ids"]
    past, buf = None, deque(maxlen=K_WIN+1)
    model.set_steering(vec, coeff=0.0)
    for _ in range(MAX_TOKENS):
        out  = model(input_ids=ids[:, -1:], past_key_values=past,
                     use_cache=True, repetition_penalty=1.1)
        past = tuple((k[..., -K_WIN:, :].contiguous(),
                      v[..., -K_WIN:, :].contiguous())
                     for k,v in out.past_key_values)
        logits = out.logits[:, -1, :]
        buf.append(logits.detach())
        model.coeff = ALPHA_HI if len(buf) > K_WIN and kl_div(logits, buf[0]) < DKL_THR else 0.0
        nxt = logits.argmax(-1, keepdim=True)
        ids = torch.cat([ids, nxt], dim=-1)
        if nxt.item() in STOP_IDS or nxt.item()==tok.eos_token_id: break
    model.reset_steering()
    return ids.squeeze()

# ---------- EVALUATION ------------------------------------------------------
print(f"Evaluating on {N_EVAL} held-out problems …")
tb, tr, ab, ar = [], [], [], []
for row in tqdm(eval_rows):
    prm = f"Question: {row['question']} Answer step by step."
    ids_b, _ = stream(prm, MAX_TOKENS)
    ids_r    = rast_generate(prm, rast_vec)
    txt_b, txt_r = tok.decode(ids_b, skip_special_tokens=True), tok.decode(ids_r, skip_special_tokens=True)
    tb.append(ids_b.numel()); tr.append(ids_r.numel())
    ab.append(numeric_match(txt_b, row["answer"]))
    ar.append(numeric_match(txt_r, row["answer"]))

print("──────── RESULTS ────────")
print(f"Mean tokens baseline : {np.mean(tb):.1f}")
print(f"Mean tokens RAST     : {np.mean(tr):.1f}")
print(f"Token saving         : {100*(np.mean(tb)-np.mean(tr))/np.mean(tb):.1f}%")
print(f"Accuracy baseline    : {np.mean(ab):.3f}")
print(f"Accuracy RAST        : {np.mean(ar):.3f}")
print("────────────────────────")


Building RAST vector from 50 traces …


 64%|██████▍   | 32/50 [14:09<08:01, 26.77s/it]