In [4]:
# !pip install repeng accelerate datasets matplotlib seaborn vllm

In [5]:
#!/usr/bin/env python
# coding: utf-8

"""
RASPID (Reasoning-Aware Steering with PID control):
A method for dynamic control of language model generation using chunk-level classification
and PID-based steering to optimize reasoning quality while reducing redundancy.

This implementation:
- Trains a chunk-level classifier on labeled reasoning chains
- Builds control vectors from required vs redundant thoughts
- Uses PID control to dynamically steer generation away from redundant patterns
- Evaluates performance on GSM8K mathematical reasoning tasks
"""

import os
import re
import math
import warnings
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from repeng import ControlModel, ControlVector
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset, DataLoader, ConcatDataset

# Suppress warnings
warnings.filterwarnings("ignore", "To copy construct from a tensor", UserWarning)
os.environ["TOKENIZERS_PARALLELISM"] = "false"


class RASPIDConfig:
    """Configuration class for RASPID hyperparameters"""
    
    # Model configuration
    MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    DTYPE = torch.float32
    
    # Data paths
    LABELED_CSV = "results/gsm8k_chains_labeled_with_tokens.csv"
    CTRL_VEC_PATH = "ctrl_vector.pt"
    
    # Classifier hyperparameters
    EMB_LAYER = 20
    CHUNK_SIZES = [16, 24]
    BATCH_SIZE = 32
    FLUFF_STAR = 0.5  # Target probability for "redundant"
    
    # PID steering hyperparameters
    INIT_FREE = 80      # Let model reason freely first
    STEER_WINDOW = 60   # Duration of steering intervention
    KP, KI, KD = 0.05, 0.001, 0.001  # PID coefficients
    MAX_I = 0.20        # Maximum integral term
    MAX_ALPHA = 0.40    # Maximum steering coefficient
    STEER_MARGIN = 0.20 # Margin above FLUFF_STAR to trigger steering
    
    # Generation parameters
    BASE_TEMP = 0.60    # Base temperature for sampling
    STEER_TEMP = 0.30   # Temperature during steering
    MAX_REPEAT = 8      # Maximum consecutive token repetitions
    MAX_TOKENS = 4096*8   # Maximum tokens to generate
    MAX_RAW = 50.0      # Clamp value for classifier raw scores
    
    # Performance optimizations
    ENABLE_COMPILE = True  # Whether to use torch.compile for speedup
    MEMORY_EFFICIENT = True  # Whether to use memory-efficient attention
    PROGRESS_UPDATE_FREQ = 10  # How often to update progress (every N problems)


class ChunkDataset(Dataset):
    """Dataset for creating fixed-size chunks from text sequences"""
    
    def __init__(self, texts, label, chunk_size, tokenizer):
        self.chunk_size = chunk_size
        self.chunks = []
        self.labels = []
        
        for text in texts:
            tok_ids = tokenizer.encode(text, add_special_tokens=False)
            # Create overlapping chunks
            for i in range(0, len(tok_ids) - chunk_size + 1, chunk_size):
                chunk = tok_ids[i:i + chunk_size]
                self.chunks.append(chunk)
                self.labels.append(label)
    
    def __len__(self):
        return len(self.chunks)
    
    def __getitem__(self, idx):
        return self.chunks[idx], self.labels[idx]


