In [3]:
#!/usr/bin/env python3
# ===============================================================
#   RASPID 2 — Concise-Reasoning Steering with PID control
#   ▸ KL vs cached baseline   ▸ gated steering window
#   ▸ length-overshoot term   ▸ low-gain PID (α≤0.04)
#   ▸ broad stop regex        ▸ 10-problem GSM8K benchmark
# ===============================================================
import torch, math, re, time, random, pandas as pd, numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from repeng import ControlModel, ControlVector, DatasetEntry

# ─── model & tokenizer ────────────────────────────────────────────────────
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE      = torch.float32
tokenizer  = AutoTokenizer.from_pretrained(MODEL_NAME)
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=DTYPE, device_map="auto" if DEVICE == "cuda" else None
).eval()
control_model = ControlModel(base_model, [20]).to(DEVICE)

# ─── helper metrics ───────────────────────────────────────────────────────
def sanitize(t): return torch.nan_to_num(torch.clamp(t, -100.0, 100.0), nan=0.0)
    
def safe_entropy(v):
    p = torch.softmax(v, dim=-1) + 1e-10
    return float((-p * torch.log(p)).sum())
    
def safe_kl_div(a, b):
    pa, pb = (torch.softmax(x, dim=-1) + 1e-10 for x in (a, b))
    return float((pa * (torch.log(pa) - torch.log(pb))).sum())

# ─── control vector: train on many trimmed examples (you supply dataset) ──
ds = load_dataset("rb/aime_reasoning", "default")["train"].select(range(100))
def trim(txt):  # keep only first 3 lines & boxed answer
    txt = re.split(r"\\boxed", txt)[0]
    return "\n".join(txt.splitlines()[:3]).strip()
    
pairs = [DatasetEntry(r["refined_reasoning"], r["reasoning_content"]) for r in ds]
ctrl_vec = ControlVector.train(control_model, tokenizer, pairs, batch_size=1)

