In [15]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AdamW,
    get_linear_schedule_with_warmup,
    T5ForConditionalGeneration,
    T5Tokenizer
)
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset
from tqdm.auto import tqdm
import re
from torch.optim import Adam
import random

In [16]:
def set_seed(seed=11):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed()

In [17]:
# Device configuration - M1 Mac specific
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple M1 MPS device")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device")
else:
    device = torch.device("cpu")
    print("Using CPU device")

Using Apple M1 MPS device


In [18]:
MAX_LENGTH = 768
BATCH_SIZE = 4
LEARNING_RATE = 2e-5
EPOCHS = 3
COT_PROMPT = "Let's solve this step-by-step. To find the answer, I'll break down the problem into smaller parts."
MODEL_NAME = "t5-base"
MAX_SAMPLES = 400

In [19]:
# Utility functions for extracting final answers and CoT steps
def extract_final_answer(answer_text):
    # Look for patterns like "The answer is X" or "Therefore, the answer is X"
    patterns = [
        r"The answer is\s*[-]?\s*\$?\s*([\d,\.]+)",
        r"Therefore,?\s.*?[-]?\s*\$?\s*([\d,\.]+)",
        r"So,?\s.*?[-]?\s*\$?\s*([\d,\.]+)",
        r"Thus,?\s.*?[-]?\s*\$?\s*([\d,\.]+)",
        r"The final answer is\s*[-]?\s*\$?\s*([\d,\.]+)",
        # Add a pattern to catch just the last number in the text
        r".*?([\d,\.]+)$"
    ]
    
    for pattern in patterns:
        matches = re.search(pattern, answer_text, re.DOTALL | re.IGNORECASE)
        if matches:
            return matches.group(1).strip()
    
    # If no patterns match, extract the last number in the text
    numbers = re.findall(r"\d+(?:,\d+)*(?:\.\d+)?", answer_text)
    if numbers:
        return numbers[-1].strip()
    
    # Last resort fallback
    return answer_text.strip().split("\n")[-1]

def extract_cot_steps(answer_text):
    # Remove the final answer part
    final_answer_match = re.search(r"The answer is(.*?)$", answer_text, re.DOTALL)
    if final_answer_match:
        cot_text = answer_text[:final_answer_match.start()].strip()
    else:
        # If no "The answer is" pattern, assume the last line is the answer
        lines = answer_text.strip().split("\n")
        cot_text = "\n".join(lines[:-1]) if len(lines) > 1 else ""
    
    # Split into steps
    steps = [step.strip() for step in cot_text.split("\n") if step.strip()]
    return steps

# Safe device transfer function for M1 MPS
def to_device(tensor_or_module):
    """Safely move tensors or modules to the selected device"""
    if tensor_or_module is None:
        return None
    
    device = torch.device("cuda" if torch.cuda.is_available() else 
                         "mps" if torch.backends.mps.is_available() else "cpu")
    
    try:
        return tensor_or_module.to(device)
    except Exception as e:
        print(f"Warning: Could not move to {device}: {e}")
        return tensor_or_module

In [20]:
class GSM8KDataset(Dataset):
    def __init__(self, split="train", tokenizer=None, max_length=768, max_samples=None):
        self.data = load_dataset("gsm8k", "main")[split]
        if max_samples:
            self.data = self.data.select(range(min(max_samples, len(self.data))))
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.processed_data = self.preprocess_data()
        
    def preprocess_data(self):
        processed = []
        for item in tqdm(self.data, desc="Preprocessing data"):
            question = item["question"]
            answer_with_cot = item["answer"]
            
            # Extract the CoT steps and the final answer
            final_answer = extract_final_answer(answer_with_cot)
            cot_steps = extract_cot_steps(answer_with_cot)
            
            # Format for T5 training - improved prompt to guide the model
            formatted_question = f"Solve this math problem step-by-step: {question} {COT_PROMPT}"
            
            processed.append({
                "question": question,
                "formatted_question": formatted_question,
                "cot_steps": cot_steps,
                "final_answer": final_answer,
                "full_answer": answer_with_cot
            })
        return processed
    
    def __len__(self):
        return len(self.processed_data)
    
    def __getitem__(self, idx):
        item = self.processed_data[idx]
        
        if self.tokenizer:
            # Prepare input with task-specific prefix for T5
            input_text = item["formatted_question"]
            target_text = item["full_answer"]
            
            # Improved tokenization with more balanced token allocation
            try:
                inputs = self.tokenizer(
                    input_text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length // 3,  # Allow more space for output
                    return_tensors="pt"
                )
                
                targets = self.tokenizer(
                    target_text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length * 2 // 3,  # Allow more space for reasoning
                    return_tensors="pt"
                )
                
                return {
                    "input_ids": inputs.input_ids.squeeze(),
                    "attention_mask": inputs.attention_mask.squeeze(),
                    "labels": targets.input_ids.squeeze(),
                    "raw_question": item["question"],
                    "raw_cot": item["cot_steps"],
                    "raw_answer": item["final_answer"]
                }
            except Exception as e:
                print(f"Error tokenizing item {idx}: {e}")
                # Return a simple fallback
                dummy_tensor = torch.zeros(self.max_length, dtype=torch.long)
                return {
                    "input_ids": dummy_tensor,
                    "attention_mask": dummy_tensor,
                    "labels": dummy_tensor,
                    "raw_question": item["question"],
                    "raw_cot": item["cot_steps"],
                    "raw_answer": item["final_answer"]
                }
        else:
            return item

