In [1]:
# !pip install repeng accelerate datasets matplotlib seaborn

In [152]:
import warnings
warnings.filterwarnings("ignore", "To copy construct from a tensor", UserWarning)

In [2]:
#!/usr/bin/env python3
"""
RASPID with dynamic PID-steering:
- Train chunk-level classifier and control vector on the first 1000 labeled chains
- Hold out the last 200 chains for final GSM8K evaluation only
"""

import os, re, math, warnings
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from repeng import ControlModel
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# ─── 1) CONFIG & LOAD ──────────────────────────────────────────────────────

MODEL_NAME    = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE         = torch.float32

LABELED_CSV   = "gsm8k_chains_labeled_with_tokens.csv"
CTRL_VEC_PATH = "ctrl_vector.pt"

# classifier hyperparams
EMB_LAYER     = 20
CHUNK_SIZES   = [16,24]
BATCH_SIZE    = 32
FLUFF_STAR    = 0.5   # target probability for “redundant”

# PID steering hyperparams
INIT_FREE     = 40
STEER_WINDOW  = 60
KP, KI, KD    = 0.05, 0.001, 0.001
MAX_I, DERIV  = 0.20, 0.01
MAX_ALPHA     = 0.40
BASE_TEMP     = 0.70
STEER_TEMP    = 0.20
MAX_REPEAT    = 8

# load tokenizer & models
tokenizer    = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
base_model   = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=DTYPE,
    device_map="auto" if DEVICE=="cuda" else None
).eval()
control_model = ControlModel(base_model, [EMB_LAYER])

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [3]:
# ─── 2) SPLIT LABELED DATA ─────────────────────────────────────────────────

df_all = pd.read_csv(LABELED_CSV)

# first 1000 for training classifier & control vector
df_ctrl = df_all.iloc[:1000]
# last 200 reserved for later (but not used to train classifier)
df_eval = df_all.iloc[1000:1200]

required_ctrl  = df_ctrl["required_thoughts"].fillna("")
redundant_ctrl = df_ctrl["redundant_thoughts"].fillna("")

In [4]:
import os
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import numpy as np

# Disable HuggingFace tokenizer parallel warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class ChunkDataset(Dataset):
    def __init__(self, texts, label, cs):
        self.cs = cs
        self.chunks = []
        self.labels = []
        for t in texts:
            tok_ids = tokenizer.encode(t, add_special_tokens=False)
            for i in range(0, len(tok_ids) - cs + 1, cs):
                self.chunks.append(tok_ids[i : i + cs])
                self.labels.append(label)
    def __len__(self):
        return len(self.chunks)
    def __getitem__(self, idx):
        return self.chunks[idx], self.labels[idx]

def collate_fn(batch):
    input_ids, labels = zip(*batch)
    # pad on CPU so pin_memory works
    seqs = [torch.tensor(ids, dtype=torch.long) for ids in input_ids]
    padded = torch.nn.utils.rnn.pad_sequence(
        seqs, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    attention_mask = (padded != tokenizer.pad_token_id).long()
    return {"input_ids": padded, "attention_mask": attention_mask}, torch.tensor(labels, dtype=torch.long)

best_cs, best_acc, best_clf = None, 0.0, None

for cs in CHUNK_SIZES:
    # build datasets
    ds0 = ChunkDataset(required_ctrl, 0, cs)
    ds1 = ChunkDataset(redundant_ctrl, 1, cs)
    loader = DataLoader(
        ConcatDataset([ds0, ds1]),
        batch_size=BATCH_SIZE,
        collate_fn=collate_fn,
        shuffle=False,
        pin_memory=True,
        num_workers=0,  # no multiprocessing to avoid fork issues
    )

    # embed all chunks
    feats, labels = [], []
    base_model.eval()
    with torch.no_grad():
        for batch_tokens, batch_labels in tqdm(loader):
            # move to GPU
            batch_tokens = {
                k: v.to(DEVICE, non_blocking=True)
                for k, v in batch_tokens.items()
            }
            out = base_model(**batch_tokens, output_hidden_states=True)
            h = out.hidden_states[EMB_LAYER].mean(1)
            feats.append(h.cpu().numpy())
            labels.append(batch_labels.numpy())
    X = np.vstack(feats)
    y = np.concatenate(labels)

    # train/validation split
    Xtr, Xval, ytr, yval = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )

    # SGD training with tqdm progress
    clf = SGDClassifier(
        loss="log_loss", random_state=42, warm_start=True, max_iter=1, tol=None
    )
    prev_coef = None
    pbar = tqdm(range(1000), desc=f"Training clf @ cs={cs}", leave=False)
    for _ in pbar:
        clf.fit(Xtr, ytr)
        coef = clf.coef_
        if prev_coef is not None:
            delta = np.max(np.abs(coef - prev_coef))
            pbar.set_postfix(delta=delta)
            if delta < 1e-3:
                break
        prev_coef = coef.copy()
    pbar.close()

    acc = accuracy_score(yval, clf.predict(Xval))
    print(f"chunk_size={cs} → val_acc={acc:.3f}")
    if acc > best_acc:
        best_cs, best_acc, best_clf = cs, acc, clf

