In [None]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModel,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    AdamW,
    get_linear_schedule_with_warmup,
    T5ForConditionalGeneration,
    T5Tokenizer
)
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset
from tqdm.auto import tqdm
import re
from torch.optim import Adam
import random
from datetime import datetime
import traceback

In [11]:
def set_seed(seed=11):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed()

In [12]:
# Device configuration - M1 Mac specific
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple M1 MPS device")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device")
else:
    device = torch.device("cpu")
    print("Using CPU device")

Using Apple M1 MPS device


In [13]:
MAX_LENGTH = 768
BATCH_SIZE = 4
LEARNING_RATE = 2e-5
EPOCHS = 3
COT_PROMPT = "Let's solve this step-by-step. To find the answer, I'll break down the problem into smaller parts."
MODEL_NAME = "t5-base"
MAX_SAMPLES = 400

In [14]:
# Utility functions for extracting final answers and CoT steps
def extract_final_answer(text):
    """Extract the final answer from generated text"""
    # Look for common patterns that indicate the final answer
    patterns = [
        r"answer\s+is\s+([\d\.\-\+\/\*]+)",  # "answer is 42"
        r"answer\s*:\s*([\d\.\-\+\/\*]+)",   # "answer: 42"
        r"final\s+answer\s*[:\s]\s*([\d\.\-\+\/\*]+)", # "final answer: 42" 
        r"therefore[,\s]+(?:the\s+)?(?:answer\s+is\s+)?([\d\.\-\+\/\*]+)",  # "therefore, 42"
        r"thus[,\s]+(?:the\s+)?(?:answer\s+is\s+)?([\d\.\-\+\/\*]+)",      # "thus, 42"
        r"=\s*([\d\.\-\+\/\*]+)(?:\s*$|\s*\.\s*$)"  # "= 42" at the end
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1).strip()
    
    # Look for the last number in the text as a fallback
    numbers = re.findall(r"(\d+(?:\.\d+)?)", text)
    if numbers:
        return numbers[-1]
    
    return ""

def extract_cot_steps(answer_text):
    # Remove the final answer part
    final_answer_match = re.search(r"The answer is(.*?)$", answer_text, re.DOTALL)
    if final_answer_match:
        cot_text = answer_text[:final_answer_match.start()].strip()
    else:
        # If no "The answer is" pattern, assume the last line is the answer
        lines = answer_text.strip().split("\n")
        cot_text = "\n".join(lines[:-1]) if len(lines) > 1 else ""
    
    # Split into steps
    steps = [step.strip() for step in cot_text.split("\n") if step.strip()]
    return steps

# Safe device transfer function for M1 MPS
def to_device(tensor_or_module):
    """Safely move tensors or modules to the selected device"""
    if tensor_or_module is None:
        return None
    
    device = torch.device("cuda" if torch.cuda.is_available() else 
                         "mps" if torch.backends.mps.is_available() else "cpu")
    
    try:
        return tensor_or_module.to(device)
    except Exception as e:
        print(f"Warning: Could not move to {device}: {e}")
        return tensor_or_module

In [15]:
class GSM8KDataset(Dataset):
    def __init__(self, split="train", tokenizer=None, max_length=768, max_samples=None):
        self.data = load_dataset("gsm8k", "main")[split]
        if max_samples:
            self.data = self.data.select(range(min(max_samples, len(self.data))))
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.processed_data = self.preprocess_data()
        
    def preprocess_data(self):
        processed = []
        for item in tqdm(self.data, desc="Preprocessing data"):
            question = item["question"]
            answer_with_cot = item["answer"]
            
            # Extract the CoT steps and the final answer
            final_answer = extract_final_answer(answer_with_cot)
            cot_steps = extract_cot_steps(answer_with_cot)
            
            # Format for T5 training - improved prompt to guide the model
            formatted_question = f"Solve this math problem step-by-step: {question} {COT_PROMPT}"
            
            processed.append({
                "question": question,
                "formatted_question": formatted_question,
                "cot_steps": cot_steps,
                "final_answer": final_answer,
                "full_answer": answer_with_cot
            })
        return processed
    
    def __len__(self):
        return len(self.processed_data)
    
    def __getitem__(self, idx):
        item = self.processed_data[idx]
        
        if self.tokenizer:
            # Prepare input with task-specific prefix for T5
            input_text = item["formatted_question"]
            target_text = item["full_answer"]
            
            # Improved tokenization with more balanced token allocation
            try:
                inputs = self.tokenizer(
                    input_text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length // 3,
                    return_tensors="pt"
                )
                
                targets = self.tokenizer(
                    target_text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length * 2 // 3,
                    return_tensors="pt"
                )
                
                # Ensure raw_answer is not None or empty
                raw_answer = item["final_answer"]
                if not raw_answer:
                    # Try to extract it from full answer
                    raw_answer = extract_final_answer(item["full_answer"])
                    if not raw_answer:
                        raw_answer = "unknown"  # Fallback
                
                return {
                    "input_ids": inputs.input_ids.squeeze(),
                    "attention_mask": inputs.attention_mask.squeeze(),
                    "labels": targets.input_ids.squeeze(),
                    "raw_question": item["question"],
                    "raw_cot": item["cot_steps"],
                    "raw_answer": raw_answer
                }
            except Exception as e:
                print(f"Error tokenizing item {idx}: {e}")
                # Improved error handling with logging
                dummy_tensor = torch.zeros(self.max_length, dtype=torch.long)
                return {
                    "input_ids": dummy_tensor,
                    "attention_mask": dummy_tensor,
                    "labels": dummy_tensor,
                    "raw_question": item["question"],
                    "raw_cot": [],
                    "raw_answer": "unknown"
                }
        else:
            return item