In [None]:
class CoTGenerator:
    def __init__(self, model_name="t5-base", local_dir="./models/t5_base_cache"):
        self.model_name = model_name
        self.local_dir = local_dir
        
        # Create the directory if it doesn't exist
        os.makedirs(self.local_dir, exist_ok=True)
        
        print(f"Loading model {model_name}...")
        
        # Check if model is already saved locally
        if os.path.exists(os.path.join(self.local_dir, "pytorch_model.bin")) and \
           os.path.exists(os.path.join(self.local_dir, "tokenizer_config.json")):
            print(f"Found existing model at {self.local_dir}. Loading locally...")
            self._load_local_model()
        else:
            print(f"Model not found locally. Downloading {model_name}...")
            self._download_model()
    
    def _download_model(self):
        try:
            # Download tokenizer and save it immediately
            self.tokenizer = T5Tokenizer.from_pretrained(
                self.model_name,
                cache_dir=self.local_dir,
                use_fast=True
            )
            print(f"Tokenizer downloaded and saved to {self.local_dir}")
            
            # Download model and save it immediately
            try:
                device = torch.device("cuda" if torch.cuda.is_available() else 
                                    "mps" if torch.backends.mps.is_available() else "cpu")
                self.model = T5ForConditionalGeneration.from_pretrained(
                    self.model_name,
                    cache_dir=self.local_dir,
                    low_cpu_mem_usage=True,
                    torch_dtype=torch.float16 if device.type != "cpu" else torch.float32
                )
                self.model = self.model.to(device)
                print(f"Model downloaded and moved to {device}")
            except Exception as e:
                print(f"Error loading model to device: {e}")
                print("Falling back to CPU")
                self.model = T5ForConditionalGeneration.from_pretrained(
                    self.model_name,
                    cache_dir=self.local_dir
                )
                print(f"Model downloaded (CPU version)")
        except Exception as e:
            print(f"Error downloading model: {e}")
            raise e
    
    def _load_local_model(self):
        try:
            # Load locally saved tokenizer
            self.tokenizer = T5Tokenizer.from_pretrained(self.local_dir)
            
            # Load locally saved model
            try:
                device = torch.device("cuda" if torch.cuda.is_available() else 
                                    "mps" if torch.backends.mps.is_available() else "cpu")
                self.model = T5ForConditionalGeneration.from_pretrained(
                    self.local_dir,
                    torch_dtype=torch.float16 if device.type != "cpu" else torch.float32
                )
                self.model = self.model.to(device)
                print(f"Model loaded from {self.local_dir} and moved to {device}")
            except Exception as e:
                print(f"Error loading model to device: {e}")
                print("Falling back to CPU")
                self.model = T5ForConditionalGeneration.from_pretrained(self.local_dir)
                print(f"Model loaded from {self.local_dir} (CPU version)")
        except Exception as e:
            print(f"Error loading local model: {e}")
            print("Will attempt to download from source...")
            self._download_model()
    
    def generate_step(self, question, previous_steps=None, max_length=256):
        """Generate a single reasoning step based on the question and previous steps"""
        if previous_steps is None:
            previous_steps = []
        
        # Construct prompt with previous steps
        previous_steps_text = ""
        if previous_steps:
            previous_steps_text = "Steps so far:\n" + "\n".join([f"{i+1}. {step}" for i, step in enumerate(previous_steps)])
            previous_steps_text += "\nNext step:"
        
        # Build the full prompt
        if previous_steps:
            input_text = f"Solve this math problem step-by-step: {question}\n{previous_steps_text}"
        else:
            input_text = f"Solve this math problem step-by-step: {question}\nFirst step:"
        
        try:
            inputs = self.tokenizer(
                input_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=max_length // 2  # Use half of max length for input
            )
            device = self.model.device
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            # Generate the next step with focused parameters
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=max_length,
                    num_return_sequences=1,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    top_k=40,
                    num_beams=3,
                    early_stopping=True
                )
            
            # Decode the generated text
            decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Clean up the output (get just the next step)
            step = decoded_output.strip()
            
            # Check if this is a final answer step
            is_final_step = ("answer is" in step.lower() or 
                             "therefore" in step.lower() or 
                             "thus" in step.lower() or
                             "final answer" in step.lower())
            
            return {
                "step": step,
                "is_final_step": is_final_step
            }
        except Exception as e:
            print(f"Error generating step: {e}")
            return {
                "step": "Error generating step",
                "is_final_step": True  # Force termination on error
            }
    
    def evaluate_step(self, step, question, reflection_module=None):
        """Evaluate if a generated step is good or needs refinement"""
        if reflection_module:
            # Use the reflection module if provided
            score = reflection_module.evaluate_step(step, question)
            return score >= 0.5  # Return True if score is good enough
        
        # Basic heuristic evaluation if no reflection module
        # Check if the step contains numbers and mathematical operations
        has_numbers = bool(re.search(r'\d+', step))
        has_math_ops = bool(re.search(r'[+\-*/=]', step))
        reasonable_length = 10 <= len(step.split()) <= 100
        
        return has_numbers and has_math_ops and reasonable_length
    
    def refine_step(self, question, step, previous_steps=None):
        """Refine a step that didn't pass evaluation"""
        if previous_steps is None:
            previous_steps = []
        
        # Create a prompt asking for improvement
        previous_steps_text = ""
        if previous_steps:
            previous_steps_text = "Steps so far:\n" + "\n".join([f"{i+1}. {step}" for i, step in enumerate(previous_steps)])
        
        input_text = (f"Solve this math problem step-by-step: {question}\n"
                     f"{previous_steps_text}\n"
                     f"The following step needs improvement: {step}\n"
                     f"Improved step:")
        
        try:
            inputs = self.tokenizer(
                input_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=512
            )
            device = self.model.device
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=256,
                    do_sample=True,
                    temperature=0.6
                )
            
            refined_step = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
            return refined_step
        except Exception as e:
            print(f"Error refining step: {e}")
            return step  # Return original if refinement fails
    
    def generate(self, question, max_length=768, cot_prompt=None, reflection_module=None, max_steps=8):
        """Generate a full chain-of-thought reasoning process step by step"""
        cot_steps = []
        final_answer = None
        
        # Generate steps iteratively
        for step_num in range(max_steps):
            # Generate the next reasoning step
            step_result = self.generate_step(question, cot_steps)
            current_step = step_result["step"]
            
            # Evaluate the step quality
            step_is_good = self.evaluate_step(current_step, question, reflection_module)
            
            # Refine the step if needed (up to 2 attempts)
            refinement_attempts = 0
            while not step_is_good and refinement_attempts < 2:
                refined_step = self.refine_step(question, current_step, cot_steps)
                current_step = refined_step
                step_is_good = self.evaluate_step(current_step, question, reflection_module)
                refinement_attempts += 1
            
            # Add the step to our chain of thought
            cot_steps.append(current_step)
            
            # Check if this is the final step
            if step_result["is_final_step"] or step_num == max_steps - 1:
                break
        
        # Generate a final answer if needed
        if not any("answer is" in step.lower() for step in cot_steps):
            final_step_result = self.generate_final_answer(question, cot_steps)
            cot_steps.append(final_step_result["step"])
        
        # Extract the final answer using the dedicated function
        full_output = "\n".join(cot_steps)
        final_answer = extract_final_answer(full_output)
        
        return {
            "cot_steps": cot_steps,
            "final_answer": final_answer,
            "full_output": full_output
        }
    
    def generate_final_answer(self, question, cot_steps):
        """Generate a final answer step based on previous reasoning steps"""
        steps_text = "\n".join([f"{i+1}. {step}" for i, step in enumerate(cot_steps)])
        
        input_text = (f"Solve this math problem step-by-step: {question}\n"
                     f"Steps so far:\n{steps_text}\n"
                     f"Final answer:")
        
        try:
            inputs = self.tokenizer(
                input_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=512
            )
            device = self.model.device
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=128,
                    do_sample=False,  # More deterministic for final answer
                    num_beams=3
                )
            
            final_step = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
            
            # Make sure it starts with "The answer is" if it doesn't already
            if not any(phrase in final_step.lower() for phrase in ["answer is", "therefore", "thus", "final answer"]):
                final_step = "The answer is " + final_step
                
            return {
                "step": final_step,
                "is_final_step": True
            }
        except Exception as e:
            print(f"Error generating final answer: {e}")
            return {
                "step": "The answer is unknown due to an error.",
                "is_final_step": True
            }
    
    def save(self, path=None):
        save_path = path if path else self.local_dir
        try:
            # Move to CPU before saving to avoid device-specific tensors in saved model
            cpu_model = self.model.to("cpu")
            cpu_model.save_pretrained(save_path)
            self.tokenizer.save_pretrained(save_path)
            print(f"Model saved to {save_path}")
            # Move back to device
            device = torch.device("cuda" if torch.cuda.is_available() else 
                               "mps" if torch.backends.mps.is_available() else "cpu")
            self.model = self.model.to(device)
        except Exception as e:
            print(f"Error saving model: {e}")
    
    def load(self, path=None):
        load_path = path if path else self.local_dir
        try:
            self.tokenizer = T5Tokenizer.from_pretrained(load_path)
            device = torch.device("cuda" if torch.cuda.is_available() else 
                               "mps" if torch.backends.mps.is_available() else "cpu")
            self.model = T5ForConditionalGeneration.from_pretrained(load_path)
            self.model = self.model.to(device)
            print(f"Model loaded from {load_path} and moved to {device}")
        except Exception as e:
            print(f"Error loading model: {e}")