print(f"✔ Selected chunk_size={best_cs}, val_acc={best_acc:.3f}")

  0%|          | 0/1590 [00:00<?, ?it/s]

Training clf @ cs=16:   0%|          | 0/1000 [00:00<?, ?it/s]

chunk_size=16 → val_acc=0.291


  0%|          | 0/1049 [00:00<?, ?it/s]

Training clf @ cs=24:   0%|          | 0/1000 [00:00<?, ?it/s]

chunk_size=24 → val_acc=0.814
✔ Selected chunk_size=24, val_acc=0.814


In [5]:
# ─── 4) BUILD CONTROL VECTOR FROM CTRL SET ────────────────────────────────

def mean_hidden(texts):
    vs = []
    for t in tqdm(texts):
        toks = tokenizer(t, return_tensors="pt", truncation=True).to(DEVICE)
        with torch.inference_mode():
            h = base_model(**toks, output_hidden_states=True).hidden_states[EMB_LAYER][0]
        vs.append(h.mean(0).cpu())
    return torch.stack(vs).mean(0)

v_req = mean_hidden(required_ctrl)
v_red = mean_hidden(redundant_ctrl)
ctrl_vec = {EMB_LAYER: (v_req - v_red).to(DEVICE)}
torch.save(ctrl_vec, CTRL_VEC_PATH)
print("✅ Saved control vector")

# ─── 5) GENERATION ROUTINES ───────────────────────────────────────────────

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

✅ Saved control vector


In [6]:
from repeng import ControlVector

# model_type is a short string identifying your model, e.g. "qwen" or whatever base_model.config.model_type gives
model_type = base_model.config.model_type  # e.g. "DeepSeek-R1-Distill-Qwen-1.5B"

ctrl_vec = ControlVector(
    model_type = model_type,
    directions={EMB_LAYER: (v_req - v_red).to(DEVICE)}
)

# now you can save it
torch.save(ctrl_vec, "ctrl_vector.pt")

In [7]:
MAX_TOKENS = 4096

In [9]:
@torch.inference_mode()
def generate_baseline(prompt, max_new_tokens=MAX_TOKENS):
    inp = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    out = base_model.generate(
        **inp,
        max_new_tokens=max_new_tokens,
        do_sample=True, temperature=0.6,
        top_p=0.9, repetition_penalty=1.2,
        pad_token_id=tokenizer.eos_token_id,
    )
    toks = out.shape[1] - inp.input_ids.shape[1]
    return tokenizer.decode(out[0], skip_special_tokens=True), toks

In [41]:
# tuned hyper-params
INIT_FREE      = 80          # let the model reason first
STEER_WINDOW   = 60
BASE_TEMP      = 0.60        # same as baseline
STEER_TEMP     = 0.30
STEER_MARGIN   = 0.20        # p_red must exceed 0.5+0.2
MAX_ALPHA      = 0.40
MAX_RAW        = 50.0        # clamp for exp()