class ChunkClassifier:
    """Chunk-level classifier for identifying redundant vs required reasoning"""
    
    def __init__(self, config, tokenizer, model):
        self.config = config
        self.tokenizer = tokenizer
        self.model = model
        self.best_classifier = None
        self.best_chunk_size = None
        self.best_accuracy = 0.0
    
    def _collate_fn(self, batch):
        """Collate function for DataLoader"""
        input_ids, labels = zip(*batch)
        # Pad sequences on CPU for pin_memory efficiency
        seqs = [torch.tensor(ids, dtype=torch.long) for ids in input_ids]
        padded = torch.nn.utils.rnn.pad_sequence(
            seqs, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        attention_mask = (padded != self.tokenizer.pad_token_id).long()
        return {"input_ids": padded, "attention_mask": attention_mask}, torch.tensor(labels, dtype=torch.long)
    
    def _extract_embeddings(self, loader):
        """Extract embeddings from model for all chunks"""
        features, labels = [], []
        self.model.eval()
        
        with torch.no_grad():
            for batch_tokens, batch_labels in tqdm(loader, desc="Extracting embeddings"):
                # Move batch to device
                batch_tokens = {
                    k: v.to(self.config.DEVICE, non_blocking=True)
                    for k, v in batch_tokens.items()
                }
                
                # Get model outputs and extract embeddings
                outputs = self.model(**batch_tokens, output_hidden_states=True)
                hidden_states = outputs.hidden_states[self.config.EMB_LAYER]
                embeddings = hidden_states.mean(dim=1)  # Average pooling
                
                features.append(embeddings.cpu().numpy())
                labels.append(batch_labels.numpy())
        
        return np.vstack(features), np.concatenate(labels)
    
    def _train_sgd_classifier(self, X_train, y_train, X_val, y_val, chunk_size):
        """Train SGD classifier with early stopping"""
        classifier = SGDClassifier(
            loss="log_loss", random_state=42, warm_start=True, max_iter=1, tol=None
        )
        
        prev_coef = None
        pbar = tqdm(range(500), desc=f"Training classifier (cs={chunk_size})", leave=False)
        
        for iteration in pbar:
            classifier.fit(X_train, y_train)
            current_coef = classifier.coef_
            
            if prev_coef is not None:
                delta = np.max(np.abs(current_coef - prev_coef))
                pbar.set_postfix(delta=delta)
                if delta < 1e-3:  # Convergence criterion
                    break
            prev_coef = current_coef.copy()
        
        pbar.close()
        
        # Evaluate on validation set
        val_accuracy = accuracy_score(y_val, classifier.predict(X_val))
        return classifier, val_accuracy
    
    def train(self, required_texts, redundant_texts):
        """Train classifier on required vs redundant text chunks"""
        print("Training chunk classifier...")
        
        for chunk_size in self.config.CHUNK_SIZES:
            print(f"Testing chunk size: {chunk_size}")
            
            # Create datasets
            dataset_required = ChunkDataset(required_texts, 0, chunk_size, self.tokenizer)
            dataset_redundant = ChunkDataset(redundant_texts, 1, chunk_size, self.tokenizer)
            
            # Create data loader
            combined_dataset = ConcatDataset([dataset_required, dataset_redundant])
            loader = DataLoader(
                combined_dataset,
                batch_size=self.config.BATCH_SIZE,
                collate_fn=self._collate_fn,
                shuffle=False,
                pin_memory=True,
                num_workers=0
            )
            
            # Extract embeddings
            X, y = self._extract_embeddings(loader)
            
            # Train/validation split
            X_train, X_val, y_train, y_val = train_test_split(
                X, y, test_size=0.2, random_state=42, stratify=y
            )
            
            # Train classifier
            classifier, accuracy = self._train_sgd_classifier(
                X_train, y_train, X_val, y_val, chunk_size
            )
            
            print(f"Chunk size {chunk_size} → Validation accuracy: {accuracy:.3f}")
            
            # Keep track of best performer
            if accuracy > self.best_accuracy:
                self.best_chunk_size = chunk_size
                self.best_accuracy = accuracy
                self.best_classifier = classifier
        
        print(f"✅ Best classifier: chunk_size={self.best_chunk_size}, accuracy={self.best_accuracy:.3f}")
        return self.best_classifier, self.best_chunk_size


class ControlVectorBuilder:
    """Builder for control vectors from text contrasts"""
    
    def __init__(self, config, tokenizer, model):
        self.config = config
        self.tokenizer = tokenizer
        self.model = model
    
    def _compute_mean_hidden_state(self, texts):
        """Compute mean hidden state across texts"""
        hidden_states = []
        
        for text in tqdm(texts, desc="Computing hidden states"):
            tokens = self.tokenizer(text, return_tensors="pt", truncation=True).to(self.config.DEVICE)
            
            with torch.inference_mode():
                outputs = self.model(**tokens, output_hidden_states=True)
                hidden_state = outputs.hidden_states[self.config.EMB_LAYER][0]
                mean_hidden = hidden_state.mean(dim=0)
                hidden_states.append(mean_hidden.cpu())
        
        return torch.stack(hidden_states).mean(dim=0)
    
    def build_control_vector(self, required_texts, redundant_texts):
        """Build control vector from required vs redundant text contrasts"""
        print("Building control vector...")
        
        # Compute mean hidden states
        v_required = self._compute_mean_hidden_state(required_texts)
        v_redundant = self._compute_mean_hidden_state(redundant_texts)
        
        # Create control vector
        model_type = self.model.config.model_type
        control_vector = ControlVector(
            model_type=model_type,
            directions={self.config.EMB_LAYER: (v_required - v_redundant).to(self.config.DEVICE)}
        )
        
        # Save control vector
        torch.save(control_vector, self.config.CTRL_VEC_PATH)
        print("✅ Control vector saved")
        
        return control_vector


class RASPIDGenerator:
    """RASPID generator with PID-controlled steering"""
    
    def __init__(self, config, tokenizer, base_model, control_model, classifier, chunk_size, control_vector):
        self.config = config
        self.tokenizer = tokenizer
        self.base_model = base_model
        self.control_model = control_model
        self.classifier = classifier
        self.chunk_size = chunk_size
        self.control_vector = control_vector
        self.stop_pattern = re.compile(r"\\boxed\{[^{}]{1,12}\}")
        
        # Pre-compile regex patterns for efficiency
        self.final_answer_pattern = re.compile(r"Final answer:", re.IGNORECASE)
        
        # Pre-compute some constants to avoid repeated calculations
        self.max_alpha_inv = 1.0 / self.config.MAX_ALPHA
        self.temp_diff = self.config.BASE_TEMP - self.config.STEER_TEMP
    
    def _handle_numerical_stability(self, tensor, name="tensor"):
        """Handle NaN values and extreme values in tensors"""
        if torch.isnan(tensor).any():
            return torch.nan_to_num(tensor, nan=0.0, posinf=100.0, neginf=-100.0)
        return tensor
    
    def _scale_extreme_classifier_output(self, raw_score):
        """Scale down extreme classifier outputs for stability - optimized"""
        abs_raw = abs(raw_score)
        if abs_raw > self.config.MAX_RAW:
            return self.config.MAX_RAW * np.sign(raw_score) * np.tanh(abs_raw / self.config.MAX_RAW)
        return np.clip(raw_score, -self.config.MAX_RAW, self.config.MAX_RAW)
    
    def _should_apply_steering(self, generation_length, steering_active, steering_start):
        """Determine if steering should be applied - optimized"""
        if not steering_active:
            return generation_length >= self.config.INIT_FREE, generation_length
        elif generation_length - steering_start > self.config.STEER_WINDOW:
            return False, steering_start
        return True, steering_start
    
    def _update_pid_controller(self, p_redundant, alpha, integral, derivative, prev_error):
        """Update PID controller based on redundancy probability - optimized"""
        threshold = self.config.FLUFF_STAR + self.config.STEER_MARGIN
        if p_redundant <= threshold:
            return alpha, integral, derivative, 0.0
        
        error = p_redundant - self.config.FLUFF_STAR
        
        # Update PID terms with clamping
        integral = np.clip(integral + self.config.KI * error, -self.config.MAX_I, self.config.MAX_I)
        derivative = self.config.KD * (error - prev_error) + (1 - self.config.KD) * derivative
        
        # Calculate new alpha with clamping
        pid_output = self.config.KP * error + integral + derivative
        alpha = np.clip(alpha + pid_output, 0.0, self.config.MAX_ALPHA)
        
        return alpha, integral, derivative, error
    
    def _fast_temperature_calc(self, coefficient):
        """Fast temperature calculation using pre-computed constants"""
        ratio = coefficient * self.max_alpha_inv
        return self.config.BASE_TEMP - self.temp_diff * ratio
    
    def _early_stop_check(self, generated_text):
        """Check if generation should stop early - optimized"""
        return (self.stop_pattern.search(generated_text) is not None or 
                self.final_answer_pattern.search(generated_text) is not None)
    
    @torch.inference_mode()
    def generate(self, prompt, max_new_tokens=None, debug=False):
        """Generate text using RASPID with PID-controlled steering - optimized"""
        if max_new_tokens is None:
            max_new_tokens = self.config.MAX_TOKENS
        
        if debug:
            print(f"\n=== RASPID GENERATION ===")
            print(f"Prompt: '{prompt}'")
            print(f"Max tokens: {max_new_tokens}")
            print("--- GENERATION TRACE ---")
            print("step | steering | p_red |  err  |   α   |   I   |   D   | temp | token")
        
        # Initialize generation state
        input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.config.DEVICE).input_ids[0]
        output_ids = input_ids.clone()
        past_key_values = None
        
        # Initialize PID controller state
        alpha = integral = derivative = prev_error = 0.0
        chunk_hidden = None
        tokens_in_chunk = 0
        steering_active = False
        steering_start = 0
        
        # Initialize repetition detection
        last_token = None
        repetition_count = 0
        generated_text = ""
        
        # Pre-allocate arrays for batch operations where possible
        device = self.config.DEVICE
        
        # Generation loop
        if not debug:
            pbar = tqdm(range(max_new_tokens), desc="RASPID generation", leave=False)
        else:
            pbar = range(max_new_tokens)
        
        for step in pbar:
            generation_length = output_ids.size(0) - input_ids.size(0)
            
            # Update steering state
            steering_active, steering_start = self._should_apply_steering(
                generation_length, steering_active, steering_start
            )
            
            # Apply control vector
            coefficient = alpha if steering_active else 0.0
            self.control_model.set_control(self.control_vector, coeff=coefficient)
            
            # Forward pass - optimized
            if generation_length == 0:
                # First step: process entire prompt
                outputs = self.control_model(
                    input_ids=output_ids.unsqueeze(0),
                    use_cache=True,
                    output_hidden_states=True
                )
            else:
                # Subsequent steps: process only new token
                outputs = self.control_model(
                    input_ids=output_ids[-1:].unsqueeze(0),
                    past_key_values=past_key_values,
                    use_cache=True,
                    output_hidden_states=True
                )
            
            past_key_values = outputs.past_key_values
            logits = outputs.logits[0, -1]
            hidden_state = outputs.hidden_states[self.config.EMB_LAYER][0, -1]
            
            # Handle numerical stability - optimized
            logits = self._handle_numerical_stability(logits)
            hidden_state = self._handle_numerical_stability(hidden_state)
            
            # Check for token repetition
            current_token = output_ids[-1].item()
            if current_token == last_token:
                repetition_count += 1
                if repetition_count >= self.config.MAX_REPEAT:
                    if debug:
                        print(f"[INFO] Maximum repetition reached, stopping generation")
                    break
            else:
                repetition_count = 0
                last_token = current_token
            
            # Update chunk tracking - optimized
            hidden_norm = torch.norm(hidden_state)
            if hidden_norm > 1e-8:
                hidden_normalized = hidden_state / hidden_norm
            else:
                hidden_normalized = hidden_state
            
            chunk_hidden = hidden_normalized if chunk_hidden is None else chunk_hidden + hidden_normalized
            tokens_in_chunk += 1
            
            # Classifier prediction and PID update - optimized
            p_redundant = error = 0.0
            if tokens_in_chunk >= self.chunk_size:
                try:
                    # Prepare classifier input - optimized
                    classifier_input = (chunk_hidden / self.chunk_size).cpu().numpy().reshape(1, -1)
                    
                    # Handle potential NaN values
                    if np.isnan(classifier_input).any():
                        classifier_input = np.nan_to_num(classifier_input)
                    
                    # Get classifier prediction - single call
                    raw_score = self.classifier.decision_function(classifier_input)[0]
                    raw_score = self._scale_extreme_classifier_output(raw_score)
                    
                    # Fast sigmoid approximation for better performance
                    if abs(raw_score) < 5.0:  # Use exact sigmoid for reasonable values
                        p_redundant = 1.0 / (1.0 + math.exp(-raw_score))
                    else:  # Use approximation for extreme values
                        p_redundant = 1.0 if raw_score > 0 else 0.0
                    
                    # Update PID controller
                    alpha, integral, derivative, error = self._update_pid_controller(
                        p_redundant, alpha, integral, derivative, prev_error
                    )
                    prev_error = error
                    
                    # Reset chunk
                    chunk_hidden = None
                    tokens_in_chunk = 0
                    
                except Exception as e:
                    if debug:
                        print(f"[ERROR] Classifier/PID error: {e}")
                    chunk_hidden = hidden_normalized
                    tokens_in_chunk = 1
            
            # Temperature adjustment and sampling - optimized
            temperature = self._fast_temperature_calc(coefficient)
            
            try:
                # Safe sampling with optimized operations
                logits_scaled = logits / temperature
                logits_clamped = torch.clamp(logits_scaled, -50, 50)  # Smaller range for speed
                probabilities = torch.softmax(logits_clamped, dim=-1)
                
                # Sample token
                next_token = torch.multinomial(probabilities, 1).item()
                
            except Exception as e:
                if debug:
                    print(f"[ERROR] Sampling error: {e}, using argmax")
                next_token = torch.argmax(logits).item()
            
            # Update output - optimized
            token_str = self.tokenizer.decode([next_token], skip_special_tokens=True)
            generated_text += token_str
            output_ids = torch.cat([output_ids, torch.tensor([next_token], device=device)])
            
            # Debug output
            if debug:
                token_display = token_str.replace("\n", "\\n")
                print(f"{generation_length:4d} | {int(steering_active):8d} | {p_redundant:5.3f} | {error:5.3f} | "
                      f"{alpha:5.3f} | {integral:5.3f} | {derivative:5.3f} | {temperature:5.3f} | '{token_display}'")
            
            # Early stopping check - optimized
            if self._early_stop_check(generated_text):
                if debug:
                    print(f"[INFO] Stop condition met")
                break
            
            # Update progress bar
            if not debug and hasattr(pbar, 'set_postfix'):
                if step % 10 == 0:  # Update every 10 steps for performance
                    pbar.set_postfix({
                        'α': f'{alpha:.2f}', 
                        'p_red': f'{p_redundant:.2f}',
                        'steering': steering_active
                    })
        
        if not debug:
            pbar.close()
        
        if debug:
            print("--- END TRACE ---")
            print(f"Total tokens generated: {output_ids.size(0) - input_ids.size(0)}")
        
        final_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
        tokens_generated = output_ids.size(0) - input_ids.size(0)
        
        return final_text, tokens_generated