In [31]:
class CoTGenerator:
    def __init__(self, model_name="t5-base", local_dir="./models/t5_base_cache"):
        self.model_name = model_name
        self.local_dir = local_dir
        
        # Create the directory if it doesn't exist
        os.makedirs(self.local_dir, exist_ok=True)
        
        print(f"Loading model {model_name}...")
        
        # Check if model is already saved locally
        if os.path.exists(os.path.join(self.local_dir, "pytorch_model.bin")) and \
           os.path.exists(os.path.join(self.local_dir, "tokenizer_config.json")):
            print(f"Found existing model at {self.local_dir}. Loading locally...")
            self._load_local_model()
        else:
            print(f"Model not found locally. Downloading {model_name}...")
            self._download_model()
    
    def _download_model(self):
        try:
            # Download tokenizer and save it immediately
            self.tokenizer = T5Tokenizer.from_pretrained(
                self.model_name,
                cache_dir=self.local_dir,
                use_fast=True
            )
            print(f"Tokenizer downloaded and saved to {self.local_dir}")
            
            # Download model and save it immediately
            try:
                device = torch.device("cuda" if torch.cuda.is_available() else 
                                    "mps" if torch.backends.mps.is_available() else "cpu")
                self.model = T5ForConditionalGeneration.from_pretrained(
                    self.model_name,
                    cache_dir=self.local_dir,
                    low_cpu_mem_usage=True,
                    torch_dtype=torch.float16 if device.type != "cpu" else torch.float32
                )
                self.model = self.model.to(device)
                print(f"Model downloaded and moved to {device}")
            except Exception as e:
                print(f"Error loading model to device: {e}")
                print("Falling back to CPU")
                self.model = T5ForConditionalGeneration.from_pretrained(
                    self.model_name,
                    cache_dir=self.local_dir
                )
                print(f"Model downloaded (CPU version)")
        except Exception as e:
            print(f"Error downloading model: {e}")
            raise e
    
    def _load_local_model(self):
        try:
            # Load locally saved tokenizer
            self.tokenizer = T5Tokenizer.from_pretrained(self.local_dir)
            
            # Load locally saved model
            try:
                device = torch.device("cuda" if torch.cuda.is_available() else 
                                    "mps" if torch.backends.mps.is_available() else "cpu")
                self.model = T5ForConditionalGeneration.from_pretrained(
                    self.local_dir,
                    torch_dtype=torch.float16 if device.type != "cpu" else torch.float32
                )
                self.model = self.model.to(device)
                print(f"Model loaded from {self.local_dir} and moved to {device}")
            except Exception as e:
                print(f"Error loading model to device: {e}")
                print("Falling back to CPU")
                self.model = T5ForConditionalGeneration.from_pretrained(self.local_dir)
                print(f"Model loaded from {self.local_dir} (CPU version)")
        except Exception as e:
            print(f"Error loading local model: {e}")
            print("Will attempt to download from source...")
            self._download_model()
    
    def generate_step(self, question, previous_steps=None, max_length=256):
        """Generate a single reasoning step based on the question and previous steps"""
        if previous_steps is None:
            previous_steps = []
        
        # Construct prompt with previous steps - REVISED
        previous_steps_text = ""
        if previous_steps:
            previous_steps_text = "Steps so far:\n" + "\n".join([f"Step {i+1}: {step}" for i, step in enumerate(previous_steps)])
            previous_steps_text += "\nNext step:"
        
        # Build the full prompt - REVISED
        if previous_steps:
            input_text = f"Solve this math problem: {question}\n{previous_steps_text}"
        else:
            input_text = f"Solve this math problem: {question}\nStart step-by-step reasoning:"
        
        try:
            inputs = self.tokenizer(
                input_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=max_length // 2
            )
            device = self.model.device
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            # Generate with clear stopping criteria - REVISED
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=max_length,
                    num_return_sequences=1,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    top_k=40,
                    num_beams=3,
                    early_stopping=True
                )
            
            # Decode and clean up the output - REVISED
            decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Only take content after the prompt
            prompt_end = input_text.split("\n")[-1]
            if prompt_end in decoded_output:
                step = decoded_output.split(prompt_end, 1)[1].strip()
            else:
                step = decoded_output.strip()
            
            # Check if this is a final answer step
            is_final_step = ("answer is" in step.lower() or 
                             "therefore" in step.lower() or 
                             "thus" in step.lower() or
                             "final answer" in step.lower())
            
            return {
                "step": step,
                "is_final_step": is_final_step
            }
        except Exception as e:
            print(f"Error generating step: {e}")
            return {
                "step": "Error generating step",
                "is_final_step": True
            }
    
    def evaluate_step(self, step, question, reflection_module=None):
        """Evaluate if a generated step is good or needs refinement"""
        if reflection_module:
            # Use the reflection module if provided
            score = reflection_module.evaluate_step(step, question)
            return score >= 0.5  # Return True if score is good enough
        
        # Basic heuristic evaluation if no reflection module
        # Check if the step contains numbers and mathematical operations
        has_numbers = bool(re.search(r'\d+', step))
        has_math_ops = bool(re.search(r'[+\-*/=]', step))
        reasonable_length = 10 <= len(step.split()) <= 100
        
        return has_numbers and has_math_ops and reasonable_length
    
    def refine_step(self, question, step, previous_steps=None):
        """Refine a step that didn't pass evaluation"""
        if previous_steps is None:
            previous_steps = []
        
        # Create a prompt asking for improvement
        previous_steps_text = ""
        if previous_steps:
            previous_steps_text = "Steps so far:\n" + "\n".join([f"{i+1}. {step}" for i, step in enumerate(previous_steps)])
        
        input_text = (f"Solve this math problem step-by-step: {question}\n"
                     f"{previous_steps_text}\n"
                     f"The following step needs improvement: {step}\n"
                     f"Improved step:")
        
        try:
            inputs = self.tokenizer(
                input_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=512
            )
            device = self.model.device
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=256,
                    do_sample=True,
                    temperature=0.6
                )
            
            refined_step = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
            return refined_step
        except Exception as e:
            print(f"Error refining step: {e}")
            return step  # Return original if refinement fails
    
    def generate(self, question, max_length=768, cot_prompt=None, reflection_module=None, max_steps=8):
        """Generate a full chain-of-thought reasoning process step by step"""
        # Improved initial prompt to guide better reasoning
        if not cot_prompt:
            cot_prompt = (
                "Think step-by-step to solve this math problem. "
                "Write one reasoning step at a time. Each step should build on previous steps. "
                "Identify the relevant information, set up equations when needed, and solve them correctly."
            )
        
        enhanced_question = f"{question}\n{cot_prompt}"
        cot_steps = []
        final_answer = None
        
        # Generate steps iteratively
        for step_num in range(max_steps):
            # Generate the next reasoning step
            step_result = self.generate_step(question, cot_steps)
            current_step = step_result["step"]
            
            # Evaluate the step quality
            step_is_good = self.evaluate_step(current_step, question, reflection_module)
            
            # Refine the step if needed (up to 2 attempts)
            refinement_attempts = 0
            while not step_is_good and refinement_attempts < 2:
                refined_step = self.refine_step(question, current_step, cot_steps)
                current_step = refined_step
                step_is_good = self.evaluate_step(current_step, question, reflection_module)
                refinement_attempts += 1
            
            # Add the step to our chain of thought
            cot_steps.append(current_step)
            
            # Check if this is the final step
            if step_result["is_final_step"] or step_num == max_steps - 1:
                break
        
        # Generate a final answer if needed
        if not any("answer is" in step.lower() for step in cot_steps):
            final_step_result = self.generate_final_answer(question, cot_steps)
            cot_steps.append(final_step_result["step"])
        
        # Extract the final answer using the dedicated function
        full_output = "\n".join(cot_steps)
        final_answer = extract_final_answer(full_output)
        
        return {
            "cot_steps": cot_steps,
            "final_answer": final_answer,
            "full_output": full_output
        }
    
    def generate_final_answer(self, question, cot_steps):
        """Generate a final answer step based on previous reasoning steps"""
        steps_text = "\n".join([f"{i+1}. {step}" for i, step in enumerate(cot_steps)])
        
        input_text = (f"Solve this math problem step-by-step: {question}\n"
                     f"Steps so far:\n{steps_text}\n"
                     f"Final answer:")
        
        try:
            inputs = self.tokenizer(
                input_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=512
            )
            device = self.model.device
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=128,
                    do_sample=False,  # More deterministic for final answer
                    num_beams=3
                )
            
            final_step = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
            
            # Make sure it starts with "The answer is" if it doesn't already
            if not any(phrase in final_step.lower() for phrase in ["answer is", "therefore", "thus", "final answer"]):
                final_step = "The answer is " + final_step
                
            return {
                "step": final_step,
                "is_final_step": True
            }
        except Exception as e:
            print(f"Error generating final answer: {e}")
            return {
                "step": "The answer is unknown due to an error.",
                "is_final_step": True
            }
    
    def save(self, path=None):
        save_path = path if path else self.local_dir
        try:
            # Move to CPU before saving to avoid device-specific tensors in saved model
            cpu_model = self.model.to("cpu")
            cpu_model.save_pretrained(save_path)
            self.tokenizer.save_pretrained(save_path)
            print(f"Model saved to {save_path}")
            # Move back to device
            device = torch.device("cuda" if torch.cuda.is_available() else 
                               "mps" if torch.backends.mps.is_available() else "cpu")
            self.model = self.model.to(device)
        except Exception as e:
            print(f"Error saving model: {e}")
    
    def load(self, path=None):
        load_path = path if path else self.local_dir
        try:
            self.tokenizer = T5Tokenizer.from_pretrained(load_path)
            device = torch.device("cuda" if torch.cuda.is_available() else 
                               "mps" if torch.backends.mps.is_available() else "cpu")
            self.model = T5ForConditionalGeneration.from_pretrained(load_path)
            self.model = self.model.to(device)
            print(f"Model loaded from {load_path} and moved to {device}")
        except Exception as e:
            print(f"Error loading model: {e}")

In [32]:
class ReflectionModule(nn.Module):
    def __init__(self, model_name=MODEL_NAME, embedding_dim=768):
        super(ReflectionModule, self).__init__()
        try:
            # Set device property on self first
            self.device = torch.device("cuda" if torch.cuda.is_available() else 
                               "mps" if torch.backends.mps.is_available() else "cpu")
                               
            # Use pre-trained model for better embeddings instead of random
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.base_model = AutoModel.from_pretrained(model_name)
            
            # Freeze base model for stability
            for param in self.base_model.parameters():
                param.requires_grad = False
                
            # Add trainable layers on top for step quality assessment
            self.encoder = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=8, dropout=0.1, batch_first=True),
                num_layers=2
            )
            self.fc1 = nn.Linear(embedding_dim, 256)
            self.fc2 = nn.Linear(256, 64)
            self.fc3 = nn.Linear(64, 1)
            self.dropout = nn.Dropout(0.2)
            
            # Rule-based heuristics lookup (for early training stability)
            self.heuristics = {
                r'\d+\s*[\+\-\*/]\s*\d+': 0.7,  # Mathematical operations
                r'therefore|thus|so': 0.6,      # Logical connections
                r'first|second|third|next': 0.5, # Sequential reasoning
                r'=\s*\d+': 0.6,               # Equation results
                r'equation|formula': 0.5,       # Mathematical concepts
                r'calculate|compute': 0.5,      # Calculation indicators
            }
            
            # Move model to device
            self.to(self.device)
            
        except Exception as e:
            print(f"Error initializing reflection module: {e}")
            self.tokenizer = None
            self.base_model = None
            # Ensure device is set even in error case
            self.device =torch.device("cuda" if torch.cuda.is_available() else 
                         "mps" if torch.backends.mps.is_available() else "cpu")
    
    def forward(self, input_ids, attention_mask):
        try:
            # Get embeddings from base model
            with torch.no_grad():
                outputs = self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    return_dict=True
                )
                embeddings = outputs.last_hidden_state
            
            # Process through transformer encoder
            encoded = self.encoder(embeddings)
            
            # Pool and process through FC layers
            pooled = encoded.mean(dim=1)
            x = F.relu(self.fc1(pooled))
            x = self.dropout(x)
            x = F.relu(self.fc2(x))
            x = self.dropout(x)
            score = torch.sigmoid(self.fc3(x))
            
            return score
        except Exception as e:
            print(f"Error in forward pass: {e}")
            return torch.tensor([[0.5]]).to(self.device)  # Default neutral score
    
    def evaluate_step(self, step_text, question_context):
        """Evaluate a single reasoning step with combined ML and rule-based approach"""
        try:
            # Apply rule-based heuristics as a fallback/supplement
            heuristic_score = self._apply_heuristics(step_text)
            
            # For very short or empty steps, rely more on heuristics
            if len(step_text.strip()) < 10:
                return max(0.2, heuristic_score * 0.8)  # Penalize very short steps but not too harshly
                
            # Prepare input for model
            combined_text = f"Question: {question_context} Step: {step_text}"
            
            # Tokenize
            tokens = self.tokenizer(
                combined_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=MAX_LENGTH
            )
            tokens = {k: v.to(self.device) for k, v in tokens.items()}
            
            # Get score from model
            with torch.no_grad():
                score = self.forward(tokens["input_ids"], tokens["attention_mask"])
            
            # Blend ML score with heuristic score for robustness
            ml_score = score.item()
            blended_score = 0.7 * ml_score + 0.3 * heuristic_score
            
            return blended_score
            
        except Exception as e:
            print(f"Error in evaluate_step: {e}")
            return self._apply_heuristics(step_text)  # Fall back to heuristics
    
    def _apply_heuristics(self, text):
        """Apply rule-based heuristics for evaluation"""
        text = text.lower()
        
        # Start with a base score
        score = 0.4
        
        # Check for various quality indicators
        for pattern, bonus in self.heuristics.items():
            if re.search(pattern, text):
                score = max(score, bonus)
        
        # Penalize extremely short steps
        if len(text.strip()) < 5:
            score *= 0.5
            
        # Bonus for using numbers (important in math reasoning)
        if re.search(r'\d+', text):
            score = min(1.0, score + 0.1)
            
        # Highest bonus for complete calculations with result
        if re.search(r'\d+\s*[\+\-\*/]\s*\d+\s*=\s*\d+', text):
            score = min(1.0, score + 0.2)
            
        return score
    
    def train_step(self, batch, optimizer):
        try:
            self.train()
            optimizer.zero_grad()
            
            # Move batch to device
            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)
            labels = batch["labels"].to(self.device)
            
            # Forward pass
            scores = self.forward(input_ids, attention_mask)
            
            # Calculate loss (MSE against target quality scores)
            loss = F.mse_loss(scores.squeeze(), labels)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
            
            # Update weights
            optimizer.step()
            
            return loss.item()
        except Exception as e:
            print(f"Error in train_step: {e}")
            return 1.0  # High loss as default
    
    def save(self, path):
        """Save the reflection module"""
        try:
            torch.save({
                'encoder_state_dict': self.encoder.state_dict(),
                'fc1_state_dict': self.fc1.state_dict(),
                'fc2_state_dict': self.fc2.state_dict(),
                'fc3_state_dict': self.fc3.state_dict(),
            }, path)
            print(f"Reflection module saved to {path}")
        except Exception as e:
            print(f"Error saving reflection module: {e}")
    
    def load(self, path):
        """Load the reflection module"""
        try:
            checkpoint = torch.load(path)
            self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
            self.fc1.load_state_dict(checkpoint['fc1_state_dict'])
            self.fc2.load_state_dict(checkpoint['fc2_state_dict'])
            self.fc3.load_state_dict(checkpoint['fc3_state_dict'])
            print(f"Reflection module loaded from {path}")
        except Exception as e:
            print(f"Error loading reflection module: {e}")