In [22]:
class ReflectionModule(nn.Module):
    def __init__(self, embedding_dim=768):  # Increased embedding size
        super(ReflectionModule, self).__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=8, batch_first=True),
            num_layers=2  # Increased layers for better analysis
        )
        self.fc = nn.Linear(embedding_dim, 1)
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
    def forward(self, step_embeddings):
        # step_embeddings shape: [batch_size, seq_len, embedding_dim]
        encoded = self.encoder(step_embeddings)
        # Take the mean over the sequence dimension
        pooled = encoded.mean(dim=1)
        # Output a scalar reward score for each step
        score = torch.sigmoid(self.fc(pooled))
        return score
    
    def evaluate_step(self, step_text, question_context):
        # Tokenize and get embeddings for the reasoning step
        combined_text = f"{question_context} Step: {step_text}"
        try:
            tokens = self.tokenizer(
                combined_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=MAX_LENGTH
            )
            tokens = {k: to_device(v) for k, v in tokens.items()}
            
            with torch.no_grad():
                embeddings = self.get_embeddings(tokens)
                score = self.forward(embeddings)
            
            return score.item()
        except Exception as e:
            print(f"Error in evaluate_step: {e}")
            return 0.5  # Default neutral score
    
    def get_embeddings(self, tokens):
        # This is a placeholder - in a real implementation, you'd use the 
        # encoder model to get token embeddings
        batch_size = tokens["input_ids"].shape[0]
        seq_len = tokens["input_ids"].shape[1]
        return to_device(torch.randn(batch_size, seq_len, 768))  # Increased to 768 dim