In [85]:
@torch.inference_mode()
def generate_raspid(prompt, max_new_tokens=MAX_TOKENS, debug=True):
    """
    Fixed RASPID generation function that properly applies the control vector
    and handles numerical stability issues.
    """
    if debug:
        print(f"\n=== RASPID GENERATION ===")
        print(f"Original prompt: '{prompt}'")
        print(f"MAX_TOKENS: {MAX_TOKENS}, INIT_FREE: {INIT_FREE}, STEER_WINDOW: {STEER_WINDOW}")
        print(f"Temperatures: BASE_TEMP: {BASE_TEMP}, STEER_TEMP: {STEER_TEMP}")
        print(f"PID: KP={KP}, KI={KI}, KD={KD}, MAX_I={MAX_I}, MAX_ALPHA={MAX_ALPHA}")
        print(f"FLUFF_STAR: {FLUFF_STAR}, EMB_LAYER: {EMB_LAYER}")
    
    # For structured control vectors (like ControlVector with directions dict),
    # we should leave the structure intact
    ctrl_vec_normalized = ctrl_vec

    # Setup generation
    stop_re = re.compile(r"\\boxed\{[^{}]{1,12}\}")
    ids = tokenizer(prompt, return_tensors="pt").to(DEVICE).input_ids[0]
    out_ids = ids.clone()
    past = None
    alpha = I = D = prev_err = 0.0
    chunk_h = None
    tok_in_chunk = 0
    steering = False
    steer_start = 0
    last_tok = None
    rep_ctr = 0
    generated_text = ""
    
    if debug:
        print("\n--- RASPID TRACE ---")
        print("step | on | p_red |  err  |   α   |   I   |   D   | temp | token")
    
    pbar = tqdm(range(max_new_tokens), desc="RASPID gen", leave=False)
    for step in pbar:
        gen_len = out_ids.size(0) - ids.size(0)
        
        # Check if steering should be activated/deactivated
        if not steering and gen_len >= INIT_FREE:
            steering, steer_start = True, gen_len
            if debug:
                print(f"[INFO] Activating steering at gen_len={gen_len}")
        if steering and gen_len - steer_start > STEER_WINDOW:
            if debug:
                print(f"[INFO] Deactivating steering at gen_len={gen_len}")
            steering, alpha, I, D = False, 0.0, 0.0, 0.0
        
        # Get coefficient and apply control
        coeff = alpha if steering else 0.0
        control_model.set_control(ctrl_vec_normalized, coeff=coeff)
        
        # CRITICAL FIX: Process the entire prompt on the first step
        if gen_len == 0 and step == 0:
            # First step - process the entire prompt
            out = control_model(
                input_ids=out_ids.unsqueeze(0),  # Use the full prompt
                use_cache=True,
                output_hidden_states=True
            )
        else:
            # Subsequent steps - process only the new token with the cached state
            out = control_model(
                input_ids=out_ids[-1:].unsqueeze(0),  # Only the last token
                past_key_values=past,
                use_cache=True,
                output_hidden_states=True
            )
        
        past, logits = out.past_key_values, out.logits[0, -1]
        h_last = out.hidden_states[EMB_LAYER][0, -1]
        
        # Check for NaN values
        if torch.isnan(logits).any():
            logits = torch.nan_to_num(logits)
        if torch.isnan(h_last).any():
            h_last = torch.nan_to_num(h_last)
        
        # Check for token repetition
        tok = out_ids[-1].item()
        if tok == last_tok:
            rep_ctr += 1
            if rep_ctr >= MAX_REPEAT:
                if debug:
                    print(f"[INFO] Hit MAX_REPEAT={MAX_REPEAT}, stopping generation")
                break
        else:
            rep_ctr, last_tok = 0, tok
        
        # Normalize hidden states for stability
        h_last_norm = h_last / (torch.norm(h_last) + 1e-8)
        
        # Track hidden states with normalized values
        chunk_h = h_last_norm if chunk_h is None else chunk_h + h_last_norm
        tok_in_chunk += 1
        
        # Classifier and PID controller
        p_red = err = 0.0
        if tok_in_chunk >= best_cs:
            try:
                # Use normalized chunk_h for classifier
                classifier_input = (chunk_h / best_cs).cpu().unsqueeze(0).numpy()
                
                # Check for NaN values
                if np.isnan(classifier_input).any():
                    classifier_input = np.nan_to_num(classifier_input)
                
                # Get raw classifier output
                raw = best_clf.decision_function(classifier_input)[0]
                
                # Scale down extreme classifier values
                if abs(raw) > 50.0:
                    scaled_raw = 50.0 * (np.sign(raw) * np.log(1 + abs(raw) / 50.0) / np.log(1 + abs(raw) / 50.0 * 20))
                    if debug:
                        print(f"[INFO] Scaling down extreme classifier value: {raw:.2f} -> {scaled_raw:.2f}")
                    raw = scaled_raw
                else:
                    raw = max(-50.0, min(50.0, raw))
                
                p_red = 1.0 / (1.0 + math.exp(-raw))
                
                if p_red > FLUFF_STAR + 0.20:
                    err = p_red - FLUFF_STAR
                    
                    # Update PID controller
                    I = max(-MAX_I, min(MAX_I, I + KI * err))
                    D = KD * (err - prev_err) + (1 - KD) * D
                    prev_err = err
                    alpha = max(0.0, min(MAX_ALPHA, alpha + KP * err + I + D))
                
                # Reset chunk
                chunk_h = None
                tok_in_chunk = 0
                
            except Exception as e:
                if debug:
                    print(f"[ERROR] Error in classifier/PID section: {e}")
                chunk_h = h_last_norm
                tok_in_chunk = 1
        
        # Temperature and sampling
        temp = BASE_TEMP * (1 - coeff / MAX_ALPHA) + STEER_TEMP * (coeff / MAX_ALPHA)
        
        # Apply temperature and get probabilities
        try:
            logits_safe = logits.clamp(-100, 100)
            probs = torch.softmax(logits_safe / temp, dim=-1)
            
            if torch.isnan(probs).any():
                probs = torch.ones_like(probs) / probs.size(0)
            
            # Sample token
            nxt = torch.multinomial(probs, 1).item()
            token_str = tokenizer.decode([nxt], skip_special_tokens=True).replace("\n","\\n")
            
            # Add token to output
            generated_text += token_str
            out_ids = torch.cat([out_ids, torch.tensor([nxt], device=DEVICE)])
            
        except Exception as e:
            if debug:
                print(f"[ERROR] Error in sampling: {e}")
            # Fallback to argmax sampling
            nxt = torch.argmax(logits).item()
            token_str = tokenizer.decode([nxt], skip_special_tokens=True).replace("\n","\\n")
            generated_text += token_str
            out_ids = torch.cat([out_ids, torch.tensor([nxt], device=DEVICE)])
        
        # Print trace line if in debug mode
        if debug:
            print(f"{gen_len:4d} | {int(steering)} | {p_red:5.3f} | {err:5.3f} | "
                  f"{alpha:5.3f} | {I:5.3f} | {D:5.3f} | {temp:5.3f} | '{token_str}'")
        
        # Check stop condition
        if stop_re.search(generated_text) or "Final answer:" in generated_text:
            if debug:
                print(f"[INFO] Stop condition met, ending generation")
            break
    
    pbar.close()
    if debug:
        print("--- END TRACE ---\n")
        print(f"=== GENERATION RESULT ===")
        print(f"Total tokens generated: {out_ids.size(0) - ids.size(0)}")
        print(f"Final text: {tokenizer.decode(out_ids, skip_special_tokens=True)}")
    
    return tokenizer.decode(out_ids, skip_special_tokens=True), out_ids.size(0) - ids.size(0)

