In [1]:
#!/usr/bin/env python3
# ===============================================================
#         RASPID  –  Reasoning-Aware Steering with PID control
#         Simple Implementation
# ===============================================================
#  pip install transformers datasets repeng tqdm

import os, re, random, warnings
import numpy as np, torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from repeng import ControlModel, ControlVector, DatasetEntry
from tqdm import tqdm

# ---------------- CONFIG ---------------------------------------------------
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32

STEER_LAYER = 20                 # layer where the repeng vector is injected
TRAIN_N = 600                    # #positive / negative pairs to train the vector
CHUNK = 16                       # "thought" chunk size
MAX_TOKENS = 2048*2                # maximum generation length
DEBUG = True                     # print debug info

# PID & thresholds
KL_SETPOINT = 0.03               # target inter-chunk KL
ENT_SETPOINT = 4.0               # target token entropy
KP, KI, KD = 6.0, 0.15, 0.8
ALPHA_MIN, ALPHA_MAX = 0.0, 2.0

# sampling
TOP_P = 0.92
TEMP = 0.75
REP_PENALTY = 1.15

random.seed(42); np.random.seed(42); torch.manual_seed(42)
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ---------------- LOAD MODEL ----------------------------------------------
print("⇢ Loading model …")
tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
base = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto" if DEVICE=="cuda" else None).eval()

model = ControlModel(base, [STEER_LAYER]).to(DEVICE)

def kv_trim(cache, k_win=256):
    """keep last k tokens in kv-cache (works for HF ≤4.46 tuple format)"""
    if cache is None: return None
    return tuple((k[...,-k_win:,:].contiguous(), v[...,-k_win:,:].contiguous())
                 for k,v in cache)

# ---------------- TRAIN DENSE-REASONING VECTOR ----------------------------
print("\n• building dense-reasoning control vector …")
ds = load_dataset("rb/aime_reasoning", "default")["train"].select(range(TRAIN_N))
def strip(txt): return re.sub(r"[^\x20-\x7E]", " ", txt)  # ASCII only for training
pairs = [DatasetEntry(strip(r["refined_reasoning"]),
                     strip(r["reasoning_content"])) for r in ds]

ctrl_vec = ControlVector.train(model, tok, pairs, batch_size=1)

# ---------------- HELPER FUNCTIONS ----------------------------------------
def sanitize(t):  # clamp + NaN→-80
    return torch.nan_to_num(torch.clamp(t, -80., 80.), nan=-80.)

def entropy(logits):
    p = torch.softmax(logits, -1)
    return float(-(p * p.log()).sum())

def kl_div(a, b):  # KL(p‖q)
    pa, pb = torch.softmax(a, -1), torch.softmax(b, -1)
    return float((pa * (pa.log() - pb.log())).sum())

# ---------------- SIMPLE NUCLEUS SAMPLING ---------------------------------
def sample_token(logits_np, temperature=TEMP, top_p=TOP_P, rep_penalty=REP_PENALTY, token_history=None):
    """Simple nucleus sampling with repetition penalty"""
    # Apply repetition penalty if token history provided
    if token_history:
        # Count token frequencies
        token_counts = {}
        for t in token_history[-50:]:  # Look at last 50 tokens
            token_counts[t] = token_counts.get(t, 0) + 1
        
        # Apply penalties based on frequency
        for t, count in token_counts.items():
            if count > 0:
                logits_np[t] -= np.log(rep_penalty) * count
    
    # Temperature scaling
    logits_np = np.clip(logits_np / temperature, -80, 80)
    
    # Convert to probabilities
    probs = np.exp(logits_np - np.max(logits_np))
    probs = probs / np.sum(probs)
    
    # Top-p sampling
    sorted_indices = np.argsort(-probs)
    cumulative_probs = np.cumsum(probs[sorted_indices])
    sorted_indices_to_keep = sorted_indices[cumulative_probs <= top_p]
    
    # Always keep top token
    if len(sorted_indices_to_keep) == 0:
        sorted_indices_to_keep = sorted_indices[:1]
    
    # Renormalize and sample
    selected_probs = probs[sorted_indices_to_keep]
    selected_probs = selected_probs / np.sum(selected_probs)
    selected_token = int(np.random.choice(sorted_indices_to_keep, p=selected_probs))
    
    return selected_token

# ---------------- BASE GENERATION FUNCTION -------------------------------
@torch.inference_mode()
def base_generate(prompt, max_tokens=MAX_TOKENS):
    """Basic generation without steering"""
    print("\nRunning baseline generation...")
    
    # Tokenize prompt
    input_ids = tok(prompt, return_tensors="pt").to(DEVICE).input_ids[0]
    prompt_len = len(input_ids)
    token_history = []
    past = None
    
    print(f"Prompt length: {prompt_len} tokens")
    print("Starting generation...")
    
    # Generation loop
    for i in tqdm(range(max_tokens)):
        # Forward pass through model
        outputs = base(
            input_ids=input_ids[-1:].unsqueeze(0),
            past_key_values=past,
            use_cache=True
        )
        past = outputs.past_key_values
        
        # Get logits for next token prediction
        logits = outputs.logits[0, -1, :]
        logits_np = logits.detach().cpu().numpy()
        
        # Sample next token
        next_token = sample_token(logits_np, token_history=token_history)
        
        # Add to tokens and history
        input_ids = torch.cat([input_ids, torch.tensor([next_token], device=DEVICE)], dim=0)
        token_history.append(next_token)
        
        # Debug output
        # if (i+1) % 50 == 0:
        #     print(f"Generated {i+1} tokens (total: {len(input_ids)})")
        #     recent_text = tok.decode(input_ids[-100:], skip_special_tokens=True)
        #     print(f"Recent text: ...{recent_text[-50:]}")
        
        # Stop on EOS token
        if next_token == tok.eos_token_id:
            print("EOS token detected, stopping generation")
            break
    
    print(f"Base generation complete. Total tokens: {len(input_ids)}")
    return input_ids