In [23]:
class RetrievalModule:
    def __init__(self, embedding_model_name=MODEL_NAME):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
        except Exception as e:
            print(f"Error loading tokenizer: {e}")
            self.tokenizer = None
        self.exemplar_bank = []  # Will contain (question, embedding, cot_sequence) tuples
        
    def add_exemplar(self, question, cot_sequence):
        embedding = self.compute_embedding(question)
        self.exemplar_bank.append((question, embedding, cot_sequence))
    
    def initialize_from_dataset(self, dataset, max_exemplars=200):  # Increased for better retrieval
        """Initialize the retrieval module with examples from a dataset"""
        for i, item in enumerate(tqdm(dataset, desc="Building exemplar bank")):
            if i >= max_exemplars:
                break
            try:
                self.add_exemplar(item["raw_question"], item["raw_cot"])
            except Exception as e:
                print(f"Error adding exemplar {i}: {e}")
    
    def compute_embedding(self, text):
        # More sophisticated embedding approach
        try:
            if self.tokenizer:
                # Tokenize and encode the text
                tokens = self.tokenizer(
                    text, 
                    padding=True, 
                    truncation=True, 
                    max_length=MAX_LENGTH, 
                    return_tensors="pt"
                )
                
                # Generate a deterministic but improved embedding based on token IDs
                token_ids = tokens["input_ids"].numpy().flatten()
                
                # Use a more nuanced approach for embeddings
                embedding = np.zeros(768)  # Increased dimension
                
                # Apply positional weighting (tokens at beginning and end often carry more meaning)
                for i, token_id in enumerate(token_ids):
                    if i < 768:
                        # Apply higher weights to beginning and end tokens
                        position_weight = 1.0
                        if i < len(token_ids) * 0.2 or i > len(token_ids) * 0.8:
                            position_weight = 1.5
                        embedding[i % 768] += token_id * position_weight
                
                # Normalize
                norm = np.linalg.norm(embedding)
                if norm > 0:
                    embedding = embedding / norm
                
                return embedding
            else:
                return np.random.randn(768)  # Fallback
        except Exception as e:
            print(f"Error computing embedding: {e}")
            return np.random.randn(768)
    
    def retrieve_similar_exemplars(self, question, k=5):  # Increased k for better diversity
        query_embedding = self.compute_embedding(question)
        
        # Compute similarities
        similarities = []
        for _, exemplar_embedding, _ in self.exemplar_bank:
            try:
                sim = cosine_similarity([query_embedding], [exemplar_embedding])[0][0]
                similarities.append(sim)
            except Exception as e:
                print(f"Error computing similarity: {e}")
                similarities.append(0.0)
        
        # Get top-k indices
        if not similarities:
            return []
        
        top_indices = np.argsort(similarities)[-k:][::-1]
        
        # Return the corresponding exemplars
        return [self.exemplar_bank[i][2] for i in top_indices]