class BaselineGenerator:
    """Baseline generator without steering - supports both HuggingFace and vLLM backends"""
    
    def __init__(self, config, tokenizer, model=None, use_vllm=False):
        self.config = config
        self.tokenizer = tokenizer
        self.model = model
        self.use_vllm = use_vllm
        self.vllm_engine = None
        
        if use_vllm:
            self._initialize_vllm()
    
    def _initialize_vllm(self):
        """Initialize vLLM engine for faster inference"""
        try:
            from vllm import LLM, SamplingParams
            
            print("🚀 Initializing vLLM engine...")
            self.vllm_engine = LLM(
                model=self.config.MODEL_NAME,
                dtype=self.config.DTYPE,
                trust_remote_code=True,
                gpu_memory_utilization=0.8,  # Leave some GPU memory for other operations
                max_model_len=8192,  # Adjust based on your model's context length
                tensor_parallel_size=1,  # Adjust for multi-GPU setups
            )
            
            # Pre-configure sampling parameters
            self.sampling_params = SamplingParams(
                temperature=0.6,
                top_p=0.9,
                repetition_penalty=1.2,
                max_tokens=self.config.MAX_TOKENS,
                stop=["\\boxed{", "Final answer:"],  # Stop tokens for math problems
            )
            
            print("✅ vLLM engine initialized successfully")
            
        except ImportError:
            print("❌ vLLM not available. Install with: pip install vllm")
            print("🔄 Falling back to HuggingFace transformers")
            self.use_vllm = False
        except Exception as e:
            print(f"❌ Failed to initialize vLLM: {e}")
            print("🔄 Falling back to HuggingFace transformers")
            self.use_vllm = False
    
    @torch.inference_mode()
    def generate(self, prompt, max_new_tokens=None):
        """Generate text using either vLLM or HuggingFace backend"""
        if self.use_vllm and self.vllm_engine:
            return self._generate_vllm(prompt, max_new_tokens)
        else:
            return self._generate_hf(prompt, max_new_tokens)
    
    def _generate_vllm(self, prompt, max_new_tokens=None):
        """Generate using vLLM engine"""
        if max_new_tokens and max_new_tokens != self.config.MAX_TOKENS:
            # Create custom sampling params for this generation
            sampling_params = SamplingParams(
                temperature=0.6,
                top_p=0.9,
                repetition_penalty=1.2,
                max_tokens=max_new_tokens,
                stop=["\\boxed{", "Final answer:"],
            )
        else:
            sampling_params = self.sampling_params
        
        # Generate with vLLM
        outputs = self.vllm_engine.generate([prompt], sampling_params)
        
        # Extract results
        generated_text = outputs[0].outputs[0].text
        tokens_generated = len(outputs[0].outputs[0].token_ids)
        
        # Combine prompt and generated text
        full_text = prompt + generated_text
        
        return full_text, tokens_generated
    
    def _generate_hf(self, prompt, max_new_tokens=None):
        """Generate using HuggingFace transformers (fallback)"""
        if max_new_tokens is None:
            max_new_tokens = self.config.MAX_TOKENS
        
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.config.DEVICE)
        
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            repetition_penalty=1.2,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        
        tokens_generated = outputs.shape[1] - inputs.input_ids.shape[1]
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return generated_text, tokens_generated
    
    def generate_batch(self, prompts, max_new_tokens=None):
        """Generate for multiple prompts efficiently (vLLM only)"""
        if not self.use_vllm or not self.vllm_engine:
            # Fallback to sequential generation for HF
            results = []
            for prompt in prompts:
                result = self._generate_hf(prompt, max_new_tokens)
                results.append(result)
            return results
        
        # Batch generation with vLLM
        if max_new_tokens and max_new_tokens != self.config.MAX_TOKENS:
            sampling_params = SamplingParams(
                temperature=0.6,
                top_p=0.9,
                repetition_penalty=1.2,
                max_tokens=max_new_tokens,
                stop=["\\boxed{", "Final answer:"],
            )
        else:
            sampling_params = self.sampling_params
        
        outputs = self.vllm_engine.generate(prompts, sampling_params)
        
        results = []
        for i, output in enumerate(outputs):
            generated_text = output.outputs[0].text
            tokens_generated = len(output.outputs[0].token_ids)
            full_text = prompts[i] + generated_text
            results.append((full_text, tokens_generated))
        
        return results