# ---------------- RASPID GENERATION FUNCTION -----------------------------
@torch.inference_mode()
def rast_generate(prompt, max_tokens=MAX_TOKENS):
    """Generation with RASPID steering"""
    print("\nRunning RASPID generation...")
    
    # Tokenize prompt
    input_ids = tok(prompt, return_tensors="pt").to(DEVICE).input_ids[0]
    prompt_len = len(input_ids)
    token_history = []
    past = None
    
    # PID state
    integ = deriv = prev_err = 0.0
    alpha = 0.0
    
    # Chunk tracking
    chunk_sum = None
    tok_in_chunk = 0
    last_vec = None
    
    print(f"Prompt length: {prompt_len} tokens")
    print("Starting generation with RASPID steering...")
    
    # Generation loop
    for i in tqdm(range(max_tokens)):
        # Forward pass through model
        outputs = model(
            input_ids=input_ids[-1:].unsqueeze(0),
            past_key_values=past,
            use_cache=True
        )
        past = outputs.past_key_values
        
        # Get logits for next token prediction
        logits = sanitize(outputs.logits[0, -1, :])
        
        # Accumulate for chunk analysis
        chunk_sum = logits.clone() if chunk_sum is None else chunk_sum + logits
        tok_in_chunk += 1
        
        # Process completed chunk
        if tok_in_chunk == CHUNK:
            vec = chunk_sum / CHUNK
            H = entropy(logits)
            
            if last_vec is not None:
                # PID control
                kl = kl_div(vec, last_vec)
                err = (kl - KL_SETPOINT) + 0.3 * (H - ENT_SETPOINT)
                
                # Update integral term with anti-windup
                integ = np.clip(integ + err, -2.0, 2.0)
                
                # Calculate derivative term
                deriv = err - prev_err
                prev_err = err
                
                # Calculate alpha (steering strength)
                alpha = KP * err + KI * integ + KD * deriv
                alpha = float(np.clip(alpha, ALPHA_MIN, ALPHA_MAX))
                
                # Apply steering
                if alpha > 1e-3:
                    model.set_control(ctrl_vec, coeff=alpha)
                    if DEBUG and i % 50 < 5:  # Only print occasionally to avoid spam
                        print(f"Step {i}: KL={kl:.4f}, H={H:.2f}, α={alpha:.4f}")
                else:
                    model.reset()
            
            # Reset for next chunk
            last_vec = vec
            chunk_sum = None
            tok_in_chunk = 0
        
        # Sample next token
        logits_np = logits.detach().cpu().numpy()
        next_token = sample_token(logits_np, token_history=token_history)
        
        # Add to tokens and history
        input_ids = torch.cat([input_ids, torch.tensor([next_token], device=DEVICE)], dim=0)
        token_history.append(next_token)
        
        # # Debug output
        # if (i+1) % 50 == 0:
        #     print(f"Generated {i+1} tokens (total: {len(input_ids)})")
        #     recent_text = tok.decode(input_ids[-100:], skip_special_tokens=True)
        #     print(f"Recent text: ...{recent_text[-50:]}")
        
        # Stop on EOS token
        if next_token == tok.eos_token_id:
            print("EOS token detected, stopping generation")
            break
    
    # Reset model state
    model.reset()
    
    print(f"RASPID generation complete. Total tokens: {len(input_ids)}")
    return input_ids

# ---------------- BENCHMARK ----------------------------------------------
print("\n• Running benchmark with GSM8K example...")
gsm = load_dataset("gsm8k", "main")["test"][0]
SUFFIX = "Answer step by step and end with: Final answer:"

# Add a small hint in the prompt to keep it on track
prompt = f"Question: {gsm['question']}"

# Run both generators
ids_base = base_generate(prompt)
ids_rast = rast_generate(prompt)

# Display results
baseline_text = tok.decode(ids_base, skip_special_tokens=True)
raspid_text = tok.decode(ids_rast, skip_special_tokens=True)

token_saving = (1 - len(ids_rast)/len(ids_base))*100


print("\n──────── BASELINE OUTPUT ───────────────")
print(baseline_text)
print("\n──────── RASPID OUTPUT ───────────────")
print(raspid_text)
print("─────────────────────────────────")

print("\n──────── SUMMARY ───────────────")
print(f"BASE: {len(ids_base)} tokens")
print(f"RAST: {len(ids_rast)} tokens") 
print(f"Token saving: {token_saving:.1f}%")
print("─────────────────────────────────")