# ─── baseline generator (unchanged) ───────────────────────────────────────
def generate_baseline(prompt, max_tokens=2048, temp=0.3, top_p=0.9):
    inp = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    out = base_model.generate(
        **inp, max_new_tokens=max_tokens, do_sample=True, temperature=temp,
        top_p=top_p, pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(out[0], skip_special_tokens=True), out.shape[1]-inp.input_ids.shape[1]

# ─── RASPID 2 generator ───────────────────────────────────────────────────
@torch.inference_mode()
def generate_raspid(
    prompt: str,
    *,
    max_tokens:int = 2048,
    stop_regex:str = r"\\boxed\{.*?\}",
    # warm-up & steering window
    free_tokens:int = 40,
    steer_window:int = 120,
    chunk:int = 24,
    # PID
    kp=0.04, ki=0.004, kd=0.0,
    max_i=0.20, max_alpha=0.04, deriv_cap=0.005,
    # targets
    kl_star=0.10, h_star=5.2, w_kl=1.0, w_h=0.4,
    len_star=8, w_len=0.015,            # length penalty (chunks over desired)
    # sampling
    base_temp=0.65, steer_temp=0.50,
    # repetition
    ngram_n=3, ngram_win=15, repeat_pen=1.0,
    log_chunks=False
):
    stop_re = re.compile(stop_regex)
    ids = tokenizer(prompt, return_tensors="pt").to(DEVICE).input_ids[0]
    out_ids, past = ids.clone(), None

    # cache baseline logits for first pass (un-steered)
    baseline_logits_cache = []

    α=I=D=prev_err=0.0
    chunk_sum=baseline_sum=None; tok_in_chunk=0
    recent, ngram_mem=[],[]
    text_buf=""
    steering_enabled=False; steer_start=None

    for t in range(max_tokens):
        gen_len=len(out_ids)-len(ids)

        # decide if steering should be active
        if not steering_enabled and ("Answer step by step" in text_buf or gen_len>=free_tokens):
            steering_enabled=True; steer_start=gen_len
        if steering_enabled and gen_len-steer_start>steer_window:
            steering_enabled=False; α=0.0   # deactivate after window

        coeff=α if steering_enabled else 0.0
        control_model.set_control(ctrl_vec, coeff=coeff)

        # forward pass (steered or not)
        out=control_model(input_ids=out_ids[-1:].unsqueeze(0), past_key_values=past, use_cache=True)
        past,logits=out.past_key_values,sanitize(out.logits[0,-1,:])

        # store baseline logits (only need once per chunk)
        baseline_logits_cache.append(logits.cpu())
        if len(baseline_logits_cache)>chunk: baseline_logits_cache.pop(0)

        # repetition
        recent.append(int(torch.argmax(logits))); recent=recent[-ngram_win:]
        if len(recent)>=ngram_n and tuple(recent[-ngram_n:]) in ngram_mem:
            logits[recent[-1]]-=repeat_pen
        else:
            ngram_mem.append(tuple(recent[-ngram_n:])); ngram_mem=ngram_mem[-300:]

        # PID every chunk
        chunk_sum = logits if chunk_sum is None else chunk_sum+logits
        baseline_sum = baseline_logits_cache[0] if baseline_sum is None else baseline_sum+baseline_logits_cache[0]
        tok_in_chunk+=1
        if tok_in_chunk>=chunk:
            v     = chunk_sum/tok_in_chunk
            v_ref = baseline_sum/tok_in_chunk
            KL,H  = safe_kl_div(v,v_ref),safe_entropy(v)
            len_err = max(0, (gen_len//chunk) - len_star)
            err  = (w_kl*(KL-kl_star) + w_h*(H-h_star) + w_len*len_err)

            I = max(min(max_i, I*0.9+err), -max_i)
            d_err = max(min(err-prev_err, deriv_cap), -deriv_cap)
            D = 0.7*D+0.3*d_err; prev_err=err
            α = max(0.0, min(kp*err+ki*I+kd*D, max_alpha))

            if log_chunks:
                print(f"[{gen_len:3}] KL={KL:.3f} H={H:.2f} lenE={len_err:.2f} α={α:.3f}")
            chunk_sum=baseline_sum=None; tok_in_chunk=0; baseline_logits_cache.clear()

        temp = steer_temp if coeff>0.015 else base_temp
        probs=torch.softmax(logits/temp,dim=-1)
        nxt=torch.multinomial(probs,1).item()
        out_ids=torch.cat([out_ids,torch.tensor([nxt],device=DEVICE)])

        text_buf=(text_buf+tokenizer.decode([nxt]))[-300:]
        if stop_re.search(text_buf) or nxt==tokenizer.eos_token_id:
            break

    return tokenizer.decode(out_ids,skip_special_tokens=True),len(out_ids)-len(ids)

# ─── answer helpers ───────────────────────────────────────────────────────
def extract_answer(txt):
    m=re.search(r'\\boxed\{\s*([^{}]+?)\s*\}',txt)
    if m: return m.group(1).strip()
    m=re.search(r'Final answer:?\s*([-\d\.]+)',txt,re.I)
    if m: return m.group(1)
    nums=re.findall(r'(\d+\.?\d*)',txt)
    return nums[-1] if nums else None
def norm(ans):
    if ans is None: return None
    ans=re.sub(r'[$,\s]','',ans)
    try:f=float(ans);return str(int(f)) if f.is_integer() else str(f)
    except:return ans

# ─── mini-benchmark 10 GSM8K problems ─────────────────────────────────────
def run_gsm8k(n=10):
    gsm=load_dataset("gsm8k","main")["test"]
    sample=random.sample(list(gsm),n)
    rows=[]
    for i,p in enumerate(tqdm(sample,desc="GSM8K"),1):
        q=p["question"]; ref=norm(re.search(r'#### (\d+\.?\d*)',p["answer"]).group(1))
        pr=f"{q}\n\nAnswer step by step and end with: Final answer: \\boxed{{numeric_value}}"
        base_txt,base_tok=generate_baseline(pr,1024)
        rast_txt,rast_tok=generate_raspid(pr)
        rows.append({
            "idx":i,"ref":ref,
            "b_tok":base_tok,"r_tok":rast_tok,
            "b_ok":norm(extract_answer(base_txt))==ref,
            "r_ok":norm(extract_answer(rast_txt))==ref
        })
    df=pd.DataFrame(rows)
    print("\nBaseline acc {:.1f}% | RASPID acc {:.1f}%"
          .format(df.b_ok.mean()*100,df.r_ok.mean()*100))
    print("Avg tokens  base {:,.0f} | raspid {:,.0f}"
          .format(df.b_tok.mean(),df.r_tok.mean()))
    return df

if __name__=="__main__":
    run_gsm8k()

  6%|▌         | 11/200 [00:03<01:15,  2.50it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (19690 > 16384). Running this sequence through the model will result in indexing errors
  6%|▌         | 11/200 [00:08<02:20,  1.34it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 674.00 MiB. GPU 0 has a total capacity of 39.49 GiB of which 501.38 MiB is free. Including non-PyTorch memory, this process has 39.00 GiB memory in use. Of the allocated memory 36.73 GiB is allocated by PyTorch, and 1.77 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)