In [33]:
class RetrievalModule:
    def __init__(self, embedding_model_name=MODEL_NAME):
        try:
            # Set device property on self first
            self.device = torch.device("cuda" if torch.cuda.is_available() else 
                               "mps" if torch.backends.mps.is_available() else "cpu")
            
            # Initialize embedding cache before any usage
            self.embedding_cache = {}
            
            # Use a better embedding model if available
            # Switch to using a encoder-only model instead of seq2seq
            # This will prevent the "decoder_input_ids" error
            self.tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
            
            # Use AutoModelForPreTraining or AutoModelForMaskedLM instead of AutoModel
            # to ensure we get an encoder-only model
            try:
                # First try loading as an encoder model explicitly
                self.model = AutoModel.from_pretrained(
                    embedding_model_name,
                    is_decoder=False,  # Explicitly specify this is not a decoder
                    add_cross_attention=False  # Ensure no cross-attention is expected
                )
            except:
                # Fallback to a model we know works for embeddings
                print(f"Failed to load {embedding_model_name} as encoder-only, falling back to bert-base-uncased")
                self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
                self.model = AutoModel.from_pretrained("bert-base-uncased")
            
            self.model.to(self.device)
            
            # Keep model in eval mode
            self.model.eval()
            
            # Exemplar bank with metadata for better filtering and retrieval
            # (question, embedding, cot_steps, metadata)
            self.exemplar_bank = []
            
            # Domain-specific keyword weights (math-focused)
            self.keyword_weights = {
                'equation': 1.5,
                'equal': 1.5,
                'solve': 1.3,
                'calculate': 1.3,
                'find': 1.2,
                'value': 1.2,
                'total': 1.2,
                'average': 1.2,
                'mean': 1.2,
                'median': 1.2,
                'probability': 1.5,
                'percent': 1.3,
                'increase': 1.2,
                'decrease': 1.2,
                'rate': 1.2,
                'ratio': 1.5,
                'proportion': 1.5,
                'fraction': 1.5,
                'decimal': 1.3
            }
            
        except Exception as e:
            print(f"Error initializing retrieval module: {e}")
            self.tokenizer = None
            self.model = None
            self.exemplar_bank = []
            # Ensure embedding_cache is set even in error case
            self.embedding_cache = {}
            # Ensure device is set even in error case
            self.device = torch.device("cpu")
    
    def compute_embedding(self, text):
        """Compute embedding for text using the model"""
        try:
            # Check cache first
            cache_key = text[:100]  # Use first 100 chars as key to save memory
            if cache_key in self.embedding_cache:
                return self.embedding_cache[cache_key]
                
            if self.model is None or self.tokenizer is None:
                return np.random.randn(768)  # Fallback
                
            # Preprocess text for embedding
            tokens = self.tokenizer(
                text, 
                padding=True, 
                truncation=True, 
                max_length=MAX_LENGTH, 
                return_tensors="pt"
            )
            tokens = {k: v.to(self.device) for k, v in tokens.items()}
            
            # Get embedding from model - using only encoder part
            with torch.no_grad():
                # Using model output without decoder inputs
                outputs = self.model(**tokens, return_dict=True)
                
                # Use CLS token or mean pooling based on model architecture
                if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
                    embedding = outputs.pooler_output.cpu().numpy()[0]
                else:
                    # Use mean of last hidden states as fallback (more universal)
                    embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0]
            
            # Apply keyword weighting
            weighted_embedding = self.apply_keyword_weighting(text, embedding)
            
            # Normalize
            norm = np.linalg.norm(weighted_embedding)
            if norm > 0:
                weighted_embedding = weighted_embedding / norm
                
            # Cache the result
            self.embedding_cache[cache_key] = weighted_embedding
            
            return weighted_embedding
            
        except Exception as e:
            print(f"Error computing embedding: {e}")
            return np.random.randn(768)  # Fallback
    
    def add_exemplar(self, question, cot_steps, metadata=None):
        """Add an exemplar to the retrieval bank with metadata"""
        try:
            if not question or not cot_steps:
                return False
                
            # Create default metadata if none provided
            if metadata is None:
                metadata = {
                    'quality': 1.0,  # Default high quality
                    'topics': self.extract_topics(question),
                    'num_steps': len(cot_steps) if isinstance(cot_steps, list) else 1,
                    'date_added': datetime.now().isoformat()
                }
                
            # Compute embedding
            embedding = self.compute_embedding(question)
            
            # Add to exemplar bank
            self.exemplar_bank.append((question, embedding, cot_steps, metadata))
            
            # Limit size to prevent memory issues
            if len(self.exemplar_bank) > 1000:
                # Remove lowest quality exemplars
                self.exemplar_bank = sorted(
                    self.exemplar_bank, 
                    key=lambda x: x[3].get('quality', 0.0), 
                    reverse=True
                )[:1000]
                
            return True
            
        except Exception as e:
            print(f"Error adding exemplar: {e}")
            return False

    def apply_keyword_weighting(self, text, embedding):
        """Apply domain-specific keyword weighting to the embedding"""
        text_lower = text.lower()
        weighted = embedding.copy()
        
        weight_multiplier = 1.0
        for keyword, weight in self.keyword_weights.items():
            if keyword in text_lower:
                weight_multiplier += (weight - 1.0) * 0.1  # Gradual effect
        
        return weighted * weight_multiplier
    
    def extract_topics(self, text):
        """Extract likely math topics from the text"""
        topics = []
        topic_patterns = {
            'algebra': r'equation|variable|solve for|unknown|linear|quadratic',
            'geometry': r'triangle|circle|angle|area|volume|perimeter',
            'probability': r'probability|chance|likelihood|random|odds',
            'statistics': r'average|mean|median|mode|standard deviation|variance',
            'calculus': r'derivative|integral|rate of change|maximum|minimum',
            'arithmetic': r'add|subtract|multiply|divide|sum|difference|product|quotient',
            'word_problem': r'train|distance|time|speed|mixture|percent|increase|decrease'
        }
        
        for topic, pattern in topic_patterns.items():
            if re.search(pattern, text, re.IGNORECASE):
                topics.append(topic)
                
        return topics
    
    def retrieve_similar_exemplars(self, question, k=5, filter_criteria=None):
        """Retrieve similar exemplars with filtering options"""
        try:
            if not self.exemplar_bank:
                return []
                
            query_embedding = self.compute_embedding(question)
            query_topics = self.extract_topics(question)
            
            # Calculate similarities with metadata bonuses
            candidates = []
            for i, (exemplar_question, exemplar_embedding, cot_steps, metadata) in enumerate(self.exemplar_bank):
                try:
                    # Base similarity using cosine similarity
                    sim = cosine_similarity([query_embedding], [exemplar_embedding])[0][0]
                    
                    # Apply topic matching bonus
                    exemplar_topics = metadata.get('topics', [])
                    matching_topics = set(query_topics).intersection(set(exemplar_topics))
                    topic_bonus = len(matching_topics) * 0.05  # 5% bonus per matching topic
                    
                    # Apply quality bonus
                    quality_bonus = (metadata.get('quality', 1.0) - 0.5) * 0.1  # Up to 5% bonus for quality
                    
                    # Apply appropriate length bonus
                    steps_count = metadata.get('num_steps', 0)
                    length_bonus = min(steps_count / 10, 0.05)  # Up to 5% bonus for longer examples
                    
                    # Final adjusted similarity
                    adjusted_sim = sim + topic_bonus + quality_bonus + length_bonus
                    
                    # Apply filters if specified
                    if filter_criteria:
                        if 'min_quality' in filter_criteria and \
                           metadata.get('quality', 0) < filter_criteria['min_quality']:
                            continue
                        if 'min_steps' in filter_criteria and \
                           metadata.get('num_steps', 0) < filter_criteria['min_steps']:
                            continue
                    
                    candidates.append((i, adjusted_sim, cot_steps))
                    
                except Exception as e:
                    print(f"Error processing exemplar {i}: {e}")
            
            # Sort and get top-k
            candidates.sort(key=lambda x: x[1], reverse=True)
            top_k = candidates[:k]
            
            # Return the cot_steps for top matches
            return [steps for _, _, steps in top_k]
            
        except Exception as e:
            print(f"Error retrieving similar exemplars: {e}")
            return []
    
    def initialize_from_dataset(self, dataset, max_exemplars=300):
        """Initialize the retrieval module with examples from a dataset with metadata"""
        added_count = 0
        for i, item in enumerate(tqdm(dataset, desc="Building exemplar bank")):
            if added_count >= max_exemplars:
                break
                
            try:
                question = item.get("raw_question", "")
                cot = item.get("raw_cot", [])
                
                # Create metadata
                if isinstance(cot, list):
                    num_steps = len(cot)
                else:
                    # If cot is a string, try to split into steps
                    cot_list = re.split(r'Step \d+:|^\d+\.', cot)
                    cot_list = [s.strip() for s in cot_list if s.strip()]
                    cot = cot_list
                    num_steps = len(cot_list)
                
                metadata = {
                    'quality': 1.0,  # Assume high quality for initial dataset
                    'topics': self.extract_topics(question),
                    'num_steps': num_steps,
                    'date_added': datetime.now().isoformat(),
                    'source': 'initial_dataset'
                }
                
                success = self.add_exemplar(question, cot, metadata)
                if success:
                    added_count += 1
                    
            except Exception as e:
                print(f"Error adding exemplar {i} from dataset: {e}")
        
        print(f"Added {added_count} exemplars from dataset to retrieval module")
    
    def clear_cache(self):
        """Clear the embedding cache to free memory"""
        self.embedding_cache = {}
        
    def save(self, path):
        """Save the retrieval module state"""
        try:
            state = {
                'exemplar_bank': self.exemplar_bank,
                'keyword_weights': self.keyword_weights
            }
            with open(path, 'wb') as f:
                pickle.dump(state, f)
            print(f"Retrieval module saved to {path}")
        except Exception as e:
            print(f"Error saving retrieval module: {e}")
    
    def load(self, path):
        """Load the retrieval module state"""
        try:
            with open(path, 'rb') as f:
                state = pickle.load(f)
            self.exemplar_bank = state['exemplar_bank']
            self.keyword_weights = state.get('keyword_weights', self.keyword_weights)
            print(f"Retrieval module loaded from {path} with {len(self.exemplar_bank)} exemplars")
        except Exception as e:
            print(f"Error loading retrieval module: {e}")

