In [1]:
# !pip install accelerate scikit-learn #matplotlib seaborn datasets scikit-learn

In [2]:
# RAST  —  Redundancy-Aware Steering Technique (token-efficiency experiment)
# Uses your steerit.SteeringVector / SteeringModel definitions.
#
# PSEUDOCODE
# 1.  Load DeepSeek-R1-Distill-Qwen-1.5B and wrap with SteeringModel.
# 2.  For each difficulty level L in {1…5}:
#     a.  Generate N_TRAIN traces (step-by-step answers) with *no steering*.
#     b.  For every token t≥k:
#         • ΔKL = KL(p_t  ||  p_{t-k})   (# compare logits after rolling back k)
#         • If ΔKL < τ   → low-gain  → save hidden h_t in LOW
#           else          high-gain → save hidden h_t in HIGH
#     c.  Vector v_L = mean(HIGH) − mean(LOW)   (layer STEER_LAY only)
# 3.  Inference with ΔKL gate:
#       keep sliding buffer of logits; if current ΔKL<τ → set coeff α∈[α_lo,α_hi],
#       else coeff 0; SteeringModel hook adds α·v_L to layer activations.
# 4.  Record tokens/answer & accuracy for baseline vs RAST; plot %–saving vs level.
# ──────────────────────────────────────────────────────────────────────────────
#!/usr/bin/env python3
# RAST on GSM8K with automatic 5-level difficulty bins
# Requires steerit.SteeringVector and SteeringModel to be importable.

import os, random, math, time, warnings
import numpy as np
import torch, torch.nn.functional as F
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from steerit.steerit import SteeringVector, SteeringModel        # ← your library

# ─────────────────────────── settings ────────────────────────────
MODEL_NAME  = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
HF_TOKEN    = ''
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE       = torch.float16 if DEVICE == "cuda" else torch.float32

In [None]:
#!/usr/bin/env python3
import os
import random
import warnings
import re
import math
from collections import deque

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from steerit.steerit import SteeringVector, SteeringModel

# ───────── CONFIG ───────────────────────────────────────────────────────────
MODEL_NAME  = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
HF_TOKEN    = os.getenv("HF_TOKEN", "")
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE       = torch.float16 if DEVICE == "cuda" else torch.float32

STEER_LAY   = 20
K_WIN       = 10
DKL_THR     = 0.05
ALPHA_HI    = 1.0
MAX_TOKENS  = 4096    # for fair baseline vs steered eval
GEN_LIMIT   = 256    # shorter pass when building vector
MAX_POOL    = 4000
SUFFIX      = " Answer step by step and end with: Final answer:"
SEED        = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
warnings.filterwarnings("ignore")

# ───────── LOAD MODEL ────────────────────────────────────────────────────────
print(f"Loading {MODEL_NAME} …")
tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto" if DEVICE == "cuda" else None,
    use_auth_token=HF_TOKEN
)
model = SteeringModel(base_model, [STEER_LAY], DEVICE)
print("Model ready.\n")

# ───────── PREPARE STOP IDS ─────────────────────────────────────────────────
STOP_IDS = tok("Final answer:", add_special_tokens=False).input_ids

# ───────── HELPERS ──────────────────────────────────────────────────────────
def kl_div(p, q):
    return F.kl_div(
        F.log_softmax(p, dim=-1),
        F.softmax(q, dim=-1),
        reduction="batchmean",
    ).item()

def numeric_match(text, gold):
    m = re.search(r'([-+]?\d+(?:\.\d+)?)\s*$', text)
    if not m:
        return False
    try:
        return math.isclose(float(m.group(1)), float(eval(gold)), rel_tol=1e-3)
    except:
        return False

def sample_next(logits, temperature=0.7, top_p=0.9):
    # nucleus sampling with temperature
    logits = logits / temperature
    sorted_logits, sorted_idx = torch.sort(logits, descending=True)
    probs = torch.softmax(sorted_logits, dim=-1)
    cum_probs = torch.cumsum(probs, dim=-1)
    mask = cum_probs > top_p
    mask[..., 0] = False
    probs = probs.masked_fill(mask, 0.0)
    probs = probs / probs.sum(dim=-1, keepdim=True)
    choice = torch.multinomial(probs, num_samples=1)
    return sorted_idx.gather(-1, choice)

@torch.no_grad()
def stream(prompt, max_len):
    """
    Nucleus-sampled stream with trimmed KV-cache.
    Returns full ids (prompt+gen) and per-step logits list.
    """
    ids = tok(prompt, return_tensors="pt").to(DEVICE)["input_ids"]
    past, logs = None, []
    for _ in range(max_len):
        out = model(
            input_ids=ids[:, -1:], 
            past_key_values=past,
            use_cache=True,
            repetition_penalty=1.1
        )
        past = tuple(
            (k[..., -K_WIN:, :].contiguous(), 
             v[..., -K_WIN:, :].contiguous())
            for k, v in out.past_key_values
        )
        logits = out.logits[:, -1, :]
        logs.append(logits.detach())
        nxt = sample_next(logits)
        token_id = nxt.item()
        ids = torch.cat([ids, nxt], dim=-1)
        if token_id in STOP_IDS or token_id == tok.eos_token_id:
            break
    return ids.squeeze(), logs