In [124]:
@torch.inference_mode()
def generate_raspid_fixed_v2(prompt, max_new_tokens=MAX_TOKENS, debug=True):
    """
    Enhanced RASPID generation function with improved parameters for more effective steering.
    
    Key improvements:
    1. Start steering earlier (reduced INIT_FREE)
    2. Stronger steering effect (increased KP and MAX_ALPHA)
    3. Consistent steering (longer STEER_WINDOW)
    4. Better control application
    """
    # Override default parameters with improved values
    INIT_FREE_ORIG = INIT_FREE  # Store original value
    STEER_WINDOW_ORIG = STEER_WINDOW  # Store original value
    KP_ORIG = KP  # Store original value
    MAX_ALPHA_ORIG = MAX_ALPHA  # Store original value
    
    # Override with enhanced parameters for better steering
    INIT_FREE_ENHANCED = 15  # Start steering much earlier (was 80)
    STEER_WINDOW_ENHANCED = 300  # Longer window to maintain steering (was 60)
    KP_ENHANCED = 0.15  # Stronger proportional control (was 0.05)
    MAX_ALPHA_ENHANCED = 1.0  # Allow stronger steering effect (was 0.4)
    
    # Use the enhanced parameters
    INIT_FREE_USED = INIT_FREE_ENHANCED
    STEER_WINDOW_USED = STEER_WINDOW_ENHANCED
    KP_USED = KP_ENHANCED
    MAX_ALPHA_USED = MAX_ALPHA_ENHANCED
    
    if debug:
        print(f"\n=== ENHANCED RASPID GENERATION ===")
        print(f"Original prompt: '{prompt}'")
        print(f"MAX_TOKENS: {max_new_tokens}")
        print(f"INIT_FREE: {INIT_FREE_USED} (was {INIT_FREE_ORIG}) - Start steering earlier")
        print(f"STEER_WINDOW: {STEER_WINDOW_USED} (was {STEER_WINDOW_ORIG}) - More consistent steering")
        print(f"KP: {KP_USED} (was {KP_ORIG}) - Stronger steering response")
        print(f"MAX_ALPHA: {MAX_ALPHA_USED} (was {MAX_ALPHA_ORIG}) - Higher maximum steering")
        print(f"Other parameters: KI={KI}, KD={KD}, MAX_I={MAX_I}")
        print(f"BASE_TEMP: {BASE_TEMP}, STEER_TEMP: {STEER_TEMP}")
        print(f"FLUFF_STAR: {FLUFF_STAR}, EMB_LAYER: {EMB_LAYER}")
    
    # For structured control vectors (like ControlVector with directions dict),
    # we should leave the structure intact
    ctrl_vec_normalized = ctrl_vec

    # Setup generation
    stop_re = re.compile(r"\\boxed\{[^{}]{1,12}\}")
    ids = tokenizer(prompt, return_tensors="pt").to(DEVICE).input_ids[0]
    out_ids = ids.clone()
    past = None
    alpha = I = D = prev_err = 0.0
    chunk_h = None
    tok_in_chunk = 0
    steering = False
    steer_start = 0
    last_tok = None
    rep_ctr = 0
    generated_text = ""
    
    if debug:
        print("\n--- RASPID TRACE ---")
        print("step | on | p_red |  err  |   α   |   I   |   D   | temp | token")
    
    pbar = tqdm(range(max_new_tokens), desc="RASPID gen", leave=False)
    for step in pbar:
        gen_len = out_ids.size(0) - ids.size(0)
        
        # Check if steering should be activated/deactivated
        if not steering and gen_len >= INIT_FREE_USED:
            steering, steer_start = True, gen_len
            if debug:
                print(f"[INFO] Activating steering at gen_len={gen_len}")
        if steering and gen_len - steer_start > STEER_WINDOW_USED:
            if debug:
                print(f"[INFO] Deactivating steering at gen_len={gen_len}")
            steering, alpha, I, D = False, 0.0, 0.0, 0.0
        
        # Get coefficient and apply control
        coeff = alpha if steering else 0.0
        
        # Apply control - now using the correct coefficient
        control_model.set_control(ctrl_vec_normalized, coeff=coeff)
        
        # CRITICAL FIX: Process the entire prompt on the first step
        if gen_len == 0 and step == 0:
            # First step - process the entire prompt
            out = control_model(
                input_ids=out_ids.unsqueeze(0),  # Use the full prompt
                use_cache=True,
                output_hidden_states=True
            )
        else:
            # Subsequent steps - process only the new token with the cached state
            out = control_model(
                input_ids=out_ids[-1:].unsqueeze(0),  # Only the last token
                past_key_values=past,
                use_cache=True,
                output_hidden_states=True
            )
        
        past, logits = out.past_key_values, out.logits[0, -1]
        h_last = out.hidden_states[EMB_LAYER][0, -1]
        
        # Check for NaN values
        if torch.isnan(logits).any():
            logits = torch.nan_to_num(logits)
        if torch.isnan(h_last).any():
            h_last = torch.nan_to_num(h_last)
        
        # Check for token repetition
        tok = out_ids[-1].item()
        if tok == last_tok:
            rep_ctr += 1
            if rep_ctr >= MAX_REPEAT:
                if debug:
                    print(f"[INFO] Hit MAX_REPEAT={MAX_REPEAT}, stopping generation")
                break
        else:
            rep_ctr, last_tok = 0, tok
        
        # Normalize hidden states for stability
        h_last_norm = h_last / (torch.norm(h_last) + 1e-8)
        
        # Track hidden states with normalized values
        chunk_h = h_last_norm if chunk_h is None else chunk_h + h_last_norm
        tok_in_chunk += 1
        
        # Classifier and PID controller
        p_red = err = 0.0
        if tok_in_chunk >= best_cs:
            try:
                # Use normalized chunk_h for classifier
                classifier_input = (chunk_h / best_cs).cpu().unsqueeze(0).numpy()
                
                # Check for NaN values
                if np.isnan(classifier_input).any():
                    classifier_input = np.nan_to_num(classifier_input)
                
                # Get raw classifier output
                raw = best_clf.decision_function(classifier_input)[0]
                
                # Scale down extreme classifier values
                if abs(raw) > 50.0:
                    scaled_raw = 50.0 * (np.sign(raw) * np.log(1 + abs(raw) / 50.0) / np.log(1 + abs(raw) / 50.0 * 20))
                    if debug:
                        print(f"[INFO] Scaling down extreme classifier value: {raw:.2f} -> {scaled_raw:.2f}")
                    raw = scaled_raw
                else:
                    raw = max(-50.0, min(50.0, raw))
                
                p_red = 1.0 / (1.0 + math.exp(-raw))
                
                if p_red > FLUFF_STAR + 0.20:
                    err = p_red - FLUFF_STAR
                    
                    # Update PID controller with enhanced KP
                    I = max(-MAX_I, min(MAX_I, I + KI * err))
                    D = KD * (err - prev_err) + (1 - KD) * D
                    prev_err = err
                    
                    # Use enhanced KP for stronger steering
                    alpha = max(0.0, min(MAX_ALPHA_USED, alpha + KP_USED * err + I + D))
                
                # Reset chunk
                chunk_h = None
                tok_in_chunk = 0
                
            except Exception as e:
                if debug:
                    print(f"[ERROR] Error in classifier/PID section: {e}")
                chunk_h = h_last_norm
                tok_in_chunk = 1
        
        # Temperature and sampling - updated for enhanced MAX_ALPHA
        temp = BASE_TEMP * (1 - coeff / MAX_ALPHA_USED) + STEER_TEMP * (coeff / MAX_ALPHA_USED)
        
        # Apply temperature and get probabilities
        try:
            logits_safe = logits.clamp(-100, 100)
            probs = torch.softmax(logits_safe / temp, dim=-1)
            
            if torch.isnan(probs).any():
                probs = torch.ones_like(probs) / probs.size(0)
            
            # Sample token
            nxt = torch.multinomial(probs, 1).item()
            token_str = tokenizer.decode([nxt], skip_special_tokens=True).replace("\n","\\n")
            
            # Add token to output
            generated_text += token_str
            out_ids = torch.cat([out_ids, torch.tensor([nxt], device=DEVICE)])
            
        except Exception as e:
            if debug:
                print(f"[ERROR] Error in sampling: {e}")
            # Fallback to argmax sampling
            nxt = torch.argmax(logits).item()
            token_str = tokenizer.decode([nxt], skip_special_tokens=True).replace("\n","\\n")
            generated_text += token_str
            out_ids = torch.cat([out_ids, torch.tensor([nxt], device=DEVICE)])
        
        # Print trace line if in debug mode
        if debug:
            print(f"{gen_len:4d} | {int(steering)} | {p_red:5.3f} | {err:5.3f} | "
                  f"{alpha:5.3f} | {I:5.3f} | {D:5.3f} | {temp:5.3f} | '{token_str}'")
        
        # Check stop condition
        if stop_re.search(generated_text) or "Final answer:" in generated_text:
            if debug:
                print(f"[INFO] Stop condition met, ending generation")
            break
    
    pbar.close()
    if debug:
        print("--- END TRACE ---\n")
        print(f"=== GENERATION RESULT ===")
        print(f"Total tokens generated: {out_ids.size(0) - ids.size(0)}")
        print(f"Final text: {tokenizer.decode(out_ids, skip_special_tokens=True)}")
    
    return tokenizer.decode(out_ids, skip_special_tokens=True), out_ids.size(0) - ids.size(0)