In [34]:
class TextRefinementTransformer:
    def __init__(self, model_name=MODEL_NAME):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
            device = torch.device("cuda" if torch.cuda.is_available() else 
                               "mps" if torch.backends.mps.is_available() else "cpu")
            self.model.to(self.device)
            
            # Pattern library for common errors in reasoning
            self.error_patterns = {
                r'(\d+)\s*\+\s*(\d+)\s*=\s*(\d+)': self._check_addition,
                r'(\d+)\s*\-\s*(\d+)\s*=\s*(\d+)': self._check_subtraction,
                r'(\d+)\s*\*\s*(\d+)\s*=\s*(\d+)': self._check_multiplication,
                r'(\d+)\s*\/\s*(\d+)\s*=\s*(\d+)': self._check_division,
            }
            
            # Template prompts for different refinement strategies
            self.refinement_templates = {
                'low_reward': "Fix the following math solution by checking for calculation errors: {text}",
                'medium_reward': "Improve this math solution with more detailed step-by-step reasoning: {text}",
                'high_reward': "Further enhance this strong math solution with clearer explanations: {text}",
                'error_fix': "The following math solution contains errors in {error_types}. Please fix these errors: {text}",
                'elaborate': "This solution needs more steps. Please elaborate step {step_number} further: {text}",
                'conclude': "This solution needs a clear final answer. Complete it based on the work shown: {text}"
            }
            
        except Exception as e:
            print(f"Error initializing refinement module: {e}")
            self.tokenizer = None
            self.model = None
    
    def _check_addition(self, match):
        """Check if addition is correct and return error if not"""
        a, b, c = map(int, match.groups())
        return a + b != c
        
    def _check_subtraction(self, match):
        """Check if subtraction is correct and return error if not"""
        a, b, c = map(int, match.groups())
        return a - b != c
        
    def _check_multiplication(self, match):
        """Check if multiplication is correct and return error if not"""
        a, b, c = map(int, match.groups())
        return a * b != c
        
    def _check_division(self, match):
        """Check if division is approximately correct and return error if not"""
        a, b, c = map(int, match.groups())
        if b == 0:
            return True  # Division by zero is an error
        return abs(a / b - c) > 0.01  # Allow for small floating point differences
    
    def detect_errors(self, text):
        """Detect specific calculation errors in the text"""
        errors = []
        
        for pattern, check_func in self.error_patterns.items():
            for match in re.finditer(pattern, text):
                if check_func(match):
                    error_text = match.group(0)
                    errors.append(error_text)
        
        return errors
    
    def select_refinement_strategy(self, text, reward, has_final_answer):
        """Select the best refinement strategy based on text analysis"""
        errors = self.detect_errors(text)
        
        # If there are calculation errors, prioritize fixing them
        if errors:
            error_types = "calculations (" + ", ".join(errors[:2]) + ")" if len(errors) > 0 else "reasoning"
            return 'error_fix', {'error_types': error_types}
            
        # If no final answer, prioritize adding conclusion
        if not has_final_answer:
            return 'conclude', {}
            
        # Choose template based on reward level
        if reward < 0.4:
            return 'low_reward', {}
        elif reward < 0.7:
            return 'medium_reward', {}
        else:
            return 'high_reward', {}
    
    def refine_text(self, input_text, reward=0.5, max_length=MAX_LENGTH):
        """Refine the input text using the selected strategy"""
        if self.model is None or self.tokenizer is None:
            return input_text  # Fallback
            
        try:
            # Extract final answer to check if it exists
            final_answer = extract_final_answer(input_text)
            has_final_answer = bool(final_answer)
            
            # Select appropriate refinement strategy
            strategy, params = self.select_refinement_strategy(input_text, reward, has_final_answer)
            
            # Prepare prompt based on strategy
            template = self.refinement_templates[strategy]
            params['text'] = input_text
            prompt = template.format(**params)
            
            # Tokenize
            inputs = self.tokenizer(
                prompt, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=max_length // 2
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # Generate improved text
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=max_length,
                    num_return_sequences=1,
                    do_sample=True,
                    temperature=0.7,
                    no_repeat_ngram_size=3,
                    num_beams=4
                )
            
            refined_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Fallback to original if refinement failed or got much shorter
            if len(refined_text) < len(input_text) * 0.5:
                print(f"Refinement produced much shorter text. Using original.")
                return input_text
                
            return refined_text
            
        except Exception as e:
            print(f"Error refining text: {e}")
            return input_text  # Fallback to original
    
    def batch_refine(self, text_list, rewards, max_length=MAX_LENGTH):
        """Refine a batch of texts in parallel"""
        refined_texts = []
        
        for text, reward in zip(text_list, rewards):
            refined = self.refine_text(text, reward, max_length)
            refined_texts.append(refined)
            
        return refined_texts
    
    def save(self, path):
        """Save the refinement module"""
        try:
            if self.model:
                self.model.save_pretrained(path)
                self.tokenizer.save_pretrained(path)
                print(f"Refinement module saved to {path}")
        except Exception as e:
            print(f"Error saving refinement module: {e}")
    
    def load(self, path):
        """Load the refinement module"""
        try:
            self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
            self.tokenizer = AutoTokenizer.from_pretrained(path)
            self.model.to(self.device)
            print(f"Refinement module loaded from {path}")
        except Exception as e:
            print(f"Error loading refinement module: {e}")