# ───────── LOAD DATASET ─────────────────────────────────────────────────────
gsm         = load_dataset("gsm8k", "main")["test"].shuffle(seed=SEED)
train_rows  = gsm.select(range(50))   # use first 50 for vector build
eval_rows   = gsm.select(range(50, 60))  # next 10 for quick eval

# ───────── BUILD RAST VECTOR ────────────────────────────────────────────────
def rs_add(pool, vec):
    if len(pool) < MAX_POOL:
        pool.append(vec)
    else:
        j = random.randrange(len(pool) + 1)
        if j < MAX_POOL:
            pool[j] = vec

hi_vecs, lo_vecs = [], []
print("Building RAST vector …")
for row in tqdm(train_rows):
    prompt = f"Question: {row['question']}.{SUFFIX}"
    ids, logs = stream(prompt, GEN_LIMIT)
    out = model(
        input_ids=ids.unsqueeze(0).to(DEVICE),
        output_hidden_states=True
    )
    hs = out.hidden_states[STEER_LAY + 1][0]  # layer index +1 for embeddings layer
    offset = len(ids) - len(logs)
    for j in range(K_WIN, len(logs)):
        dkl = kl_div(logs[j], logs[j - K_WIN])
        vec = hs[offset + j].detach().cpu().numpy()
        rs_add(hi_vecs if dkl >= DKL_THR else lo_vecs, vec)

print(f"  hi:{len(hi_vecs)}  lo:{len(lo_vecs)} (capped {MAX_POOL})")
vec_dir  = (np.mean(hi_vecs, axis=0) - np.mean(lo_vecs, axis=0)).astype(np.float32)
rast_vec = SteeringVector({STEER_LAY: vec_dir})
print("Vector ready.\n")

# ───────── RAST-GENERATION ─────────────────────────────────────────────────
@torch.no_grad()
def rast_generate(prompt, vec):
    ids  = tok(prompt, return_tensors="pt").to(DEVICE)["input_ids"]
    past, buf = None, deque(maxlen=K_WIN + 1)
    model.set_steering(vec, coeff=0.0)
    for _ in range(MAX_TOKENS):
        out = model(
            input_ids=ids[:, -1:], 
            past_key_values=past,
            use_cache=True,
            repetition_penalty=1.1
        )
        past = tuple(
            (k[..., -K_WIN:, :].contiguous(), 
             v[..., -K_WIN:, :].contiguous())
            for k, v in out.past_key_values
        )
        logits = out.logits[:, -1, :]
        buf.append(logits.detach())
        if len(buf) > K_WIN and kl_div(logits, buf[0]) < DKL_THR:
            model.coeff = ALPHA_HI
        else:
            model.coeff = 0.0
        nxt = sample_next(logits)
        token_id = nxt.item()
        ids = torch.cat([ids, nxt], dim=-1)
        if token_id in STOP_IDS or token_id == tok.eos_token_id:
            break
    model.reset_steering()
    return ids.squeeze()

# ───────── EVALUATION ───────────────────────────────────────────────────────
print(f"Evaluating on {len(eval_rows)} problems …")
tok_base, tok_rast, acc_base, acc_rast = [], [], [], []
base_txts, rast_txts = [], []

for row in tqdm(eval_rows):
    prompt    = f"Question: {row['question']}.{SUFFIX}"
    base_ids, _ = stream(prompt, MAX_TOKENS)
    rast_ids    = rast_generate(prompt, rast_vec)

    base_txt = tok.decode(base_ids, skip_special_tokens=True)
    rast_txt = tok.decode(rast_ids, skip_special_tokens=True)
    
    base_txts.append(base_txt)
    rast_txts.append(rast_txts)

    tok_base.append(base_ids.numel())
    tok_rast.append(rast_ids.numel())
    
    acc_base.append(numeric_match(base_txt, row["answer"]))
    acc_rast.append(numeric_match(rast_txt, row["answer"]))

print("──────── RESULTS ────────")
print(f"Mean tokens baseline : {np.mean(tok_base):.1f}")
print(f"Mean tokens RAST     : {np.mean(tok_rast):.1f}")
print(f"Token saving         : {100*(np.mean(tok_base)-np.mean(tok_rast))/np.mean(tok_base):.1f}%")
print(f"Accuracy baseline    : {np.mean(acc_base):.3f}")
print(f"Accuracy RAST        : {np.mean(acc_rast):.3f}")
print("────────────────────────")


Building RAST vector from 50 traces …


 64%|██████▍   | 32/50 [14:09<08:01, 26.77s/it]