In [145]:
def norm_answer(s: str) -> str:
    """
    Extracts the content of the last \boxed{} occurrence in the string.
    This handles cases where the model might have multiple boxed answers,
    ensuring we get the final one.
    
    Args:
        s: The string to extract the answer from
        
    Returns:
        The content inside the last \boxed{} occurrence, or empty string if none found
    """
    # Find all occurrences of \boxed{...}
    matches = list(re.finditer(r"\\boxed\{([^}]+)\}", s))
    
    # Return the last match, if any
    if matches:
        return matches[-1].group(1).strip()
    else:
        return ""

In [168]:
def run_gsm8k(n_probs, max_tokens, debug = False):
    gsm = load_dataset("gsm8k", "main")["test"].select(range(1000,1000+n_probs))
    rec = []
    baseline_total = 0
    raspid_total = 0
    for ex in tqdm(gsm):
        q   = ex["question"].strip()
        prompt = f"{q}\n\nAnswer step by step and end with: Final answer: \\boxed{{numeric_value}}"
        r_txt,r_tok = generate_raspid_fixed_v2(prompt, max_tokens, debug = debug)
        b_txt,b_tok = generate_baseline(prompt, max_tokens)
        

        rec.append({
            "reference_answer":ex["answer"],
            "baseline_correct": norm_answer(b_txt),
            "raspid_correct":  norm_answer(r_txt),
            "baseline_tokens": b_tok,
            "raspid_tokens":  r_tok,
            "baseline_txt": b_txt,
            "raspid_txt":  r_txt,

        })
        baseline_total += int(b_tok)
        raspid_total += int(r_tok)
        print(f'total-token-usage for baseline: {baseline_total} raspid: {raspid_total}')

    df = pd.DataFrame(rec)
    return df