In [35]:
class RewardFunction:
    def __init__(self, reflection_module):
        self.reflection_module = reflection_module
        
    def outcome_reward(self, predicted_answer, gold_answer):
        """Calculate a reward based on how close the predicted answer is to the gold answer"""
        if not predicted_answer or not gold_answer:
            return 0.0
        
        try:
            # Clean and normalize answers
            pred = self._normalize_answer(predicted_answer)
            gold = self._normalize_answer(gold_answer)
            
            # Exact match
            if pred == gold:
                return 1.0
            
            # Try to convert to numbers for numerical comparison
            try:
                pred_num = float(pred.replace(',', ''))
                gold_num = float(gold.replace(',', ''))
                
                # Calculate relative error
                if abs(gold_num) > 1e-6:  # Avoid division by zero
                    rel_error = abs(pred_num - gold_num) / abs(gold_num)
                    if rel_error < 0.01:  # Within 1%
                        return 0.9
                    elif rel_error < 0.05:  # Within 5%
                        return 0.7
                    elif rel_error < 0.1:  # Within 10%
                        return 0.5
                    elif rel_error < 0.2:  # Within 20%
                        return 0.3
                
                # If numbers are small, use absolute error
                abs_error = abs(pred_num - gold_num)
                if abs_error < 0.01:
                    return 0.8
            except:
                pass
            
            # Partial string match as fallback
            if pred in gold or gold in pred:
                return 0.3
                
            return 0.0
        except Exception as e:
            print(f"Error in outcome_reward: {e}")
            return 0.0
        
    def _normalize_answer(self, answer):
        """Normalize an answer string for comparison"""
        if not answer:
            return ""
            
        # Convert to string if needed
        answer = str(answer)
        
        # Remove punctuation and lowercase
        answer = answer.lower().strip()
        answer = re.sub(r'[^\w\s\.\-]', '', answer)
        
        # Remove units and common words
        units = ['dollars', 'dollar', '$', 'cents', 'apples', 'people', 'hours', 'km', 'meters', 'years']
        for unit in units:
            answer = answer.replace(unit, '').strip()
        
        return answer.strip()
    
    def process_reward(self, cot_steps, question):
        """Improved reward based on the quality of the reasoning process"""
        if not cot_steps:
            return 0.2  # Small baseline reward even with no steps
        
        try:
            # Check for math expressions, operations, and equations
            has_math_expressions = False
            for step in cot_steps:
                if re.search(r'[=+\-*/]', step) and re.search(r'\d', step):
                    has_math_expressions = True
                    break
            
            # Check for relevant keywords from the question in the steps
            question_words = set(re.findall(r'\b\w{4,}\b', question.lower()))
            step_text = ' '.join(cot_steps).lower()
            relevant_word_count = sum(1 for word in question_words if word in step_text)
            relevance_score = min(relevant_word_count / max(len(question_words), 1), 1.0)
            
            # Evaluate each step and take the mean
            step_scores = []
            for step in cot_steps:
                score = self.reflection_module.evaluate_step(step, question)
                step_scores.append(score)
            
            # More generous length bonus
            length_bonus = min(len(step_scores) / 3, 1.0)  # Bonus for 3+ steps, capped at 1.0
            
            # Bonus for using numbers in steps
            has_numbers = any(bool(re.search(r'\d+', step)) for step in cot_steps)
            number_bonus = 0.2 if has_numbers else 0.0
            
            # Bonus for mathematical expressions
            math_bonus = 0.3 if has_math_expressions else 0.0
            
            # Return the mean score with bonuses
            base_score = sum(step_scores) / len(step_scores) if step_scores else 0.3
            return min(base_score * (1.0 + 0.3 * length_bonus + number_bonus + math_bonus) + 0.2 * relevance_score, 1.0)
        except Exception as e:
            print(f"Error in process_reward: {e}")
            return 0.3  # More generous fallback
    
    def combined_reward(self, cot_steps, predicted_answer, true_answer, question, training_progress=0.0):
        """
        Combine process and outcome rewards with adaptive weighting based on training progress
        
        Args:
            cot_steps: List of reasoning steps
            predicted_answer: Model's predicted answer
            true_answer: Gold/reference answer
            question: Original question text
            training_progress: Float between 0-1 indicating training progress
        
        Returns:
            Float: Combined reward score between 0-1
        """
        try:
            # Calculate component rewards
            outcome = self.outcome_reward(predicted_answer, true_answer)
            process = self.process_reward(cot_steps, question)
            
            # Dynamic alpha weighting that shifts focus towards outcome as training progresses
            # Early training: focus more on process (reasoning quality)
            # Late training: focus more on outcome (correct answers)
            base_alpha = 0.7  # Same as original
            min_alpha = 0.5   # Don't go below this to maintain outcome importance
            
            # Calculate adaptive alpha, increasing with training progress
            # Start with lower alpha (more process focus) and gradually increase
            alpha = min_alpha + (base_alpha - min_alpha) * training_progress
            
            # Apply confidence-based adjustment
            # If outcome is very good but process is poor, we might have a lucky guess
            # If process is very good but outcome is poor, reasoning might be correct but execution failed
            confidence_adjustment = 0.0
            
            # If there's a big gap between process and outcome, adjust weighting
            gap = abs(outcome - process)
            if gap > 0.4:  # Significant disagreement between outcome and process
                if outcome > process + 0.4:  # Much better outcome than process - lucky?
                    confidence_adjustment = -0.1  # Reduce alpha slightly to trust process more
                elif process > outcome + 0.4:  # Much better process than outcome - calculation error?
                    confidence_adjustment = 0.1   # Increase alpha slightly to focus on outcomes
            
            # Apply adjustment but stay within bounds
            adjusted_alpha = max(min_alpha, min(base_alpha, alpha + confidence_adjustment))
            
            # Combine rewards with adjusted alpha
            combined = adjusted_alpha * outcome + (1 - adjusted_alpha) * process
            
            # For very short or empty reasoning chains with good outcomes, apply a penalty
            if outcome > 0.8 and (not cot_steps or len(cot_steps) < 2):
                combined *= 0.9  # Small penalty for correct answers with insufficient reasoning
            
            # Debug info with expanded metrics
            if random.random() < 0.05:  # Only print for 5% of examples to avoid log flooding
                print(f"\nQuestion: {question[:50]}...")
                print(f"Predicted: {predicted_answer}")
                print(f"True: {true_answer}")
                print(f"Outcome reward: {outcome:.2f}, Process reward: {process:.2f}")
                print(f"Training progress: {training_progress:.2f}, Alpha: {adjusted_alpha:.2f}")
                print(f"Combined reward: {combined:.2f}")
                if cot_steps:
                    print(f"First reasoning step: {cot_steps[0][:50]}...")
                    print(f"Step count: {len(cot_steps)}")
            
            return combined
        except Exception as e:
            print(f"Error in combined_reward: {e}")
            return 0.3  # More generous fallback

In [36]:

class SelfTrainer:
    def __init__(
        self, 
        cot_generator,
        reflection_module,
        retrieval_module,
        refinement_module,
        reward_function,
        tokenizer=None
    ):
        self.cot_generator = cot_generator
        self.reflection_module = reflection_module
        self.retrieval_module = retrieval_module
        self.refinement_module = refinement_module
        self.reward_function = reward_function
        try:
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(MODEL_NAME)
        except Exception as e:
            print(f"Error loading tokenizer: {e}")
            self.tokenizer = None
        
    def generate_pseudo_labels(self, dataset, threshold=0.3, max_samples=100, training_progress=0.0):
        """
        Generate high-quality pseudo-labeled examples with adaptive reward calculations
        
        Args:
            dataset: Dataset to generate examples from
            threshold: Minimum reward threshold for keeping examples
            max_samples: Maximum number of samples to process
            training_progress: Float between 0-1 indicating training progress
            
        Returns:
            list: Pseudo-labeled examples that meet quality threshold
        """
        pseudo_labeled = []
        rewards = []
        retry_count = 0
        
        # Sample more examples
        sample_count = min(max_samples, len(dataset))
        sample_indices = random.sample(range(len(dataset)), sample_count)
        
        for i in tqdm(sample_indices, desc="Generating pseudo-labels"):
            try:
                item = dataset[i]
                question = item["raw_question"]
                gold_answer = item["raw_answer"]
                
                # Generate CoT with the current model using a stronger prompt
                generated = self.cot_generator.generate(
                    question,
                    reflection_module=self.reflection_module,
                    temperature=0.7  # Add some variation
                )
    
                # Retrieve similar exemplars to guide generation
                similar_exemplars = self.retrieval_module.retrieve_similar_exemplars(question, k=3)
                exemplar_text = ""
                if similar_exemplars:
                    exemplar_text = "Here are examples of good reasoning:\n"
                    for i, exemplar in enumerate(similar_exemplars[:2]):  # Use top 2 exemplars
                        steps_text = "\n".join([f"Step {j+1}: {step}" for j, step in enumerate(exemplar)])
                        exemplar_text += f"Example {i+1}:\n{steps_text}\n\n"
                
                # Try to improve generation with exemplars if available
                if exemplar_text:
                    enhanced_prompt = f"{exemplar_text}\nNow solve this problem:\n{question}"
                    enhanced_generation = self.cot_generator.generate(
                        enhanced_prompt,
                        reflection_module=self.reflection_module,
                        temperature=0.65  # Slightly lower temperature for guided generation
                    )
                    
                    # Calculate rewards for both generations and use the better one
                    base_reward = self.reward_function.combined_reward(
                        generated["cot_steps"], 
                        generated["final_answer"], 
                        gold_answer, 
                        question,
                        training_progress=training_progress
                    )
                    
                    enhanced_reward = self.reward_function.combined_reward(
                        enhanced_generation["cot_steps"], 
                        enhanced_generation["final_answer"], 
                        gold_answer, 
                        question,
                        training_progress=training_progress
                    )
                    
                    if enhanced_reward > base_reward:
                        generated = enhanced_generation
                        reward = enhanced_reward
                    else:
                        reward = base_reward
                else:
                    reward = self.reward_function.combined_reward(
                        generated["cot_steps"], 
                        generated["final_answer"], 
                        gold_answer, 
                        question,
                        training_progress=training_progress
                    )
                
                # Apply refinement with different strategies based on reward quality
                if reward < 0.8:  # Apply to all but the very best examples
                    refinement_strategy = "major" if reward < 0.5 else "minor"
                    refined_text = self.refinement_module.refine_text(
                        generated["full_output"], 
                        strategy=refinement_strategy
                    )
                    refined_steps = extract_cot_steps(refined_text)
                    refined_answer = extract_final_answer(refined_text)
                    
                    refined_reward = self.reward_function.combined_reward(
                        refined_steps, 
                        refined_answer, 
                        gold_answer, 
                        question,
                        training_progress=training_progress
                    )
                    
                    # Use the refined version if it's better
                    if refined_reward > reward:
                        generated["cot_steps"] = refined_steps
                        generated["final_answer"] = refined_answer
                        generated["full_output"] = refined_text
                        reward = refined_reward
                
                # Keep examples that exceed the quality threshold
                if reward >= threshold:
                    pseudo_labeled.append({
                        "question": question,
                        "cot_steps": generated["cot_steps"],
                        "final_answer": generated["final_answer"],
                        "full_output": generated["full_output"],
                        "gold_answer": gold_answer,
                        "reward": reward,
                        "source_index": i  # Track which example this came from
                    })
                    rewards.append(reward)
                else:
                    # For examples that almost meet the threshold, try one more generation with different parameters
                    if reward >= threshold - 0.1 and retry_count < max_samples * 0.2:  # Limit retries to 20%
                        retry_count += 1
                        
                        # Try again with different temperature
                        retry_generated = self.cot_generator.generate(
                            question,
                            reflection_module=self.reflection_module,
                            temperature=0.9  # Higher temperature for more exploration
                        )
                        
                        retry_reward = self.reward_function.combined_reward(
                            retry_generated["cot_steps"], 
                            retry_generated["final_answer"], 
                            gold_answer, 
                            question,
                            training_progress=training_progress
                        )
                        
                        if retry_reward >= threshold:
                            pseudo_labeled.append({
                                "question": question,
                                "cot_steps": retry_generated["cot_steps"],
                                "final_answer": retry_generated["final_answer"],
                                "full_output": retry_generated["full_output"],
                                "gold_answer": gold_answer,
                                "reward": retry_reward,
                                "source_index": i,
                                "is_retry": True
                            })
                            rewards.append(retry_reward)
            except Exception as e:
                print(f"Error processing item {i}: {e}")
        
        # Log statistics
        if rewards:
            avg_reward = sum(rewards) / len(rewards)
            median_reward = sorted(rewards)[len(rewards) // 2]
            print(f"Generated {len(pseudo_labeled)} pseudo-labeled examples:")
            print(f"  Average reward: {avg_reward:.3f}")
            print(f"  Median reward: {median_reward:.3f}")
            print(f"  Min/Max reward: {min(rewards):.3f}/{max(rewards):.3f}")
            print(f"  Retries attempted: {retry_count}")
        else:
            print("Failed to generate any valid pseudo-labeled examples")
            
        return pseudo_labeled
    
    def train_with_pseudo_labels(self, pseudo_labeled, learning_rate=1e-5, epochs=2, batch_size=4):
        """Train the model with pseudo-labeled examples"""
        if not pseudo_labeled:
            print("No pseudo-labeled examples available for training")
            return
        
        try:
            # Create a simple dataset from pseudo-labeled examples
            class PseudoLabeledDataset(Dataset):
                def __init__(self, examples, tokenizer, max_length=MAX_LENGTH):
                    self.examples = examples
                    self.tokenizer = tokenizer
                    self.max_length = max_length
                
                def __len__(self):
                    return len(self.examples)
                
                def __getitem__(self, idx):
                    item = self.examples[idx]
                    
                    input_text = f"Solve this math problem step-by-step: {item['question']}"
                    target_text = item["full_output"]
                    
                    inputs = self.tokenizer(
                        input_text,
                        padding="max_length",
                        truncation=True,
                        max_length=self.max_length // 3,
                        return_tensors="pt"
                    )
                    
                    targets = self.tokenizer(
                        target_text,
                        padding="max_length",
                        truncation=True,
                        max_length=self.max_length * 2 // 3,
                        return_tensors="pt"
                    )
                    
                    return {
                        "input_ids": inputs.input_ids.squeeze(),
                        "attention_mask": inputs.attention_mask.squeeze(),
                        "labels": targets.input_ids.squeeze(),
                    }
            
            # Initialize dataset and dataloader
            pseudo_dataset = PseudoLabeledDataset(pseudo_labeled, self.tokenizer)
            pseudo_dataloader = DataLoader(
                pseudo_dataset, 
                batch_size=batch_size,
                shuffle=True
            )
            
            # Set up optimizer
            optimizer = AdamW(self.cot_generator.model.parameters(), lr=learning_rate)
            
            # Training loop
            total_steps = len(pseudo_dataloader) * epochs
            print(f"Starting training on {len(pseudo_labeled)} examples for {epochs} epochs ({total_steps} steps)")
            
            device = self.cot_generator.model.device
            
            for epoch in range(epochs):
                epoch_loss = 0.0
                
                for step, batch in enumerate(tqdm(pseudo_dataloader, desc=f"Epoch {epoch+1}/{epochs}")):
                    # Move batch to device
                    batch = {k: v.to(device) for k, v in batch.items()}
                    
                    # Forward pass
                    self.cot_generator.model.train()
                    outputs = self.cot_generator.model(
                        input_ids=batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        labels=batch["labels"]
                    )
                    
                    loss = outputs.loss
                    
                    # Backward pass
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    epoch_loss += loss.item()
                    
                    # Log progress
                    if step % 10 == 0:
                        print(f"Step {step}/{len(pseudo_dataloader)}, Loss: {loss.item():.4f}")
                
                # End of epoch stats
                avg_epoch_loss = epoch_loss / len(pseudo_dataloader)
                print(f"Epoch {epoch+1}/{epochs} completed. Average loss: {avg_epoch_loss:.4f}")
                
                # Save checkpoint
                self.cot_generator.save(f"./models/recot_checkpoint_epoch_{epoch+1}")
                
            print("Training completed successfully!")
            
        except Exception as e:
            print(f"Error in train_with_pseudo_labels: {e}")
    
    def run_ppo_update(self, batch, ppo_epochs=4, mini_batch_size=2, clip_param=0.2, 
                  value_loss_coef=0.5, entropy_coef=0.01, max_grad_norm=0.5):
        """
        Perform PPO update on the model using the given batch with improved implementation
        
        Args:
            batch: List of examples with questions, gold answers, etc.
            ppo_epochs: Number of epochs to run over the entire batch
            mini_batch_size: Size of mini-batches for updates
            clip_param: PPO clipping parameter
            value_loss_coef: Value function loss coefficient
            entropy_coef: Entropy bonus coefficient
            max_grad_norm: Maximum gradient norm for clipping
        """
        try:
            device = self.cot_generator.model.device
            
            # Prepare data structures
            trajectories = []
            
            # Collect trajectories
            print(f"Collecting trajectories from {len(batch)} examples...")
            for item in tqdm(batch, desc="Collecting trajectories"):
                question = item["question"]
                gold_answer = item["gold_answer"]
                
                # Generate CoT with current policy and record log probs
                with torch.no_grad():
                    input_text = f"Solve this math problem step-by-step: {question}"
                    inputs = self.tokenizer(
                        input_text, 
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=MAX_LENGTH // 3
                    ).to(device)
                    
                    # Generate with model (record log probs)
                    self.cot_generator.model.eval()
                    outputs = self.cot_generator.model.generate(
                        input_ids=inputs["input_ids"],
                        attention_mask=inputs["attention_mask"],
                        max_length=MAX_LENGTH,
                        output_scores=True,
                        return_dict_in_generate=True,
                        do_sample=True,
                        temperature=0.7
                    )
                    
                    # Extract sequences and scores
                    sequences = outputs.sequences
                    log_probs = torch.stack(outputs.scores, dim=1)
                    generated_text = self.tokenizer.decode(sequences[0], skip_special_tokens=True)
                    
                    # Extract CoT steps and final answer
                    cot_steps = extract_cot_steps(generated_text)
                    final_answer = extract_final_answer(generated_text)
                    
                    # Calculate reward
                    reward = self.reward_function.combined_reward(
                        cot_steps, 
                        final_answer, 
                        gold_answer, 
                        question,
                        training_progress=0.5  # Provide estimated training progress
                    )
                    
                    # Store trajectory
                    trajectories.append({
                        "input_ids": inputs["input_ids"],
                        "attention_mask": inputs["attention_mask"],
                        "generated_ids": sequences,
                        "log_probs": log_probs,
                        "reward": torch.tensor([reward], device=device)
                    })
            
            if not trajectories:
                print("No valid trajectories collected. Skipping PPO update.")
                return
                
            # Create optimizer with smaller learning rate for PPO
            optimizer = AdamW(self.cot_generator.model.parameters(), lr=5e-6)
            
            # Initialize PPO statistics
            stats = {
                "policy_loss": [],
                "value_loss": [],
                "entropy": [],
                "total_loss": [],
                "approx_kl": [],
                "clip_fraction": [],
                "explained_variance": []
            }
            
            # PPO update loop
            print(f"Running PPO for {ppo_epochs} epochs with mini-batch size {mini_batch_size}...")
            for epoch in range(ppo_epochs):
                # Shuffle trajectories for this epoch
                random.shuffle(trajectories)
                epoch_stats = []
                
                # Process mini-batches
                for i in range(0, len(trajectories), mini_batch_size):
                    mini_batch = trajectories[i:i+mini_batch_size]
                    if not mini_batch:
                        continue
                        
                    # Initialize batch tensors
                    batch_input_ids = torch.cat([item["input_ids"] for item in mini_batch], dim=0)
                    batch_attention_mask = torch.cat([item["attention_mask"] for item in mini_batch], dim=0)
                    batch_generated_ids = torch.cat([item["generated_ids"] for item in mini_batch], dim=0)
                    batch_rewards = torch.cat([item["reward"] for item in mini_batch], dim=0)
                    
                    # Normalize rewards for stability
                    batch_rewards = (batch_rewards - batch_rewards.mean()) / (batch_rewards.std() + 1e-8)
                    
                    # Create labels for the model by right-shifting generated ids
                    batch_labels = batch_generated_ids.clone()
                    batch_labels[:, :-1] = batch_generated_ids[:, 1:]
                    batch_labels[:, -1] = self.tokenizer.pad_token_id
                    
                    # Compute value predictions and action log probs with current policy
                    self.cot_generator.model.train()
                    
                    # Forward pass with current policy
                    outputs = self.cot_generator.model(
                        input_ids=batch_input_ids,
                        attention_mask=batch_attention_mask,
                        labels=batch_labels
                    )
                    
                    # Get log probs of current policy
                    logits = outputs.logits
                    current_log_probs = F.log_softmax(logits, dim=-1)
                    
                    # Get old log probs
                    old_log_probs = torch.cat([item["log_probs"] for item in mini_batch], dim=0)
                    
                    # Calculate ratio and clipped surrogate objective (simplification)
                    # Note: In a full implementation, we'd need to match token by token
                    ratio = torch.exp(current_log_probs.mean(dim=1) - old_log_probs.mean(dim=1))
                    surr1 = ratio * batch_rewards
                    surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * batch_rewards
                    
                    # Policy loss
                    policy_loss = -torch.min(surr1, surr2).mean()
                    
                    # Value loss (using model loss as a proxy for value)
                    value_pred = -outputs.loss.detach()
                    value_targets = batch_rewards
                    value_loss = F.mse_loss(value_pred, value_targets)
                    
                    # Entropy bonus (encourage exploration)
                    entropy = -(F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)).sum(dim=-1).mean()
                    
                    # Total loss
                    loss = policy_loss + value_loss_coef * value_loss - entropy_coef * entropy
                    
                    # Backward pass
                    optimizer.zero_grad()
                    loss.backward()
                    
                    # Gradient clipping
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.cot_generator.model.parameters(), 
                        max_grad_norm
                    )
                    
                    # Optimizer step
                    optimizer.step()
                    
                    # Calculate statistics
                    clip_fraction = ((ratio - 1.0).abs() > clip_param).float().mean().item()
                    approx_kl = (ratio.log() * (ratio - 1)).mean().item()
                    
                    # Store batch statistics
                    batch_stats = {
                        "policy_loss": policy_loss.item(),
                        "value_loss": value_loss.item(),
                        "entropy": entropy.item(),
                        "total_loss": loss.item(),
                        "approx_kl": approx_kl,
                        "clip_fraction": clip_fraction,
                        "grad_norm": grad_norm.item()
                    }
                    epoch_stats.append(batch_stats)
                
                # Calculate epoch statistics
                if epoch_stats:
                    epoch_mean_stats = {k: sum(d[k] for d in epoch_stats) / len(epoch_stats) for k in epoch_stats[0]}
                    for k, v in epoch_mean_stats.items():
                        stats[k] = stats.get(k, []) + [v]
                    
                    # Print epoch statistics
                    print(f"PPO Epoch {epoch+1}/{ppo_epochs} stats:")
                    print(f"  Policy loss: {epoch_mean_stats['policy_loss']:.4f}")
                    print(f"  Value loss: {epoch_mean_stats['value_loss']:.4f}")
                    print(f"  Entropy: {epoch_mean_stats['entropy']:.4f}")
                    print(f"  Approx KL: {epoch_mean_stats['approx_kl']:.4f}")
                    print(f"  Clip fraction: {epoch_mean_stats['clip_fraction']:.4f}")
                    print(f"  Gradient norm: {epoch_mean_stats['grad_norm']:.4f}")
            
            # Final stats summary
            if all(len(v) > 0 for v in stats.values()):
                print("\nPPO update completed. Final statistics:")
                for k, v in stats.items():
                    if v:
                        print(f"  {k}: {v[-1]:.4f} (started: {v[0]:.4f})")
            
        except Exception as e:
            print(f"Error in PPO update: {e}")
            traceback.print_exc()
    
    def evaluate(self, test_dataset, num_samples=50):
        """Evaluate the current model on a test dataset"""
        correct = 0
        total = 0
        rewards = []
        
        # Sample a subset for efficient evaluation
        sample_indices = random.sample(range(len(test_dataset)), min(num_samples, len(test_dataset)))
        
        for idx in tqdm(sample_indices, desc="Evaluating"):
            try:
                item = test_dataset[idx]
                question = item["raw_question"]
                gold_answer = item["raw_answer"]
                
                # Generate answer
                generation = self.cot_generator.generate(question)
                predicted_answer = generation["final_answer"]
                
                # Check if correct
                reward = self.reward_function.outcome_reward(predicted_answer, gold_answer)
                rewards.append(reward)
                
                # Binary correctness (1.0 means exactly correct)
                if reward >= 0.9:
                    correct += 1
                total += 1
                
            except Exception as e:
                print(f"Error evaluating example {idx}: {e}")
        
        # Calculate metrics
        accuracy = correct / total if total > 0 else 0
        avg_reward = sum(rewards) / len(rewards) if rewards else 0
        
        print(f"Evaluation results:")
        print(f"  Accuracy: {accuracy:.2f} ({correct}/{total})")
        print(f"  Average reward: {avg_reward:.3f}")
        
        return {
            "accuracy": accuracy,
            "avg_reward": avg_reward,
            "correct": correct,
            "total": total
        }
    
    def run_training_loop(self, train_dataset, test_dataset, num_iterations=5, pseudo_samples=100, batch_size=4):
        """
        Run the full training loop with pseudo-labeling and improved RL updates
        
        Args:
            train_dataset: Dataset to generate pseudo-labels from
            test_dataset: Dataset for evaluation
            num_iterations: Number of training iterations
            pseudo_samples: Maximum number of samples to generate pseudo-labels for
            batch_size: Batch size for supervised training
            
        Returns:
            dict: Final evaluation results
        """
        try:
            results_history = []
            
            # Initial evaluation
            print("Initial model evaluation:")
            results = self.evaluate(test_dataset)
            results_history.append(results)
            initial_accuracy = results["accuracy"]
            
            for iteration in range(num_iterations):
                print(f"\n===== Iteration {iteration+1}/{num_iterations} =====")
                
                # Calculate training progress for adaptive reward weighting
                training_progress = iteration / (num_iterations - 1) if num_iterations > 1 else 0.5
                
                # 1. Generate pseudo-labeled examples with adaptive threshold
                # Lower threshold initially, increase as training progresses
                threshold = 0.3 + (0.2 * training_progress)  # Starts at 0.3, increases to 0.5
                
                print(f"Generating {pseudo_samples} pseudo-labeled examples (threshold={threshold:.2f})...")
                pseudo_labeled = self.generate_pseudo_labels(
                    train_dataset, 
                    threshold=threshold,
                    max_samples=pseudo_samples
                )
                
                if not pseudo_labeled:
                    print("No pseudo-labeled examples generated. Adjusting threshold and retrying...")
                    # Retry with lower threshold if no examples meet criteria
                    retry_threshold = max(0.2, threshold - 0.1)
                    pseudo_labeled = self.generate_pseudo_labels(
                        train_dataset, 
                        threshold=retry_threshold,
                        max_samples=pseudo_samples * 2  # Sample more to increase chances
                    )
                    
                    if not pseudo_labeled:
                        print("Still no pseudo-labeled examples. Skipping iteration.")
                        continue
                
                # 2. Analyze pseudo-labeled examples
                rewards = [ex["reward"] for ex in pseudo_labeled]
                avg_reward = sum(rewards) / len(rewards) if rewards else 0
                reward_std = (sum((r - avg_reward) ** 2 for r in rewards) / len(rewards)) ** 0.5 if rewards else 0
                
                print(f"Generated {len(pseudo_labeled)} examples:")
                print(f"  Average reward: {avg_reward:.3f} (std: {reward_std:.3f})")
                print(f"  Reward range: {min(rewards):.3f} - {max(rewards):.3f}")
                
                # 3. Train with pseudo-labeled examples - adaptive learning rate
                # Decrease learning rate as training progresses for finer adjustments
                base_lr = 2e-5
                decay_factor = 1.0 - (0.5 * training_progress)  # Starts at 1.0, decreases to 0.5
                adaptive_lr = base_lr * decay_factor
                
                print(f"Training with {len(pseudo_labeled)} examples (lr={adaptive_lr:.2e})...")
                self.train_with_pseudo_labels(
                    pseudo_labeled,
                    learning_rate=adaptive_lr,
                    epochs=1 + (1 if iteration > num_iterations // 2 else 0),  # Extra epoch later in training
                    batch_size=batch_size
                )
                
                # 4. Perform PPO update with improved implementation
                # Detect if we have enough high-quality examples
                high_quality = [ex for ex in pseudo_labeled if ex["reward"] >= threshold + 0.1]
                if high_quality:
                    print(f"Performing PPO update with {len(high_quality)} high-quality examples...")
                    # Adjust mini-batch size based on available examples
                    optimal_mini_batch = max(2, len(high_quality) // 4)
                    mini_batch_size = min(optimal_mini_batch, 8)  # Cap at 8
                    
                    # Dynamically set epochs based on batch size
                    ppo_epochs = 6 - (mini_batch_size // 2)  # More epochs for smaller batches
                    ppo_epochs = max(3, min(ppo_epochs, 5))  # Between 3-5
                    
                    self.run_ppo_update(
                        high_quality,
                        ppo_epochs=ppo_epochs,
                        mini_batch_size=mini_batch_size,
                        # Increase exploration early, focus more on exploitation later
                        entropy_coef=0.02 * (1.0 - training_progress)
                    )
                else:
                    print("No high-quality examples found for PPO. Skipping reinforcement learning step.")
                
                # 5. Update retrieval module with new exemplars
                print("Updating retrieval module...")
                exemplar_threshold = 0.6 + (0.1 * training_progress)  # Increase quality threshold over time
                quality_exemplars = [ex for ex in pseudo_labeled if ex["reward"] >= exemplar_threshold]
                
                if quality_exemplars:
                    print(f"Adding {len(quality_exemplars)} new exemplars to retrieval module...")
                    for example in quality_exemplars:
                        self.retrieval_module.add_exemplar(
                            example["question"],
                            example["cot_steps"]
                        )
                
                # 6. Save checkpoint with iteration info
                checkpoint_path = f"./models/recot_checkpoint_iteration_{iteration+1}"
                print(f"Saving checkpoint to {checkpoint_path}...")
                try:
                    # Save with metadata
                    metadata = {
                        "iteration": iteration + 1,
                        "timestamp": datetime.datetime.now().isoformat(),
                        "train_examples": len(pseudo_labeled),
                        "avg_reward": avg_reward,
                        "exemplars_added": len(quality_exemplars) if quality_exemplars else 0
                    }
                    self.cot_generator.save(checkpoint_path, metadata=metadata)
                except Exception as e:
                    print(f"Error saving checkpoint: {e}")
                    # Fallback to simple save
                    self.cot_generator.save(checkpoint_path)
                
                # 7. Evaluate progress
                print("Evaluating current model:")
                results = self.evaluate(test_dataset)
                results_history.append(results)
                
                # Print improvement
                if len(results_history) > 1:
                    prev = results_history[-2]["accuracy"]
                    curr = results_history[-1]["accuracy"]
                    diff = curr - prev
                    rel_improvement = f"{(diff / prev * 100):.1f}%" if prev > 0 else "N/A"
                    print(f"Accuracy change: {diff:+.3f} ({prev:.3f} → {curr:.3f}, {rel_improvement} relative)")
                    
                    # Early stopping check - if performance degraded significantly
                    if diff < -0.05 and iteration > 0:
                        print("Warning: Performance degraded significantly. Consider restoring previous checkpoint.")
            
            # Final comprehensive evaluation
            print("\n===== Final Evaluation =====")
            final_results = self.evaluate(test_dataset, num_samples=min(200, len(test_dataset)))
            
            # Print training summary
            print("\n===== Training Summary =====")
            print(f"Initial accuracy: {initial_accuracy:.3f}")
            print(f"Final accuracy: {final_results['accuracy']:.3f}")
            absolute_improvement = final_results['accuracy'] - initial_accuracy
            relative_improvement = (absolute_improvement / initial_accuracy * 100) if initial_accuracy > 0 else float('inf')
            print(f"Absolute improvement: {absolute_improvement:+.3f}")
            print(f"Relative improvement: {relative_improvement:+.1f}%")
            
            # Plot training curve if matplotlib is available
            try:
                import matplotlib.pyplot as plt
                
                iterations = list(range(num_iterations + 1))
                accuracies = [r["accuracy"] for r in results_history] + [final_results["accuracy"]]
                rewards = [r["avg_reward"] for r in results_history] + [final_results["avg_reward"]]
                
                plt.figure(figsize=(10, 6))
                plt.plot(iterations, accuracies, 'b-o', label='Accuracy')
                plt.plot(iterations, rewards, 'r-o', label='Avg Reward')
                plt.xlabel('Iteration')
                plt.ylabel('Score')
                plt.title('Training Progress')
                plt.legend()
                plt.grid(True)
                
                # Save plot
                plt.savefig('./training_progress.png')
                print("Training progress plot saved to ./training_progress.png")
            except ImportError:
                print("Matplotlib not available. Skipping training curve plot.")
            
            return final_results
            
        except Exception as e:
            print(f"Error in training loop: {e}")
            traceback.print_exc()
            return None
        
    def count_parameters(self):
        """Count the number of trainable parameters in the model"""
        model = self.cot_generator.model
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [37]:
# Set random seeds for reproducibility
random.seed(11)
np.random.seed(11)
torch.manual_seed(11)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(11)

# Print system information
device = torch.device("cuda" if torch.cuda.is_available() else 
                        "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Create output directories
os.makedirs("./models", exist_ok=True)
os.makedirs("./results", exist_ok=True)

print("\n===== Initializing Components =====")

# Initialize tokenizer
print("Initializing tokenizer...")
try:
    tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
except Exception as e:
    print(f"Error initializing tokenizer: {e}")
    tokenizer = None

# Initialize data
print("Loading data...")
try:
    train_dataset = GSM8KDataset(
        split="train", 
        tokenizer=tokenizer, 
        max_length=MAX_LENGTH, 
        max_samples=MAX_SAMPLES
    )
    test_dataset = GSM8KDataset(
        split="test", 
        tokenizer=tokenizer, 
        max_length=MAX_LENGTH, 
        max_samples=MAX_SAMPLES // 4
    )
    print(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples")
except Exception as e:
    print(f"Error loading datasets: {e}")
    raise e

# Initialize CoT Generator (main model)
print("Initializing CoT Generator...")
cot_generator = CoTGenerator(MODEL_NAME)

# Initialize Reflection Module
print("Initializing Reflection Module...")
reflection_module = ReflectionModule()
reflection_module = to_device(reflection_module)

# Initialize Retrieval Module
print("Initializing Retrieval Module...")
retrieval_module = RetrievalModule()
# Populate retrieval module with examples from the dataset
retrieval_module.initialize_from_dataset(train_dataset, max_exemplars=100)

# Initialize Refinement Module (transformer for text refinement)
print("Initializing Refinement Module...")
refinement_module = TextRefinementTransformer()

# Initialize Reward Function
print("Initializing Reward Function...")
reward_function = RewardFunction(reflection_module)

# Initialize Self-Trainer
print("Initializing Self-Trainer...")
self_trainer = SelfTrainer(
    cot_generator=cot_generator,
    reflection_module=reflection_module,
    retrieval_module=retrieval_module,
    refinement_module=refinement_module,
    reward_function=reward_function,
    tokenizer=tokenizer
)

num_params = self_trainer.count_parameters()
print(f"Model has {num_params:,} trainable parameters")


# Run a quick test to ensure everything is working
print("\n===== Running Quick Test =====")
try:
    # Test CoT generation
    test_question = "John has 5 apples. Mary gives him 3 more apples. How many apples does John have now?"
    print(f"Test question: {test_question}")
    
    generation = cot_generator.generate(test_question, reflection_module=reflection_module)
    print(f"Generated CoT:")
    for i, step in enumerate(generation["cot_steps"]):
        print(f"  Step {i+1}: {step}")
    print(f"Final answer: {generation['final_answer']}")
    
    # Test evaluation
    print("Testing evaluation...")
    test_results = self_trainer.evaluate(test_dataset, num_samples=5)
    print(f"Test evaluation completed with accuracy: {test_results['accuracy']:.2f}")
except Exception as e:
    print(f"Error in quick test: {e}")
    import traceback
    traceback.print_exc()

# Run the main training loop
print("\n===== Starting Main Training Loop =====")
try:
    training_results = self_trainer.run_training_loop(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        num_iterations=3,  # Reduced number of iterations for initial testing
        pseudo_samples=50,  # Start with fewer samples
        batch_size=BATCH_SIZE
    )
    
    # Save final model
    print("Saving final model...")
    cot_generator.save("./models/recot_final_model")
    
    print("\n===== Training Complete =====")
    if training_results:
        print(f"Final accuracy: {training_results['accuracy']:.4f}")
    
except Exception as e:
    print(f"Error in training loop: {e}")
    import traceback
    traceback.print_exc()

# Example of using the trained model
print("\n===== Example Usage =====")
try:
    example_questions = [
        "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast every morning and bakes muffins for her friends every day with 4 eggs per batch. She bakes 2 batches of muffins per day. How many eggs does Janet have left each day?",
        "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?",
        "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?"
    ]
    
    for i, question in enumerate(example_questions):
        print(f"\nExample {i+1}: {question}")
        result = cot_generator.generate(question, reflection_module=reflection_module)
        print("Generated reasoning:")
        for j, step in enumerate(result["cot_steps"]):
            print(f"  Step {j+1}: {step}")
        print(f"Final answer: {result['final_answer']}")
        
except Exception as e:
    print(f"Error in example usage: {e}")

Using device: mps
PyTorch version: 2.6.0
CUDA available: False

===== Initializing Components =====
Initializing tokenizer...
Loading data...


Preprocessing data: 100%|██████████| 400/400 [00:00<00:00, 12794.34it/s]
Preprocessing data: 100%|██████████| 100/100 [00:00<00:00, 11158.03it/s]


Loaded 400 training examples and 100 test examples
Initializing CoT Generator...
Loading model t5-base...
Model not found locally. Downloading t5-base...
Tokenizer downloaded and saved to ./models/t5_base_cache
Model downloaded and moved to mps
Initializing Reflection Module...
Initializing Retrieval Module...


Building exemplar bank:   1%|▏         | 5/400 [00:00<00:17, 22.07it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:   3%|▎         | 12/400 [00:00<00:14, 27.39it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:   5%|▌         | 20/400 [00:00<00:12, 30.82it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:   6%|▌         | 24/400 [00:00<00:12, 30.78it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:   8%|▊         | 32/400 [00:01<00:12, 30.20it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:   9%|▉         | 36/400 [00:01<00:13, 27.14it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  11%|█         | 43/400 [00:01<00:12, 28.90it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  13%|█▎        | 51/400 [00:01<00:11, 31.16it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  15%|█▍        | 59/400 [00:02<00:10, 31.21it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  16%|█▌        | 63/400 [00:02<00:18, 18.41it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  16%|█▋        | 66/400 [00:02<00:19, 16.96it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  18%|█▊        | 71/400 [00:03<00:22, 14.66it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  18%|█▊        | 73/400 [00:03<00:24, 13.14it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  19%|█▉        | 75/400 [00:03<00:23, 14.13it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  20%|█▉        | 79/400 [00:04<00:37,  8.56it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  20%|██        | 81/400 [00:04<00:34,  9.16it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  21%|██▏       | 85/400 [00:04<00:31, 10.00it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  22%|██▏       | 88/400 [00:04<00:27, 11.37it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  24%|██▍       | 96/400 [00:05<00:15, 19.95it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  25%|██▍       | 99/400 [00:05<00:16, 18.01it/s]

Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds


Building exemplar bank:  25%|██▌       | 100/400 [00:05<00:16, 18.21it/s]


Error computing embedding: You have to specify either decoder_input_ids or decoder_inputs_embeds
Added 100 exemplars from dataset to retrieval module
Initializing Refinement Module...
Error initializing refinement module: 'TextRefinementTransformer' object has no attribute 'device'
Initializing Reward Function...
Initializing Self-Trainer...
Model has 222,903,552 trainable parameters

===== Running Quick Test =====
Test question: John has 5 apples. Mary gives him 3 more apples. How many apples does John have now?


KeyboardInterrupt: 