In [24]:
class TextRefinementTransformer(nn.Module):
    def __init__(self, model_name=MODEL_NAME):
        super(TextRefinementTransformer, self).__init__()
        try:
            self.tokenizer = T5Tokenizer.from_pretrained(model_name)
            self.model = T5ForConditionalGeneration.from_pretrained(model_name)
            self.model = to_device(self.model)
        except Exception as e:
            print(f"Error initializing refinement module: {e}")
            self.tokenizer = None
            self.model = None
    
    def forward(self, input_ids, attention_mask, labels=None):
        if self.model is None:
            return None
        try:
            return self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
        except Exception as e:
            print(f"Error in forward pass: {e}")
            return None
    
    def refine_text(self, input_text, max_length=MAX_LENGTH):
        if self.model is None or self.tokenizer is None:
            return input_text  # Fallback: return original text
            
        try:
            # Add explicit instruction to improve quality
            enhanced_input = f"Improve this math solution with detailed step-by-step reasoning: {input_text}"
            
            inputs = self.tokenizer(
                enhanced_input, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=max_length // 2
            )
            inputs = {k: to_device(v) for k, v in inputs.items()}
            
            with torch.no_grad():  # No grad for inference
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=max_length,
                    num_return_sequences=1,
                    do_sample=True,
                    temperature=0.7,
                    min_length=100,  # Encourage detailed responses
                    num_beams=4
                )
            
            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        except Exception as e:
            print(f"Error refining text: {e}")
            return input_text  # Fallback
    
    def train_step(self, batch, optimizer):
        if self.model is None:
            return 0.0  # Fallback
            
        try:
            self.train()
            optimizer.zero_grad()
            
            # Move batch to device
            device_batch = {k: to_device(v) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            
            outputs = self.forward(
                input_ids=device_batch["input_ids"],
                attention_mask=device_batch["attention_mask"],
                labels=device_batch["labels"]
            )
            
            if outputs is None:
                return 0.0
                
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            return loss.item()
        except Exception as e:
            print(f"Error in train_step: {e}")
            return 0.0

In [25]:
class RewardFunction:
    def __init__(self, reflection_module):
        self.reflection_module = reflection_module
        
    def outcome_reward(self, predicted_answer, true_answer):
        """Improved reward based on partial matching and numeric comparison"""
        try:
            # Clean up answers - remove $, commas, and whitespace
            pred_clean = re.sub(r'[$,\s]', '', predicted_answer).strip()
            true_clean = re.sub(r'[$,\s]', '', true_answer).strip()
            
            # Exact match
            if pred_clean == true_clean:
                return 1.0
                
            # Both are numbers - check if they're close
            if pred_clean.replace('.', '', 1).isdigit() and true_clean.replace('.', '', 1).isdigit():
                try:
                    pred_num = float(pred_clean)
                    true_num = float(true_clean)
                    
                    # If they're within 1% of each other
                    if abs(pred_num - true_num) / max(abs(true_num), 1) < 0.01:
                        return 0.9
                        
                    # If they're within 10% of each other
                    if abs(pred_num - true_num) / max(abs(true_num), 1) < 0.1:
                        return 0.5
                    
                    # Partial credit for getting the order of magnitude right
                    if (true_num > 0 and pred_num > 0) or (true_num < 0 and pred_num < 0):
                        if 0.1 <= (pred_num / true_num) <= 10:
                            return 0.2
                    
                    # If they're way off
                    return 0.0
                except:
                    pass
            
            # No match
            return 0.0
        except Exception as e:
            print(f"Error in outcome_reward: {e}")
            return 0.0
    
    def process_reward(self, cot_steps, question):
        """Improved reward based on the quality of the reasoning process"""
        if not cot_steps:
            return 0.2  # Small baseline reward even with no steps
        
        try:
            # Check for math expressions, operations, and equations
            has_math_expressions = False
            for step in cot_steps:
                if re.search(r'[=+\-*/]', step) and re.search(r'\d', step):
                    has_math_expressions = True
                    break
            
            # Check for relevant keywords from the question in the steps
            question_words = set(re.findall(r'\b\w{4,}\b', question.lower()))
            step_text = ' '.join(cot_steps).lower()
            relevant_word_count = sum(1 for word in question_words if word in step_text)
            relevance_score = min(relevant_word_count / max(len(question_words), 1), 1.0)
            
            # Evaluate each step and take the mean
            step_scores = []
            for step in cot_steps:
                score = self.reflection_module.evaluate_step(step, question)
                step_scores.append(score)
            
            # More generous length bonus
            length_bonus = min(len(step_scores) / 3, 1.0)  # Bonus for 3+ steps, capped at 1.0
            
            # Bonus for using numbers in steps
            has_numbers = any(bool(re.search(r'\d+', step)) for step in cot_steps)
            number_bonus = 0.2 if has_numbers else 0.0
            
            # Bonus for mathematical expressions
            math_bonus = 0.3 if has_math_expressions else 0.0
            
            # Return the mean score with bonuses
            base_score = sum(step_scores) / len(step_scores) if step_scores else 0.3
            return min(base_score * (1.0 + 0.3 * length_bonus + number_bonus + math_bonus) + 0.2 * relevance_score, 1.0)
        except Exception as e:
            print(f"Error in process_reward: {e}")
            return 0.3  # More generous fallback
    
    def combined_reward(self, cot_steps, predicted_answer, true_answer, question, alpha=0.7):
        """Combine process and outcome rewards with detailed logging"""
        try:
            outcome = self.outcome_reward(predicted_answer, true_answer)
            process = self.process_reward(cot_steps, question)
            
            # Higher weight on outcome now
            combined = alpha * outcome + (1 - alpha) * process
            
            # Debug info
            if random.random() < 0.05:  # Only print for 5% of examples to avoid log flooding
                print(f"\nQuestion: {question[:50]}...")
                print(f"Predicted: {predicted_answer}")
                print(f"True: {true_answer}")
                print(f"Outcome reward: {outcome:.2f}, Process reward: {process:.2f}, Combined: {combined:.2f}")
                if cot_steps:
                    print(f"First reasoning step: {cot_steps[0][:50]}...")
            
            return combined
        except Exception as e:
            print(f"Error in combined_reward: {e}")
            return 0.3  # More generous fallback

In [None]:
class SelfTrainer:
    def __init__(
        self, 
        cot_generator,
        reflection_module,
        retrieval_module,
        refinement_module,
        reward_function,
        tokenizer=None
    ):
        self.cot_generator = cot_generator
        self.reflection_module = reflection_module
        self.retrieval_module = retrieval_module
        self.refinement_module = refinement_module
        self.reward_function = reward_function
        try:
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(MODEL_NAME)
        except Exception as e:
            print(f"Error loading tokenizer: {e}")
            self.tokenizer = None
        
    def generate_pseudo_labels(self, dataset, threshold=0.3, max_samples=100):
        """Generate high-quality pseudo-labeled examples with better logging"""
        pseudo_labeled = []
        rewards = []
        
        # Sample more examples
        sample_count = min(max_samples, len(dataset))
        sample_indices = random.sample(range(len(dataset)), sample_count)
        
        for i in tqdm(sample_indices, desc="Generating pseudo-labels"):
            try:
                item = dataset[i]
                question = item["raw_question"]
                gold_answer = item["raw_answer"]
                
                # Generate CoT with the current model using a stronger prompt
                generated = self.cot_generator.generate(
                    question,
                    reflection_module=self.reflection_module  # Pass the reflection module
                )

                # Retrieve similar exemplars to guide generation
                similar_exemplars = self.retrieval_module.retrieve_similar_exemplars(question, k=3)
                exemplar_text = ""
                if similar_exemplars:
                    exemplar_text = "Here are examples of good reasoning:\n"
                    for i, exemplar in enumerate(similar_exemplars[:2]):  # Use top 2 exemplars
                        steps_text = "\n".join([f"Step {j+1}: {step}" for j, step in enumerate(exemplar)])
                        exemplar_text += f"Example {i+1}:\n{steps_text}\n\n"
                
                # Try to improve generation with exemplars if available
                if exemplar_text:
                    enhanced_prompt = f"{exemplar_text}\nNow solve this problem:\n{question}"
                    enhanced_generation = self.cot_generator.generate(
                        enhanced_prompt,
                        reflection_module=self.reflection_module
                    )
                    
                    # Calculate rewards for both generations and use the better one
                    base_reward = self.reward_function.combined_reward(
                        generated["cot_steps"], 
                        generated["final_answer"], 
                        gold_answer, 
                        question
                    )
                    
                    enhanced_reward = self.reward_function.combined_reward(
                        enhanced_generation["cot_steps"], 
                        enhanced_generation["final_answer"], 
                        gold_answer, 
                        question
                    )
                    
                    if enhanced_reward > base_reward:
                        generated = enhanced_generation
                        reward = enhanced_reward
                    else:
                        reward = base_reward
                else:
                    reward = self.reward_function.combined_reward(
                        generated["cot_steps"], 
                        generated["final_answer"], 
                        gold_answer, 
                        question
                    )
                
                # Try to refine the generated text if reward is mediocre
                if 0.3 <= reward <= 0.7:
                    refined_text = self.refinement_module.refine_text(generated["full_output"])
                    refined_steps = extract_cot_steps(refined_text)
                    refined_answer = extract_final_answer(refined_text)
                    
                    refined_reward = self.reward_function.combined_reward(
                        refined_steps, 
                        refined_answer, 
                        gold_answer, 
                        question
                    )
                    
                    # Use the refined version if it's better
                    if refined_reward > reward:
                        generated["cot_steps"] = refined_steps
                        generated["final_answer"] = refined_answer
                        generated["full_output"] = refined_text
                        reward = refined_reward
                
                # Keep examples that exceed the quality threshold
                if reward >= threshold:
                    pseudo_labeled.append({
                        "question": question,
                        "cot_steps": generated["cot_steps"],
                        "final_answer": generated["final_answer"],
                        "full_output": generated["full_output"],
                        "gold_answer": gold_answer,
                        "reward": reward
                    })
                    rewards.append(reward)
            except Exception as e:
                print(f"Error processing item {i}: {e}")
        
        # Log statistics
        if rewards:
            avg_reward = sum(rewards) / len(rewards)
            print(f"Generated {len(pseudo_labeled)} pseudo-labeled examples with average reward {avg_reward:.3f}")
        else:
            print("Failed to generate any valid pseudo-labeled examples")
            
        return pseudo_labeled
    
    def train_with_pseudo_labels(self, pseudo_labeled, learning_rate=1e-5, epochs=2, batch_size=4):
        """Train the model with pseudo-labeled examples"""
        if not pseudo_labeled:
            print("No pseudo-labeled examples available for training")
            return
        
        try:
            # Create a simple dataset from pseudo-labeled examples
            class PseudoLabeledDataset(Dataset):
                def __init__(self, examples, tokenizer, max_length=MAX_LENGTH):
                    self.examples = examples
                    self.tokenizer = tokenizer
                    self.max_length = max_length
                
                def __len__(self):
                    return len(self.examples)
                
                def __getitem__(self, idx):
                    item = self.examples[idx]
                    
                    input_text = f"Solve this math problem step-by-step: {item['question']}"
                    target_text = item["full_output"]
                    
                    inputs = self.tokenizer(
                        input_text,
                        padding="max_length",
                        truncation=True,
                        max_length=self.max_length // 3,
                        return_tensors="pt"
                    )
                    
                    targets = self.tokenizer(
                        target_text,
                        padding="max_length",
                        truncation=True,
                        max_length=self.max_length * 2 // 3,
                        return_tensors="pt"
                    )
                    
                    return {
                        "input_ids": inputs.input_ids.squeeze(),
                        "attention_mask": inputs.attention_mask.squeeze(),
                        "labels": targets.input_ids.squeeze(),
                    }
            
            # Initialize dataset and dataloader
            pseudo_dataset = PseudoLabeledDataset(pseudo_labeled, self.tokenizer)
            pseudo_dataloader = DataLoader(
                pseudo_dataset, 
                batch_size=batch_size,
                shuffle=True
            )
            
            # Set up optimizer
            optimizer = AdamW(self.cot_generator.model.parameters(), lr=learning_rate)
            
            # Training loop
            total_steps = len(pseudo_dataloader) * epochs
            print(f"Starting training on {len(pseudo_labeled)} examples for {epochs} epochs ({total_steps} steps)")
            
            device = self.cot_generator.model.device
            
            for epoch in range(epochs):
                epoch_loss = 0.0
                
                for step, batch in enumerate(tqdm(pseudo_dataloader, desc=f"Epoch {epoch+1}/{epochs}")):
                    # Move batch to device
                    batch = {k: v.to(device) for k, v in batch.items()}
                    
                    # Forward pass
                    self.cot_generator.model.train()
                    outputs = self.cot_generator.model(
                        input_ids=batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        labels=batch["labels"]
                    )
                    
                    loss = outputs.loss
                    
                    # Backward pass
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    epoch_loss += loss.item()
                    
                    # Log progress
                    if step % 10 == 0:
                        print(f"Step {step}/{len(pseudo_dataloader)}, Loss: {loss.item():.4f}")
                
                # End of epoch stats
                avg_epoch_loss = epoch_loss / len(pseudo_dataloader)
                print(f"Epoch {epoch+1}/{epochs} completed. Average loss: {avg_epoch_loss:.4f}")
                
                # Save checkpoint
                self.cot_generator.save(f"./models/recot_checkpoint_epoch_{epoch+1}")
                
            print("Training completed successfully!")
            
        except Exception as e:
            print(f"Error in train_with_pseudo_labels: {e}")
    
    def run_ppo_update(self, batch, ppo_steps=3, clip_param=0.2, value_loss_coef=0.5):
        """Perform PPO update on the model using the given batch"""
        try:
            device = self.cot_generator.model.device
            
            # Prepare data structures
            states = []  # Questions
            actions = []  # Generated CoT steps
            old_probs = []  # Log probs of the old policy
            rewards = []  # Calculated rewards
            
            # Collect trajectories
            for item in tqdm(batch, desc="Collecting trajectories"):
                question = item["question"]
                gold_answer = item["gold_answer"]
                
                # Generate CoT with current policy and record log probs
                with torch.no_grad():
                    input_text = f"Solve this math problem step-by-step: {question}"
                    inputs = self.tokenizer(
                        input_text, 
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=MAX_LENGTH // 3
                    ).to(device)
                    
                    # Generate with model (record log probs)
                    self.cot_generator.model.eval()
                    outputs = self.cot_generator.model.generate(
                        input_ids=inputs["input_ids"],
                        attention_mask=inputs["attention_mask"],
                        max_length=MAX_LENGTH,
                        output_scores=True,
                        return_dict_in_generate=True,
                        do_sample=True,
                        temperature=0.7
                    )
                    
                    # Extract sequences and scores
                    sequences = outputs.sequences
                    generated_text = self.tokenizer.decode(sequences[0], skip_special_tokens=True)
                    
                    # Extract CoT steps and final answer
                    cot_steps = extract_cot_steps(generated_text)
                    final_answer = extract_final_answer(generated_text)
                    
                    # Calculate reward
                    reward = self.reward_function.combined_reward(
                        cot_steps, 
                        final_answer, 
                        gold_answer, 
                        question
                    )
                    
                    # Store trajectory
                    states.append(inputs)
                    actions.append(sequences)
                    old_probs.append(outputs.scores)
                    rewards.append(reward)
            
            # PPO update loop
            for _ in range(ppo_steps):
                # Randomly sample mini-batches for update
                indices = torch.randperm(len(states))
                
                for idx in indices:
                    # Get trajectory
                    state = states[idx]
                    action = actions[idx]
                    old_prob = old_probs[idx]
                    reward = rewards[idx]
                    
                    # Forward pass with current policy
                    self.cot_generator.model.train()
                    current_outputs = self.cot_generator.model.forward(
                        input_ids=state["input_ids"],
                        attention_mask=state["attention_mask"],
                        labels=action
                    )
                    
                    # Extract policy loss (simplified PPO implementation)
                    current_loss = current_outputs.loss
                    
                    # Calculate ratio and clipped surrogate objective
                    ratio = torch.exp(current_loss - old_prob.mean())
                    clipped_ratio = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param)
                    
                    # Calculate losses
                    policy_loss = -torch.min(ratio * reward, clipped_ratio * reward).mean()
                    
                    # Simple value function (could be more sophisticated)
                    value_loss = value_loss_coef * ((current_outputs.loss - reward) ** 2).mean()
                    
                    # Total loss
                    total_loss = policy_loss + value_loss
                    
                    # Backward pass
                    self.cot_generator.model.zero_grad()
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.cot_generator.model.parameters(), 1.0)
                    
                    # Optimizer step (using the generator's optimizer)
                    optimizer = AdamW(self.cot_generator.model.parameters(), lr=1e-5)
                    optimizer.step()
            
            print("PPO update completed")
            
        except Exception as e:
            print(f"Error in PPO update: {e}")
    
    def evaluate(self, test_dataset, num_samples=50):
        """Evaluate the current model on a test dataset"""
        correct = 0
        total = 0
        rewards = []
        
        # Sample a subset for efficient evaluation
        sample_indices = random.sample(range(len(test_dataset)), min(num_samples, len(test_dataset)))
        
        for idx in tqdm(sample_indices, desc="Evaluating"):
            try:
                item = test_dataset[idx]
                question = item["raw_question"]
                gold_answer = item["raw_answer"]
                
                # Generate answer
                generation = self.cot_generator.generate(question)
                predicted_answer = generation["final_answer"]
                
                # Check if correct
                reward = self.reward_function.outcome_reward(predicted_answer, gold_answer)
                rewards.append(reward)
                
                # Binary correctness (1.0 means exactly correct)
                if reward >= 0.9:
                    correct += 1
                total += 1
                
            except Exception as e:
                print(f"Error evaluating example {idx}: {e}")
        
        # Calculate metrics
        accuracy = correct / total if total > 0 else 0
        avg_reward = sum(rewards) / len(rewards) if rewards else 0
        
        print(f"Evaluation results:")
        print(f"  Accuracy: {accuracy:.2f} ({correct}/{total})")
        print(f"  Average reward: {avg_reward:.3f}")
        
        return {
            "accuracy": accuracy,
            "avg_reward": avg_reward,
            "correct": correct,
            "total": total
        }
    
    def run_training_loop(self, train_dataset, test_dataset, num_iterations=5, pseudo_samples=100, batch_size=4):
        """Run the full training loop with pseudo-labeling and RL updates"""
        try:
            results_history = []
            
            # Initial evaluation
            print("Initial model evaluation:")
            results = self.evaluate(test_dataset)
            results_history.append(results)
            
            for iteration in range(num_iterations):
                print(f"\n===== Iteration {iteration+1}/{num_iterations} =====")
                
                # 1. Generate pseudo-labeled examples
                print(f"Generating {pseudo_samples} pseudo-labeled examples...")
                pseudo_labeled = self.generate_pseudo_labels(
                    train_dataset, 
                    threshold=0.3,  # Lower threshold initially to get more examples 
                    max_samples=pseudo_samples
                )
                
                if not pseudo_labeled:
                    print("No pseudo-labeled examples generated. Skipping iteration.")
                    continue
                
                # 2. Train with pseudo-labeled examples
                print(f"Training with {len(pseudo_labeled)} pseudo-labeled examples...")
                self.train_with_pseudo_labels(
                    pseudo_labeled,
                    learning_rate=2e-5,
                    epochs=1,
                    batch_size=batch_size
                )
                
                # 3. Perform PPO update
                print("Performing PPO update...")
                # Filter high-quality examples for PPO
                high_quality = [ex for ex in pseudo_labeled if ex["reward"] >= 0.5]
                if high_quality:
                    self.run_ppo_update(high_quality)
                
                # 4. Update retrieval module with new exemplars
                print("Updating retrieval module...")
                for example in pseudo_labeled:
                    if example["reward"] >= 0.7:  # Only add high-quality examples
                        self.retrieval_module.add_exemplar(
                            example["question"],
                            example["cot_steps"]
                        )
                
                # 5. Save checkpoint
                checkpoint_path = f"./models/recot_checkpoint_iteration_{iteration+1}"
                print(f"Saving checkpoint to {checkpoint_path}...")
                self.cot_generator.save(checkpoint_path)
                
                # 6. Evaluate progress
                print("Evaluating current model:")
                results = self.evaluate(test_dataset)
                results_history.append(results)
                
                # Print improvement
                if len(results_history) > 1:
                    prev = results_history[-2]["accuracy"]
                    curr = results_history[-1]["accuracy"]
                    diff = curr - prev
                    print(f"Accuracy change: {diff:+.2f} ({prev:.2f} → {curr:.2f})")
            
            # Final evaluation
            print("\n===== Final Evaluation =====")
            final_results = self.evaluate(test_dataset, num_samples=100)  # More samples for final eval
            
            # Print training summary
            print("\n===== Training Summary =====")
            print(f"Initial accuracy: {results_history[0]['accuracy']:.2f}")
            print(f"Final accuracy: {final_results['accuracy']:.2f}")
            print(f"Improvement: {final_results['accuracy'] - results_history[0]['accuracy']:+.2f}")
            
            return final_results
            
        except Exception as e:
            print(f"Error in training loop: {e}")
            return None
        
    def count_parameters(self):
        """Count the number of trainable parameters in the model"""
        model = self.cot_generator.model
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
# Set random seeds for reproducibility
random.seed(11)
np.random.seed(11)
torch.manual_seed(11)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(11)