In [None]:
results_df = run_gsm8k(50, 4096)

  0%|          | 0/50 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/4096 [00:00<?, ?it/s]

total-token-usage for baseline: 1253 raspid: 1439


RASPID gen:   0%|          | 0/4096 [00:00<?, ?it/s]

In [None]:
results_df.to_csv('results_df_50.csv')

In [154]:
from tqdm.auto import tqdm
import pandas as pd

records = []
for i in tqdm(range(10), desc='Eval runs'):
    a = generate_baseline('what is square root of 256', 1024)
    b = generate_raspid('what is square root of 256', 1024, False)
    c = generate_raspid_fixed_v2('what is square root of 256', 1024, False)
    records.append({
        'run': i + 1,
        'baseline_tokens': a[1],
        'raspid_tokens': b[1],
        'raspid_fixed_tokens': c[1],
        'baseline_answer': norm_answer(a[0]),
        'raspid_answer': norm_answer(b[0]),
        'raspid_fixed_answer': norm_answer(c[0])

    })

records = pd.DataFrame(records)

Eval runs:   0%|          | 0/10 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

RASPID gen:   0%|          | 0/1024 [00:00<?, ?it/s]

In [157]:
token_columns = [        
    'baseline_tokens',
    'raspid_tokens',
    'raspid_fixed_tokens',
]