class AnswerExtractor:
    """Utility for extracting numerical answers from generated text"""
    
    @staticmethod
    def extract_boxed_answer(text):
        """Extract the content of the last \\boxed{} occurrence"""
        matches = list(re.finditer(r"\\boxed\{([^}]+)\}", text))
        return matches[-1].group(1).strip() if matches else ""
    
    @staticmethod
    def extract_reference_answer(gsm8k_answer):
        """Extract reference answer from GSM8K format"""
        return gsm8k_answer.split('#### ')[1] if '#### ' in gsm8k_answer else gsm8k_answer


class GSM8KEvaluator:
    """Evaluator for GSM8K mathematical reasoning tasks"""
    
    def __init__(self, config, raspid_generator, baseline_generator):
        self.config = config
        self.raspid_generator = raspid_generator
        self.baseline_generator = baseline_generator
        self.answer_extractor = AnswerExtractor()
    
    def evaluate(self, n_problems, max_tokens=None, debug=False, use_batch=True):
        """Evaluate both RASPID and baseline on GSM8K problems - optimized batch processing"""
        if max_tokens is None:
            max_tokens = self.config.MAX_TOKENS
        
        print("📚 Loading GSM8K test set...")
        # Load GSM8K test set (starting from problem 1000)
        gsm8k_test = load_dataset("gsm8k", "main")["test"].select(range(1000, 1000 + n_problems))
        
        # Prepare prompts and examples
        prompts = []
        examples = []
        for example in tqdm(gsm8k_test, desc="Preparing prompts"):
            question = example["question"].strip()
            prompt = f"{question}\n\nAnswer step by step and end with: Final answer: \\boxed{{numeric_value}}"
            prompts.append(prompt)
            examples.append(example)
        
        print(f"🎯 Evaluating {len(prompts)} problems...")
        
        # Phase 1: Generate ALL baseline responses first
        print("🤖 Phase 1: Generating baseline responses...")
        if use_batch and hasattr(self.baseline_generator, 'use_vllm') and self.baseline_generator.use_vllm:
            print("   Using vLLM batch generation for maximum efficiency...")
            baseline_results = self.baseline_generator.generate_batch(prompts, max_tokens)
            baseline_total_tokens = sum(tokens for _, tokens in baseline_results)
            print(f"   ✅ Baseline batch completed: {baseline_total_tokens} total tokens")
        else:
            print("   Using sequential generation...")
            baseline_results = []
            baseline_total_tokens = 0
            for i, prompt in enumerate(tqdm(prompts, desc="   Baseline generation")):
                result = self.baseline_generator.generate(prompt, max_tokens)
                baseline_results.append(result)
                baseline_total_tokens += result[1]
                
                # Progress update every 10 problems
                if (i + 1) % 10 == 0:
                    print(f"   Progress: {i+1}/{len(prompts)} problems, {baseline_total_tokens} tokens so far")
        
        # Phase 2: Generate ALL RASPID responses
        print("🧠 Phase 2: Generating RASPID responses...")
        raspid_results = []
        raspid_total_tokens = 0
        
        # Pre-warm the model (first generation is often slower)
        if len(prompts) > 0:
            print("   Pre-warming RASPID generator...")
            _, _ = self.raspid_generator.generate(prompts[0], min(max_tokens, 100), debug=False)
        
        for i, prompt in enumerate(tqdm(prompts, desc="   RASPID generation")):
            result = self.raspid_generator.generate(prompt, max_tokens, debug and i == 0)  # Debug only first
            raspid_results.append(result)
            raspid_total_tokens += result[1]
            
            # Progress update every 10 problems
            if (i + 1) % 10 == 0:
                avg_tokens_baseline = baseline_total_tokens // (i + 1)
                avg_tokens_raspid = raspid_total_tokens // (i + 1)
                efficiency = (avg_tokens_baseline - avg_tokens_raspid) / avg_tokens_baseline * 100
                print(f"   Progress: {i+1}/{len(prompts)} problems")
                print(f"   Average tokens - Baseline: {avg_tokens_baseline}, RASPID: {avg_tokens_raspid}")
                print(f"   Current efficiency: {efficiency:.1f}% token savings")
        
        # Phase 3: Process and analyze results
        print("📊 Phase 3: Processing results...")
        results = []
        
        for i, example in enumerate(tqdm(examples, desc="   Extracting answers")):
            baseline_text, baseline_tokens = baseline_results[i]
            raspid_text, raspid_tokens = raspid_results[i]
            
            # Extract answers
            baseline_answer = self.answer_extractor.extract_boxed_answer(baseline_text)
            raspid_answer = self.answer_extractor.extract_boxed_answer(raspid_text)
            reference_answer = self.answer_extractor.extract_reference_answer(example["answer"])
            
            results.append({
                "problem_id": i,
                "question": example["question"],
                "reference_answer": example["answer"],
                "reference_correct": reference_answer,
                "baseline_correct": baseline_answer,
                "raspid_correct": raspid_answer,
                "baseline_tokens": baseline_tokens,
                "raspid_tokens": raspid_tokens,
                "baseline_txt": baseline_text,
                "raspid_txt": raspid_text,
                "baseline_is_correct": baseline_answer == reference_answer,
                "raspid_is_correct": raspid_answer == reference_answer,
            })
        
        # Final statistics
        results_df = pd.DataFrame(results)
        
        baseline_accuracy = results_df['baseline_is_correct'].mean()
        raspid_accuracy = results_df['raspid_is_correct'].mean()
        token_efficiency = (baseline_total_tokens - raspid_total_tokens) / baseline_total_tokens
        avg_baseline_tokens = baseline_total_tokens / len(prompts)
        avg_raspid_tokens = raspid_total_tokens / len(prompts)
        
        print(f"\n🎉 EVALUATION COMPLETE!")
        print(f"   📈 Baseline Accuracy: {baseline_accuracy:.3f} ({results_df['baseline_is_correct'].sum()}/{len(results_df)})")
        print(f"   🧠 RASPID Accuracy: {raspid_accuracy:.3f} ({results_df['raspid_is_correct'].sum()}/{len(results_df)})")
        print(f"   ⚡ Token Efficiency: {token_efficiency:.3f} ({token_efficiency*100:.1f}% savings)")
        print(f"   📊 Average tokens per problem:")
        print(f"      Baseline: {avg_baseline_tokens:.1f}")
        print(f"      RASPID: {avg_raspid_tokens:.1f}")
        print(f"      Savings: {avg_baseline_tokens - avg_raspid_tokens:.1f} tokens per problem")
        print(f"   📝 Total tokens:")
        print(f"      Baseline: {baseline_total_tokens:,}")
        print(f"      RASPID: {raspid_total_tokens:,}")
        print(f"      Total savings: {baseline_total_tokens - raspid_total_tokens:,}")
        
        print("🧠 Generating RASPID responses...")
        raspid_results = []
        for prompt in tqdm(prompts, desc="RASPID generation"):
            result = self.raspid_generator.generate(prompt, max_tokens, debug)
            raspid_results.append(result)
        
        # Process results
        for i, example in enumerate(examples):
            baseline_text, baseline_tokens = baseline_results[i]
            raspid_text, raspid_tokens = raspid_results[i]
            
            # Extract answers
            baseline_answer = self.answer_extractor.extract_boxed_answer(baseline_text)
            raspid_answer = self.answer_extractor.extract_boxed_answer(raspid_text)
            reference_answer = self.answer_extractor.extract_reference_answer(example["answer"])
            
            results.append({
                "reference_answer": example["answer"],
                "reference_correct": reference_answer,
                "baseline_correct": baseline_answer,
                "raspid_correct": raspid_answer,
                "baseline_tokens": baseline_tokens,
                "raspid_tokens": raspid_tokens,
                "baseline_txt": baseline_text,
                "raspid_txt": raspid_text,
            })
            
            baseline_total_tokens += baseline_tokens
            raspid_total_tokens += raspid_tokens
        
        print(f'📊 Total token usage - Baseline: {baseline_total_tokens}, RASPID: {raspid_total_tokens}')
        return pd.DataFrame(results)