# Print system information
device = torch.device("cuda" if torch.cuda.is_available() else 
                        "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Create output directories
os.makedirs("./models", exist_ok=True)
os.makedirs("./results", exist_ok=True)

print("\n===== Initializing Components =====")

# Initialize tokenizer
print("Initializing tokenizer...")
try:
    tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
except Exception as e:
    print(f"Error initializing tokenizer: {e}")
    tokenizer = None

# Initialize data
print("Loading data...")
try:
    train_dataset = GSM8KDataset(
        split="train", 
        tokenizer=tokenizer, 
        max_length=MAX_LENGTH, 
        max_samples=MAX_SAMPLES
    )
    test_dataset = GSM8KDataset(
        split="test", 
        tokenizer=tokenizer, 
        max_length=MAX_LENGTH, 
        max_samples=MAX_SAMPLES // 4
    )
    print(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples")
except Exception as e:
    print(f"Error loading datasets: {e}")
    raise e

# Initialize CoT Generator (main model)
print("Initializing CoT Generator...")
cot_generator = CoTGenerator(MODEL_NAME)

# Initialize Reflection Module
print("Initializing Reflection Module...")
reflection_module = ReflectionModule()
reflection_module = to_device(reflection_module)

# Initialize Retrieval Module
print("Initializing Retrieval Module...")
retrieval_module = RetrievalModule()
# Populate retrieval module with examples from the dataset
retrieval_module.initialize_from_dataset(train_dataset, max_exemplars=100)

# Initialize Refinement Module (transformer for text refinement)
print("Initializing Refinement Module...")
refinement_module = TextRefinementTransformer()

# Initialize Reward Function
print("Initializing Reward Function...")
reward_function = RewardFunction(reflection_module)

# Initialize Self-Trainer
print("Initializing Self-Trainer...")
self_trainer = SelfTrainer(
    cot_generator=cot_generator,
    reflection_module=reflection_module,
    retrieval_module=retrieval_module,
    refinement_module=refinement_module,
    reward_function=reward_function,
    tokenizer=tokenizer
)

num_params = self_trainer.count_parameters()
print(f"Model has {num_params:,} trainable parameters")


# Run a quick test to ensure everything is working
print("\n===== Running Quick Test =====")
try:
    # Test CoT generation
    test_question = "John has 5 apples. Mary gives him 3 more apples. How many apples does John have now?"
    print(f"Test question: {test_question}")
    
    generation = cot_generator.generate(test_question, reflection_module=reflection_module)
    print(f"Generated CoT:")
    for i, step in enumerate(generation["cot_steps"]):
        print(f"  Step {i+1}: {step}")
    print(f"Final answer: {generation['final_answer']}")
    
    # Test evaluation
    print("Testing evaluation...")
    test_results = self_trainer.evaluate(test_dataset, num_samples=5)
    print(f"Test evaluation completed with accuracy: {test_results['accuracy']:.2f}")
except Exception as e:
    print(f"Error in quick test: {e}")
    import traceback
    traceback.print_exc()

# Run the main training loop
print("\n===== Starting Main Training Loop =====")
try:
    training_results = self_trainer.run_training_loop(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        num_iterations=3,  # Reduced number of iterations for initial testing
        pseudo_samples=50,  # Start with fewer samples
        batch_size=BATCH_SIZE
    )
    
    # Save final model
    print("Saving final model...")
    cot_generator.save("./models/recot_final_model")
    
    print("\n===== Training Complete =====")
    if training_results:
        print(f"Final accuracy: {training_results['accuracy']:.4f}")
    
except Exception as e:
    print(f"Error in training loop: {e}")
    import traceback
    traceback.print_exc()

# Example of using the trained model
print("\n===== Example Usage =====")
try:
    example_questions = [
        "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast every morning and bakes muffins for her friends every day with 4 eggs per batch. She bakes 2 batches of muffins per day. How many eggs does Janet have left each day?",
        "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?",
        "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?"
    ]
    
    for i, question in enumerate(example_questions):
        print(f"\nExample {i+1}: {question}")
        result = cot_generator.generate(question, reflection_module=reflection_module)
        print("Generated reasoning:")
        for j, step in enumerate(result["cot_steps"]):
            print(f"  Step {j+1}: {step}")
        print(f"Final answer: {result['final_answer']}")
        
except Exception as e:
    print(f"Error in example usage: {e}")

Using device: mps
PyTorch version: 2.6.0
CUDA available: False

===== Initializing Components =====
Initializing tokenizer...
Loading data...


Preprocessing data: 100%|██████████| 400/400 [00:00<00:00, 14236.92it/s]
Preprocessing data: 100%|██████████| 100/100 [00:00<00:00, 10647.06it/s]


Loaded 400 training examples and 100 test examples
Initializing CoT Generator...
Loading model t5-base...
Model not found locally. Downloading t5-base...
Tokenizer downloaded and saved to ./models/t5_base_cache
Model downloaded and moved to mps
Initializing Reflection Module...
Initializing Retrieval Module...


Building exemplar bank:  25%|██▌       | 100/400 [00:00<00:00, 318.88it/s]


Initializing Refinement Module...
Initializing Reward Function...
Initializing Self-Trainer...
Model has 222,903,552 trainable parameters