In [160]:
records

Unnamed: 0,run,baseline_tokens,raspid_tokens,raspid_fixed_tokens,baseline_answer,raspid_answer,raspid_fixed_answer
0,1,178,69,408,16,16.0,16
1,2,486,340,287,16,16.0,16
2,3,282,249,224,16,16.0,16
3,4,253,378,442,16,16.0,16
4,5,409,1024,161,16,,16
5,6,75,210,166,16,16.0,16
6,7,367,359,332,\sqrt{256,16.0,16
7,8,450,205,275,16,16.0,16
8,9,81,539,198,16,16.0,16
9,10,1024,1024,311,,,16


In [159]:
records[token_columns].mean()

baseline_tokens        360.5
raspid_tokens          439.7
raspid_fixed_tokens    280.4
dtype: float64

In [163]:
a = generate_raspid_fixed_v2('what is square root of 81')
b = generate_baseline('what is square root of 81')


=== ENHANCED RASPID GENERATION ===
Original prompt: 'what is square root of 81'
MAX_TOKENS: 4096
INIT_FREE: 15 (was 80) - Start steering earlier
STEER_WINDOW: 300 (was 60) - More consistent steering
KP: 0.15 (was 0.05) - Stronger steering response
MAX_ALPHA: 1.0 (was 0.4) - Higher maximum steering
Other parameters: KI=0.001, KD=0.001, MAX_I=0.2
BASE_TEMP: 0.6, STEER_TEMP: 0.3
FLUFF_STAR: 0.5, EMB_LAYER: 20

--- RASPID TRACE ---
step | on | p_red |  err  |   α   |   I   |   D   | temp | token


RASPID gen:   0%|          | 0/4096 [00:00<?, ?it/s]

   0 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | '?\n\n'
   1 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | '<think>'
   2 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | '\n'
   3 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | 'To'
   4 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | ' find'
   5 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | ' the'
   6 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | ' square'
   7 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | ' root'
   8 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | ' of'
   9 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | ' '
  10 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | '8'
  11 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | '1'
  12 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | ','
  13 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | ' I'
  14 | 0 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.600 | ' need'
[INFO] Activating s

In [164]:
a

("what is square root of 81?\n\n<think>\nTo find the square root of 81, I need to determine the number that, when multiplied by itself, equals 81.\n\nI'll start by testing some whole numbers. \n\nFirst, I'll try 9. Multiplying 9 by itself gives 81.\n\nSince 9 squared equals 81, the square root of 81 is 9.\n</think>\n\nTo find the square root of 81, we need to determine a number that, when multiplied by itself, equals 81.\n\nLet's test some whole numbers:\n\n- \\(9 \\times 9 = 81\\)\n\nSince \\(9^2 = 81\\), the square root of 81 is:\n\n\\[\n\\boxed{9}\n",
 157)

In [165]:
b

("what is square root of 81?\n\n<think>\nTo find the square root of 81, I need to determine which number multiplied by itself equals 81.\n\nI know that multiplying a number by itself results in its square. For example:\n\n- \\(9 \\times 9 = 81\\)\n\nTherefore, the square root of 81 is 9.\n</think>\n\n**Solution:**\n\nTo find the **square root** of 81, we look for a number that, when multiplied by itself, gives us 81.\n\nLet's denote this unknown number as \\( x \\). Therefore, according to our equation:\n\\[ \nx^2 = 81 \n\\]\n\nWe can solve for \\( x \\) by taking the square root of both sides:\n\\[\nx = \\sqrt{81}\n\\]\n\nSince \\( 9 \\times 9 = 81 \\), it follows that:\n\\[\n\\sqrt{81} = 9\n\\]\n\nThus, the square root of 81 is:\n\n\\[\n\\boxed{9}\n\\]",
 217)