In [None]:
def main():
    """Main execution function with optimized evaluation pipeline"""
    print("🚀 Initializing RASPID...")
    
    # Initialize configuration
    config = RASPIDConfig()
    
    # Load models and tokenizer
    print(f"Loading model: {config.MODEL_NAME}")
    tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME, trust_remote_code=True)
    
    # Apply memory optimizations if enabled
    model_kwargs = {
        "torch_dtype": config.DTYPE,
        "device_map": "auto" if config.DEVICE == "cuda" else None
    }
    
    if config.MEMORY_EFFICIENT and config.DEVICE == "cuda":
        model_kwargs.update({
            "attn_implementation": "flash_attention_2",  # Use Flash Attention if available
            "low_cpu_mem_usage": True,
        })
    
    try:
        base_model = AutoModelForCausalLM.from_pretrained(
            config.MODEL_NAME, **model_kwargs
        ).eval()
    except Exception as e:
        print(f"⚠️  Flash Attention not available: {e}")
        # Fallback to standard attention
        model_kwargs.pop("attn_implementation", None)
        base_model = AutoModelForCausalLM.from_pretrained(
            config.MODEL_NAME, **model_kwargs
        ).eval()
    
    control_model = ControlModel(base_model, [config.EMB_LAYER])
    
    # Apply torch.compile for potential speedup (PyTorch 2.0+)
    if config.ENABLE_COMPILE:
        try:
            import torch._dynamo
            print("🔥 Applying torch.compile for potential speedup...")
            base_model = torch.compile(base_model, mode="reduce-overhead")
            print("✅ torch.compile applied successfully")
        except Exception as e:
            print(f"⚠️  torch.compile not available or failed: {e}")
    
    # Load and split labeled data
    print("📚 Loading labeled data...")
    try:
        df_all = pd.read_csv(config.LABELED_CSV)
        print(f"   Loaded {len(df_all)} labeled examples")
    except FileNotFoundError:
        print(f"❌ Could not find {config.LABELED_CSV}")
        print("   Please ensure the labeled dataset is available")
        return
    
    df_ctrl = df_all.iloc[:1000]  # First 1000 for training
    df_eval = df_all.iloc[1000:1200]  # Last 200 reserved for evaluation
    
    required_thoughts = df_ctrl["required_thoughts"].fillna("")
    redundant_thoughts = df_ctrl["redundant_thoughts"].fillna("")
    
    print(f"   Training set: {len(df_ctrl)} examples")
    print(f"   Required thoughts: {len([x for x in required_thoughts if x])}")
    print(f"   Redundant thoughts: {len([x for x in redundant_thoughts if x])}")
    
    # Train chunk classifier
    print("🎯 Training chunk classifier...")
    classifier_trainer = ChunkClassifier(config, tokenizer, base_model)
    best_classifier, best_chunk_size = classifier_trainer.train(required_thoughts, redundant_thoughts)
    
    # Build control vector
    print("🎛️  Building control vector...")
    vector_builder = ControlVectorBuilder(config, tokenizer, base_model)
    control_vector = vector_builder.build_control_vector(required_thoughts, redundant_thoughts)
    
    # Initialize generators
    print("⚙️  Initializing generators...")
    raspid_generator = RASPIDGenerator(
        config, tokenizer, base_model, control_model,
        best_classifier, best_chunk_size, control_vector
    )
    
    # Initialize baseline generator with vLLM support
    use_vllm = True  # Set to False to use HuggingFace transformers
    baseline_generator = BaselineGenerator(config, tokenizer, base_model, use_vllm=use_vllm)
    
    # Evaluate on GSM8K with optimized pipeline
    print("🧮 Starting GSM8K evaluation with optimized pipeline...")
    evaluator = GSM8KEvaluator(config, raspid_generator, baseline_generator)
    
    # Run evaluation
    results_df = evaluator.evaluate(
        n_problems=100, 
        debug=False,  # Set to True to debug first RASPID generation
        use_batch=use_vllm
    )
    
    # Save results with timestamp
    timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    results_filename = f'results_df_100_{timestamp}.csv'
    results_df.to_csv(results_filename, index=False)
    print(f"💾 Results saved to {results_filename}")
    
    # Generate detailed analysis
    print("\n" + "="*60)
    print("📊 DETAILED EVALUATION RESULTS")
    print("="*60)
    
    baseline_accuracy = results_df['baseline_is_correct'].mean()
    raspid_accuracy = results_df['raspid_is_correct'].mean()
    token_efficiency = (results_df['baseline_tokens'].sum() - results_df['raspid_tokens'].sum()) / results_df['baseline_tokens'].sum()
    
    print(f"📈 Accuracy Comparison:")
    print(f"   Baseline:  {baseline_accuracy:.3f} ({results_df['baseline_is_correct'].sum()}/{len(results_df)})")
    print(f"   RASPID:    {raspid_accuracy:.3f} ({results_df['raspid_is_correct'].sum()}/{len(results_df)})")
    print(f"   Difference: {raspid_accuracy - baseline_accuracy:+.3f}")
    
    print(f"\n⚡ Token Efficiency:")
    print(f"   Overall savings: {token_efficiency:.3f} ({token_efficiency*100:.1f}%)")
    print(f"   Avg tokens/problem:")
    print(f"      Baseline: {results_df['baseline_tokens'].mean():.1f}")
    print(f"      RASPID:   {results_df['raspid_tokens'].mean():.1f}")
    print(f"      Savings:  {results_df['baseline_tokens'].mean() - results_df['raspid_tokens'].mean():.1f}")
    
    print(f"\n📊 Token Distribution:")
    print(f"   Baseline - Min: {results_df['baseline_tokens'].min()}, Max: {results_df['baseline_tokens'].max()}, Std: {results_df['baseline_tokens'].std():.1f}")
    print(f"   RASPID   - Min: {results_df['raspid_tokens'].min()}, Max: {results_df['raspid_tokens'].max()}, Std: {results_df['raspid_tokens'].std():.1f}")
    
    # Performance analysis
    problems_where_both_correct = results_df['baseline_is_correct'] & results_df['raspid_is_correct']
    if problems_where_both_correct.any():
        avg_savings_when_both_correct = (
            results_df[problems_where_both_correct]['baseline_tokens'].mean() - 
            results_df[problems_where_both_correct]['raspid_tokens'].mean()
        ) / results_df[problems_where_both_correct]['baseline_tokens'].mean()
        print(f"\n🎯 When both methods are correct ({problems_where_both_correct.sum()} problems):")
        print(f"   Token savings: {avg_savings_when_both_correct:.3f} ({avg_savings_when_both_correct*100:.1f}%)")
    
    print(f"\n🚀 Generation Method Performance:")
    if use_vllm and baseline_generator.use_vllm:
        print(f"   ✅ Baseline generation: vLLM (batch processing)")
    else:
        print(f"   🔄 Baseline generation: HuggingFace transformers (sequential)")
    print(f"   🧠 RASPID generation: Custom PID-controlled (sequential)")
    
    print(f"\n🎉 Evaluation completed successfully!")
    print("="*60)


if __name__ == "__main__":
    main()

🚀 Initializing RASPID...
Loading model: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
⚠️  Flash Attention not available: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.
🔥 Applying torch.compile for potential speedup...
✅ torch.compile applied successfully
📚 Loading labeled data...
   Loaded 1200 labeled examples
   Training set: 1000 examples
   Required thoughts: 999
   Redundant thoughts: 998
🎯 Training chunk classifier...
Training chunk classifier...
Testing chunk size: 16


Extracting embeddings:   0%|          | 0/1590 [00:00<?, ?it/s]



Training classifier (cs=16):   0%|          | 0/500 [00:00<?, ?it/s]