In [1]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AdamW,
    get_linear_schedule_with_warmup,
    T5ForConditionalGeneration,
    T5Tokenizer
)
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset
from tqdm.auto import tqdm
import re
from torch.optim import Adam
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed()

In [3]:
# Device configuration - M1 Mac specific
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple M1 MPS device")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device")
else:
    device = torch.device("cpu")
    print("Using CPU device")

Using Apple M1 MPS device


In [4]:
# Constants
MAX_LENGTH = 512
BATCH_SIZE = 4  # Reduced batch size for M1
LEARNING_RATE = 5e-5
EPOCHS = 2  # Reduced epochs for faster testing on M1
COT_PROMPT = "Let's think step by step!"
MODEL_NAME = "t5-small"  # Using smaller model for M1 compatibility

In [5]:
# Utility functions for extracting final answers and CoT steps
def extract_final_answer(answer_text):
    # GSM8K typically has the answer after "The answer is"
    matches = re.search(r"The answer is\s*[-]?\s*\$?\s*([\d,\.]+)", answer_text, re.DOTALL | re.IGNORECASE)
    if matches:
        return matches.group(1).strip()
    
    # Look for "Therefore"
    matches = re.search(r"Therefore,?\s.*?[-]?\s*\$?\s*([\d,\.]+)", answer_text, re.DOTALL | re.IGNORECASE)
    if matches:
        return matches.group(1).strip()
    
    # Look for "So"
    matches = re.search(r"So,?\s.*?[-]?\s*\$?\s*([\d,\.]+)", answer_text, re.DOTALL | re.IGNORECASE)
    if matches:
        return matches.group(1).strip()
    
    # Fallback: last numeric value in the text
    numbers = re.findall(r"\d+(?:,\d+)*(?:\.\d+)?", answer_text)
    if numbers:
        return numbers[-1].strip()
    
    return answer_text.strip().split("\n")[-1]  # Last line as fallback


def extract_cot_steps(answer_text):
    # Remove the final answer part
    final_answer_match = re.search(r"The answer is(.*?)$", answer_text, re.DOTALL)
    if final_answer_match:
        cot_text = answer_text[:final_answer_match.start()].strip()
    else:
        # If no "The answer is" pattern, assume the last line is the answer
        lines = answer_text.strip().split("\n")
        cot_text = "\n".join(lines[:-1]) if len(lines) > 1 else ""
    
    # Split into steps
    steps = [step.strip() for step in cot_text.split("\n") if step.strip()]
    return steps

# Safe device transfer function for M1 MPS
def to_device(tensor_or_module):
    """Safely move tensors or modules to the selected device"""
    if tensor_or_module is None:
        return None
    
    try:
        return tensor_or_module.to(device)
    except Exception as e:
        print(f"Warning: Could not move to {device}: {e}")
        return tensor_or_module

In [6]:
# Load and prepare GSM8K dataset
class GSM8KDataset(Dataset):
    def __init__(self, split="train", tokenizer=None, max_length=512, max_samples=None):
        self.data = load_dataset("gsm8k", "main")[split]
        if max_samples:
            # Limit the number of samples for M1 Mac to reduce memory usage
            self.data = self.data.select(range(min(max_samples, len(self.data))))
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.processed_data = self.preprocess_data()
        
    def preprocess_data(self):
        processed = []
        for item in tqdm(self.data, desc="Preprocessing data"):
            question = item["question"]
            answer_with_cot = item["answer"]
            
            # Extract the CoT steps and the final answer
            final_answer = extract_final_answer(answer_with_cot)
            cot_steps = extract_cot_steps(answer_with_cot)
            
            processed.append({
                "question": question,
                "cot_steps": cot_steps,
                "final_answer": final_answer,
                "full_answer": answer_with_cot
            })
        return processed
    
    def __len__(self):
        return len(self.processed_data)
    
    def __getitem__(self, idx):
        item = self.processed_data[idx]
        
        if self.tokenizer:
            # Prepare input (question + CoT prompt)
            input_text = item["question"] + " " + COT_PROMPT
            target_text = item["full_answer"]
            
            # Handle tokenization with mps device
            try:
                inputs = self.tokenizer(
                    input_text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors="pt"
                )
                
                targets = self.tokenizer(
                    target_text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors="pt"
                )
                
                return {
                    "input_ids": inputs.input_ids.squeeze(),
                    "attention_mask": inputs.attention_mask.squeeze(),
                    "labels": targets.input_ids.squeeze(),
                    "raw_question": item["question"],
                    "raw_cot": item["cot_steps"],
                    "raw_answer": item["final_answer"]
                }
            except Exception as e:
                print(f"Error tokenizing item {idx}: {e}")
                # Return a simple fallback
                dummy_tensor = torch.zeros(self.max_length, dtype=torch.long)
                return {
                    "input_ids": dummy_tensor,
                    "attention_mask": dummy_tensor,
                    "labels": dummy_tensor,
                    "raw_question": item["question"],
                    "raw_cot": item["cot_steps"],
                    "raw_answer": item["final_answer"]
                }
        else:
            return item


In [7]:
class CoTGenerator:
    def __init__(self, model_name="t5-base", local_dir="./models/t5_base_cache"):
        self.model_name = model_name
        self.local_dir = local_dir
        self.model_path = os.path.join(local_dir, model_name.split('/')[-1])
        
        # Create the directory if it doesn't exist
        os.makedirs(self.local_dir, exist_ok=True)
        
        print(f"Loading model {model_name}...")
        
        # Check if model is already saved locally - look for specific files
        if os.path.exists(os.path.join(self.local_dir, "pytorch_model.bin")) and \
           os.path.exists(os.path.join(self.local_dir, "tokenizer_config.json")):
            print(f"Found existing model at {self.local_dir}. Loading locally...")
            self._load_local_model()
        else:
            print(f"Model not found locally. Downloading {model_name}...")
            self._download_model()
        
    def _download_model(self):
        try:
            # Use the download function from your script
            print(f"Downloading model {self.model_name} to {self.local_dir}...")
            
            # Download tokenizer and save it immediately
            self.tokenizer = T5Tokenizer.from_pretrained(
                self.model_name,
                cache_dir=self.local_dir,
                use_fast=True
            )
            print(f"Tokenizer downloaded and saved to {self.local_dir}")
            
            # Download model and save it immediately
            try:
                device = torch.device("cuda" if torch.cuda.is_available() else 
                                     "mps" if torch.backends.mps.is_available() else "cpu")
                self.model = T5ForConditionalGeneration.from_pretrained(
                    self.model_name,
                    cache_dir=self.local_dir,
                    low_cpu_mem_usage=True,  # Helpful for M1 Macs
                    torch_dtype=torch.float16  # Use half precision
                )
                self.model = self.model.to(device)
                print(f"Model downloaded and moved to {device}")
            except Exception as e:
                print(f"Error loading model to device: {e}")
                print("Falling back to CPU")
                self.model = T5ForConditionalGeneration.from_pretrained(
                    self.model_name,
                    cache_dir=self.local_dir
                )
                print(f"Model downloaded (CPU version)")
        except Exception as e:
            print(f"Error downloading model: {e}")
            raise e
    
    def _load_local_model(self):
        try:
            # Load locally saved tokenizer
            self.tokenizer = T5Tokenizer.from_pretrained(self.local_dir)
            
            # Load locally saved model
            try:
                device = torch.device("cuda" if torch.cuda.is_available() else 
                                     "mps" if torch.backends.mps.is_available() else "cpu")
                self.model = T5ForConditionalGeneration.from_pretrained(
                    self.local_dir,
                    torch_dtype=torch.float16
                )
                self.model = self.model.to(device)
                print(f"Model loaded from {self.local_dir} and moved to {device}")
            except Exception as e:
                print(f"Error loading model to device: {e}")
                print("Falling back to CPU")
                self.model = T5ForConditionalGeneration.from_pretrained(self.local_dir)
                print(f"Model loaded from {self.local_dir} (CPU version)")
        except Exception as e:
            print(f"Error loading local model: {e}")
            print("Will attempt to download from source...")
            self._download_model()
        
    def generate(self, question, max_length=512, cot_prompt="Let's think step by step."):
        input_text = f"{question} {cot_prompt}"
        
        try:
            inputs = self.tokenizer(
                input_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=max_length
            )
            device = self.model.device
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            # For MPS compatibility, use simpler generation parameters
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=max_length,
                    num_return_sequences=1,
                    do_sample=True,
                    temperature=0.7
                )
            
            # Decode the generated text
            decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract the CoT steps and final answer
            try:
                final_answer = self._extract_final_answer(decoded_output)
                cot_steps = self._extract_cot_steps(decoded_output)
            except Exception as e:
                print(f"Error extracting steps: {e}")
                # Fallback if extraction fails
                cot_steps = [decoded_output.strip()]
                final_answer = "Unable to extract final answer"
            
            return {
                "cot_steps": cot_steps,
                "final_answer": final_answer,
                "full_output": decoded_output
            }
        except Exception as e:
            print(f"Error in generation: {e}")
            return {
                "cot_steps": ["Error generating steps"],
                "final_answer": "Error",
                "full_output": f"Error: {str(e)}"
            }
    
    def _extract_final_answer(self, text):
        # Simple extraction: look for "The answer is" or similar patterns
        # You may need to customize this based on your model's output format
        if "The answer is" in text:
            return text.split("The answer is")[-1].strip()
        elif "Therefore," in text:
            return text.split("Therefore,")[-1].strip()
        else:
            # Return the last sentence as a fallback
            sentences = text.split('.')
            if sentences:
                return sentences[-1].strip()
            return text.strip()
    
    def _extract_cot_steps(self, text):
        # Simple extraction: split by numbering or line breaks
        # You may need to customize this based on your model's output format
        if any(f"{i}." in text for i in range(1, 10)):
            # Try to split by numbered steps
            steps = []
            for i in range(1, 10):
                pattern = f"{i}."
                next_pattern = f"{i+1}."
                if pattern in text:
                    start = text.find(pattern)
                    if next_pattern in text:
                        end = text.find(next_pattern)
                        steps.append(text[start:end].strip())
                    else:
                        # Last step or only one step
                        steps.append(text[start:].strip())
            return steps if steps else text.split("\n")
        else:
            # Fallback to splitting by newlines
            return [s.strip() for s in text.split("\n") if s.strip()]
    
    def save(self, path=None):
        save_path = path if path else self.local_dir
        try:
            # Move to CPU before saving to avoid device-specific tensors in saved model
            cpu_model = self.model.to("cpu")
            cpu_model.save_pretrained(save_path)
            self.tokenizer.save_pretrained(save_path)
            print(f"Model saved to {save_path}")
            # Move back to device
            device = torch.device("cuda" if torch.cuda.is_available() else 
                               "mps" if torch.backends.mps.is_available() else "cpu")
            self.model = self.model.to(device)
        except Exception as e:
            print(f"Error saving model: {e}")
    
    def load(self, path=None):
        load_path = path if path else self.local_dir
        try:
            self.tokenizer = T5Tokenizer.from_pretrained(load_path)
            device = torch.device("cuda" if torch.cuda.is_available() else 
                               "mps" if torch.backends.mps.is_available() else "cpu")
            self.model = T5ForConditionalGeneration.from_pretrained(load_path)
            self.model = self.model.to(device)
            print(f"Model loaded from {load_path} and moved to {device}")
        except Exception as e:
            print(f"Error loading model: {e}")

In [8]:
# 2. Reflection Module
class ReflectionModule(nn.Module):
    def __init__(self, embedding_dim=512):  # Smaller embedding for M1
        super(ReflectionModule, self).__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4, batch_first=True),
            num_layers=1  # Reduced layers for M1
        )
        self.fc = nn.Linear(embedding_dim, 1)
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
    def forward(self, step_embeddings):
        # step_embeddings shape: [batch_size, seq_len, embedding_dim]
        encoded = self.encoder(step_embeddings)
        # Take the mean over the sequence dimension
        pooled = encoded.mean(dim=1)
        # Output a scalar reward score for each step
        score = torch.sigmoid(self.fc(pooled))
        return score
    
    def evaluate_step(self, step_text, question_context):
        # Tokenize and get embeddings for the reasoning step
        combined_text = f"{question_context} Step: {step_text}"
        try:
            tokens = self.tokenizer(
                combined_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=MAX_LENGTH
            )
            tokens = {k: to_device(v) for k, v in tokens.items()}
            
            with torch.no_grad():
                embeddings = self.get_embeddings(tokens)
                score = self.forward(embeddings)
            
            return score.item()
        except Exception as e:
            print(f"Error in evaluate_step: {e}")
            return 0.5  # Default neutral score
    
    def get_embeddings(self, tokens):
        # This is a placeholder - in a real implementation, you'd use the 
        # encoder model to get token embeddings
        # For simplicity, we're using random embeddings of the right shape
        batch_size = tokens["input_ids"].shape[0]
        seq_len = tokens["input_ids"].shape[1]
        return to_device(torch.randn(batch_size, seq_len, 512))  # Changed to 512 dim

In [9]:
# 3. Retrieval Module
class RetrievalModule:
    def __init__(self, embedding_model_name=MODEL_NAME):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
        except Exception as e:
            print(f"Error loading tokenizer: {e}")
            # Fallback for M1
            self.tokenizer = None
        self.exemplar_bank = []  # Will contain (question, embedding, cot_sequence) tuples
        
    def add_exemplar(self, question, cot_sequence):
        embedding = self.compute_embedding(question)
        self.exemplar_bank.append((question, embedding, cot_sequence))
    
    def initialize_from_dataset(self, dataset, max_exemplars=100):  # Reduced for M1
        """Initialize the retrieval module with examples from a dataset"""
        for i, item in enumerate(tqdm(dataset, desc="Building exemplar bank")):
            if i >= max_exemplars:
                break
            try:
                self.add_exemplar(item["raw_question"], item["raw_cot"])
            except Exception as e:
                print(f"Error adding exemplar {i}: {e}")
    
    def compute_embedding(self, text):
        # For M1 compatibility, use a simplified embedding approach
        try:
            if self.tokenizer:
                # Tokenize and encode the text
                tokens = self.tokenizer(
                    text, 
                    padding=True, 
                    truncation=True, 
                    max_length=MAX_LENGTH, 
                    return_tensors="pt"
                )
                
                # Generate a deterministic but simplified embedding based on token IDs
                token_ids = tokens["input_ids"].numpy().flatten()
                # Use a simple hash-based approach for embeddings
                embedding = np.zeros(512)  # Reduced dimension for M1
                for i, token_id in enumerate(token_ids):
                    if i < 512:
                        embedding[i % 512] += token_id
                
                # Normalize
                norm = np.linalg.norm(embedding)
                if norm > 0:
                    embedding = embedding / norm
                
                return embedding
            else:
                return np.random.randn(512)  # Fallback
        except Exception as e:
            print(f"Error computing embedding: {e}")
            return np.random.randn(512)
    
    def retrieve_similar_exemplars(self, question, k=2):  # Reduced k for M1
        query_embedding = self.compute_embedding(question)
        
        # Compute similarities
        similarities = []
        for _, exemplar_embedding, _ in self.exemplar_bank:
            try:
                sim = cosine_similarity([query_embedding], [exemplar_embedding])[0][0]
                similarities.append(sim)
            except Exception as e:
                print(f"Error computing similarity: {e}")
                similarities.append(0.0)
        
        # Get top-k indices
        if not similarities:
            return []
        
        top_indices = np.argsort(similarities)[-k:][::-1]
        
        # Return the corresponding exemplars
        return [self.exemplar_bank[i][2] for i in top_indices]

In [10]:
# 4. Transformer-based Refinement Module (replacing the GAN module)
class TextRefinementTransformer(nn.Module):
    def __init__(self, model_name=MODEL_NAME):
        super(TextRefinementTransformer, self).__init__()
        try:
            self.tokenizer = T5Tokenizer.from_pretrained(model_name)
            self.model = T5ForConditionalGeneration.from_pretrained(model_name)
            self.model = to_device(self.model)
        except Exception as e:
            print(f"Error initializing refinement module: {e}")
            # Fallback - dummy model
            self.tokenizer = None
            self.model = None
    
    def forward(self, input_ids, attention_mask, labels=None):
        if self.model is None:
            return None
        try:
            return self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
        except Exception as e:
            print(f"Error in forward pass: {e}")
            return None
    
    def refine_text(self, input_text, max_length=MAX_LENGTH):
        if self.model is None or self.tokenizer is None:
            return input_text  # Fallback: return original text
            
        try:
            inputs = self.tokenizer(
                input_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=max_length
            )
            inputs = {k: to_device(v) for k, v in inputs.items()}
            
            with torch.no_grad():  # No grad for inference
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=max_length,
                    num_return_sequences=1,
                    do_sample=True,
                    temperature=0.7
                )
            
            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        except Exception as e:
            print(f"Error refining text: {e}")
            return input_text  # Fallback
    
    def train_step(self, batch, optimizer):
        if self.model is None:
            return 0.0  # Fallback
            
        try:
            self.train()
            optimizer.zero_grad()
            
            # Move batch to device
            device_batch = {k: to_device(v) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            
            outputs = self.forward(
                input_ids=device_batch["input_ids"],
                attention_mask=device_batch["attention_mask"],
                labels=device_batch["labels"]
            )
            
            if outputs is None:
                return 0.0
                
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            return loss.item()
        except Exception as e:
            print(f"Error in train_step: {e}")
            return 0.0

In [11]:
class RewardFunction:
    def __init__(self, reflection_module):
        self.reflection_module = reflection_module
        
    def outcome_reward(self, predicted_answer, true_answer):
        """Improved reward based on partial matching and numeric comparison"""
        try:
            # Clean up answers - remove $, commas, and whitespace
            pred_clean = re.sub(r'[$,\s]', '', predicted_answer).strip()
            true_clean = re.sub(r'[$,\s]', '', true_answer).strip()
            
            # Exact match
            if pred_clean == true_clean:
                return 1.0
                
            # Both are numbers - check if they're close
            if pred_clean.replace('.', '', 1).isdigit() and true_clean.replace('.', '', 1).isdigit():
                try:
                    pred_num = float(pred_clean)
                    true_num = float(true_clean)
                    
                    # If they're within 1% of each other
                    if abs(pred_num - true_num) / max(abs(true_num), 1) < 0.01:
                        return 0.9
                        
                    # If they're within 10% of each other
                    if abs(pred_num - true_num) / max(abs(true_num), 1) < 0.1:
                        return 0.7
                except:
                    pass
            
            # Partial string match - useful for non-numeric answers
            if pred_clean in true_clean or true_clean in pred_clean:
                return 0.8
                
            # No match
            return 0.0
        except Exception as e:
            print(f"Error in outcome_reward: {e}")
            return 0.0
    
    def process_reward(self, cot_steps, question):
        """Improved reward based on the quality of the reasoning process"""
        if not cot_steps:
            return 0.2  # Small baseline reward even with no steps
        
        try:
            # Evaluate each step and take the mean
            step_scores = []
            for step in cot_steps:
                score = self.reflection_module.evaluate_step(step, question)
                step_scores.append(score)
            
            # More generous length bonus
            length_bonus = min(len(step_scores) / 3, 1.0)  # Bonus for 3+ steps, capped at 1.0
            
            # Check for numerical values in steps (good sign for math problems)
            has_numbers = any(bool(re.search(r'\d+', step)) for step in cot_steps)
            number_bonus = 0.2 if has_numbers else 0.0
            
            # Return the mean score with bonuses
            base_score = sum(step_scores) / len(step_scores) if step_scores else 0.3
            return base_score * (1.0 + 0.5 * length_bonus + number_bonus)
        except Exception as e:
            print(f"Error in process_reward: {e}")
            return 0.3  # More generous fallback
    
    def combined_reward(self, cot_steps, predicted_answer, true_answer, question, alpha=0.6):
        """Combine process and outcome rewards with detailed logging"""
        try:
            outcome = self.outcome_reward(predicted_answer, true_answer)
            process = self.process_reward(cot_steps, question)
            
            # More weight on the process for early training
            combined = alpha * outcome + (1 - alpha) * process
            
            # Debug info
            if random.random() < 0.1:  # Only print for 10% of examples to avoid log flooding
                print(f"\nQuestion: {question[:50]}...")
                print(f"Predicted: {predicted_answer}")
                print(f"True: {true_answer}")
                print(f"Outcome reward: {outcome:.2f}, Process reward: {process:.2f}, Combined: {combined:.2f}")
                if cot_steps:
                    print(f"First reasoning step: {cot_steps[0][:50]}...")
            
            return combined
        except Exception as e:
            print(f"Error in combined_reward: {e}")
            return 0.3  # More generous fallback


In [12]:
# 6. Self-Training and Distillation Loop
class SelfTrainer:
    def __init__(
        self, 
        cot_generator,
        reflection_module,
        retrieval_module,
        refinement_module,
        reward_function,
        tokenizer=None
    ):
        self.cot_generator = cot_generator
        self.reflection_module = reflection_module
        self.retrieval_module = retrieval_module
        self.refinement_module = refinement_module
        self.reward_function = reward_function
        try:
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(MODEL_NAME)
        except Exception as e:
            print(f"Error loading tokenizer: {e}")
            self.tokenizer = None
        
    def generate_pseudo_labels(self, dataset, threshold=0.4, max_samples=50):  # Lower threshold!
        """Generate high-quality pseudo-labeled examples with better logging"""
        pseudo_labeled = []
        rewards = []
        
        # Limit the number of samples for M1
        sample_count = min(max_samples, len(dataset))
        sample_indices = random.sample(range(len(dataset)), sample_count)
        
        for i in tqdm(sample_indices, desc="Generating pseudo-labels"):
            try:
                item = dataset[i]
                question = item["raw_question"]
                gold_answer = item["raw_answer"]
                
                # Generate CoT with the current model
                generated = self.cot_generator.generate(question)
                cot_steps = generated["cot_steps"]
                predicted_answer = generated["final_answer"]
                
                # Calculate reward
                reward = self.reward_function.combined_reward(
                    cot_steps, predicted_answer, gold_answer, question
                )
                rewards.append(reward)
                
                # If the reward is above threshold, add to pseudo-labeled data
                if reward >= threshold:
                    pseudo_labeled.append({
                        "question": question,
                        "cot_steps": cot_steps,
                        "final_answer": predicted_answer,
                        "reward": reward
                    })
            except Exception as e:
                print(f"Error generating pseudo-label for item {i}: {e}")
        
        # Print reward statistics
        if rewards:
            print(f"Reward stats - Min: {min(rewards):.2f}, Max: {max(rewards):.2f}, Avg: {sum(rewards)/len(rewards):.2f}")
            print(f"Rewards histogram: {[(r, sum(1 for x in rewards if r-0.1 <= x < r+0.1)) for r in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]]}")
        
        # If we still don't have any examples, take the top 20% regardless of threshold
        if not pseudo_labeled and rewards:
            top_k = max(1, int(0.2 * len(rewards)))  # At least 1, up to 20% of examples
            top_indices = sorted(range(len(rewards)), key=lambda i: rewards[i], reverse=True)[:top_k]
            
            print(f"Forcing inclusion of top {top_k} examples regardless of threshold")
            for i in top_indices:
                if i < len(sample_indices):  # Make sure index is valid
                    idx = sample_indices[i]
                    if idx < len(dataset):  # Double check index
                        try:
                            item = dataset[idx]
                            question = item["raw_question"]
                            generated = self.cot_generator.generate(question)
                            
                            pseudo_labeled.append({
                                "question": question,
                                "cot_steps": generated["cot_steps"],
                                "final_answer": generated["final_answer"],
                                "reward": rewards[i]
                            })
                        except Exception as e:
                            print(f"Error adding forced example {i}: {e}")
        
        return pseudo_labeled
    
    def train_refinement_module(self, dataset, epochs=2):  # Reduced epochs for M1
        """Train the refinement module to improve CoT quality"""
        if self.refinement_module.model is None:
            print("Refinement module not initialized, skipping training")
            return
            
        # Create a smaller dataset for M1
        subset_size = min(100, len(dataset))  # Limit size for M1
        subset_indices = random.sample(range(len(dataset)), subset_size)
        subset = [dataset[i] for i in subset_indices]
        
        dataloader = DataLoader(subset, batch_size=BATCH_SIZE, shuffle=True)
        
        try:
            optimizer = AdamW(self.refinement_module.parameters(), lr=LEARNING_RATE)
            
            for epoch in range(epochs):
                total_loss = 0
                batch_count = 0
                
                for batch in tqdm(dataloader, desc=f"Training refinement module - Epoch {epoch+1}/{epochs}"):
                    try:
                        loss = self.refinement_module.train_step(batch, optimizer)
                        total_loss += loss
                        batch_count += 1
                    except Exception as e:
                        print(f"Error processing batch: {e}")
                        continue
                    
                if batch_count > 0:
                    avg_loss = total_loss / batch_count
                    print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")
                else:
                    print(f"Epoch {epoch+1}/{epochs}, No valid batches processed")
        except Exception as e:
            print(f"Error in train_refinement_module: {e}")
    
    def train_loop(self, train_dataset, val_dataset, epochs=2):
        """Improved main training loop with better logging and fallbacks"""
        try:
            # Initialize retrieval module with a small subset of examples
            self.retrieval_module.initialize_from_dataset(train_dataset, max_exemplars=50)
            
            # Track best validation score for model saving
            best_val_reward = 0.0
            
            for epoch in range(epochs):
                print(f"\n=== Starting epoch {epoch+1}/{epochs} ===")
                
                # 1. Generate pseudo-labeled data (with lower threshold for M1)
                pseudo_labeled = self.generate_pseudo_labels(train_dataset, threshold=0.4, max_samples=50)
                print(f"Generated {len(pseudo_labeled)} pseudo-labeled examples")
                
                # Display top examples if available
                if pseudo_labeled:
                    top_example = max(pseudo_labeled, key=lambda x: x["reward"])
                    print(f"\nTop example (reward: {top_example['reward']:.2f}):")
                    print(f"Question: {top_example['question'][:100]}...")
                    print(f"Answer: {top_example['final_answer']}")
                    print(f"First step: {top_example['cot_steps'][0] if top_example['cot_steps'] else 'No steps'}")
                
                # 3. Validate on a small subset of validation set
                val_subset_size = min(20, len(val_dataset))  # Limit for M1
                val_subset_indices = random.sample(range(len(val_dataset)), val_subset_size)
                
                val_rewards = []
                correct = 0
                total = 0
                
                for i in tqdm(val_subset_indices, desc="Validating"):
                    try:
                        item = val_dataset[i]
                        question = item["raw_question"]
                        gold_answer = item["raw_answer"]
                        
                        # Generate CoT
                        generated = self.cot_generator.generate(question)
                        
                        # Calculate reward
                        reward = self.reward_function.combined_reward(
                            generated["cot_steps"], generated["final_answer"], gold_answer, question
                        )
                        val_rewards.append(reward)
                        
                        # Check if answer is correct (exact match)
                        pred_clean = re.sub(r'[$,\s]', '', generated["final_answer"]).strip()
                        true_clean = re.sub(r'[$,\s]', '', gold_answer).strip()
                        if pred_clean == true_clean:
                            correct += 1
                        total += 1
                    except Exception as e:
                        print(f"Error validating item {i}: {e}")
                
                if val_rewards:
                    avg_reward = sum(val_rewards) / len(val_rewards)
                    accuracy = correct / total if total > 0 else 0
                    print(f"Epoch {epoch+1}/{epochs}, Validation Reward: {avg_reward:.4f}, Accuracy: {accuracy:.4f} ({correct}/{total})")
                    
                    # Save best model
                    if avg_reward > best_val_reward:
                        best_val_reward = avg_reward
                        print(f"New best validation reward: {best_val_reward:.4f}, saving model...")
                        self.cot_generator.save("./model_output/best_cot_generator")
                else:
                    print(f"Epoch {epoch+1}/{epochs}, No valid validation rewards calculated")
                
            print(f"\nTraining complete. Best validation reward: {best_val_reward:.4f}")
            return best_val_reward
        except Exception as e:
            print(f"Error in train_loop: {e}")
            return 0.0

In [13]:
# Training and Evaluation Functions
def train_model(train_data, val_data):
    try:
        # Initialize all components with error handling
        print("Initializing model components...")
        
        cot_generator = CoTGenerator()
        print("CoT Generator initialized")
        
        reflection_module = ReflectionModule()
        reflection_module = to_device(reflection_module)
        print("Reflection Module initialized")
        
        retrieval_module = RetrievalModule()
        print("Retrieval Module initialized")
        
        refinement_module = TextRefinementTransformer()
        refinement_module = to_device(refinement_module)
        print("Refinement Module initialized")
        
        reward_function = RewardFunction(reflection_module)
        print("Reward Function initialized")
        
        # Create the self-trainer
        trainer = SelfTrainer(
            cot_generator=cot_generator,
            reflection_module=reflection_module,
            retrieval_module=retrieval_module,
            refinement_module=refinement_module,
            reward_function=reward_function
        )
        print("Self-Trainer initialized")
        
        # Train the model
        trainer.train_loop(train_data, val_data, epochs=EPOCHS)
        
        # Save the trained generator model
        print("Saving model...")
        os.makedirs("./model_output", exist_ok=True)
        cot_generator.save("./model_output/cot_generator_model")
        print("Model saved to ./model_output/cot_generator_model")
        
        return cot_generator
    except Exception as e:
        print(f"Error in train_model: {e}")
        # Return a default model anyway so the script can continue
        return CoTGenerator()

def evaluate_model(model, test_data, max_samples=20):  # Limited for M1
    try:
        correct = 0
        total = 0
        
        # Limit test samples for M1
        sample_indices = random.sample(range(len(test_data)), min(max_samples, len(test_data)))
        
        for i in tqdm(sample_indices, desc="Evaluating"):
            try:
                item = test_data[i]
                question = item["raw_question"]
                gold_answer = item["raw_answer"]
                
                generated = model.generate(question)
                predicted_answer = generated["final_answer"]
                
                # Simple exact match evaluation
                if predicted_answer.strip() == gold_answer.strip():
                    correct += 1
                total += 1
            except Exception as e:
                print(f"Error evaluating item {i}: {e}")
        
        accuracy = correct / total if total > 0 else 0
        print(f"Evaluation Accuracy: {accuracy:.4f} ({correct}/{total})")
        return accuracy
    except Exception as e:
        print(f"Error in evaluate_model: {e}")
        return 0.0

In [14]:
# print("Starting CoT reasoning framework for M1 Mac...")
# print(f"PyTorch version: {torch.__version__}")
# print(f"MPS available: {torch.backends.mps.is_available()}")
# print(f"Using device: {device}")

# print("Loading GSM8K dataset...")

# # Try loading tokenizer with error handling
# try:
#     tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
#     print("Tokenizer loaded successfully")
# except Exception as e:
#     print(f"Error loading tokenizer: {e}")
#     print("Proceeding without tokenizer")
#     tokenizer = None

# # Load datasets with limited samples for M1
# try:
#     train_dataset = GSM8KDataset("train", tokenizer, max_samples=200)  # Limited for M1
#     val_dataset = GSM8KDataset("test", tokenizer, max_samples=50)  # Limited for M1
    
#     print(f"Loaded {len(train_dataset)} training examples and {len(val_dataset)} validation examples")
# except Exception as e:
#     print(f"Error loading datasets: {e}")
#     print("Using dummy datasets for testing")
#     # Create dummy datasets for testing
#     from torch.utils.data import TensorDataset
#     dummy_data = torch.zeros(10, MAX_LENGTH, dtype=torch.long)
#     train_dataset = TensorDataset(dummy_data, dummy_data, dummy_data)
#     val_dataset = TensorDataset(dummy_data, dummy_data, dummy_data)

# # Train the model
# print("Training model...")
# trained_model = train_model(train_dataset, val_dataset)

# # Evaluate the model
# print("Evaluating model...")
# evaluate_model(trained_model, val_dataset)

# # Example inference
# example_question = "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast every morning and sells the rest at the farmers' market for $2 per egg. How much money does she make in a week?"

# print("\nTesting with example question:")
# print(f"Question: {example_question}")

# try:
#     # Generate answer with trained model
#     result = trained_model.generate(example_question)
    
#     print("\nGenerated Chain-of-Thought:")
#     for i, step in enumerate(result["cot_steps"]):
#         print(f"Step {i+1}: {step}")
    
#     print(f"\nFinal Answer: {result['final_answer']}")
# except Exception as e:
#     print(f"Error in example inference: {e}")

# print("\nCoT reasoning framework execution completed!")


In [15]:


print("Starting the training and evaluation process...")

# Set up tokenizer
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

# Create datasets with limited size
print("Creating datasets...")
train_dataset = GSM8KDataset("train", tokenizer, MAX_LENGTH, max_samples=200)
val_dataset = GSM8KDataset("test", tokenizer, MAX_LENGTH, max_samples=50)
print(f"Loaded {len(train_dataset)} training examples and {len(val_dataset)} validation examples")

# Initialize all components
print("Initializing model components...")
cot_generator = CoTGenerator()
print("CoT Generator initialized")

reflection_module = ReflectionModule()
reflection_module = to_device(reflection_module)
print("Reflection Module initialized")

retrieval_module = RetrievalModule()
print("Retrieval Module initialized")

refinement_module = TextRefinementTransformer()
refinement_module = to_device(refinement_module)
print("Refinement Module initialized")

# Use improved reward function
reward_function = RewardFunction(reflection_module)
print("Reward Function initialized")

# Create the self-trainer
trainer = SelfTrainer(
    cot_generator=cot_generator,
    reflection_module=reflection_module,
    retrieval_module=retrieval_module,
    refinement_module=refinement_module,
    reward_function=reward_function
)
print("Self-Trainer initialized")

# Train the model
print("Training model...")
best_reward = trainer.train_loop(train_dataset, val_dataset, epochs=EPOCHS)

# Save the trained generator model
print("Saving final model...")
os.makedirs("./model_output", exist_ok=True)
cot_generator.save("./model_output/final_cot_generator")
print("Model saved to ./model_output/final_cot_generator")

# Evaluate
print("\nFinal evaluation...")
accuracy = evaluate_model(cot_generator, val_dataset, max_samples=20)

print(f"\nTraining complete!")
print(f"Best validation reward: {best_reward:.4f}")
print(f"Final evaluation accuracy: {accuracy:.4f}")


Starting the training and evaluation process...


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Creating datasets...


Preprocessing data: 100%|██████████| 200/200 [00:00<00:00, 20597.67it/s]
Preprocessing data: 100%|██████████| 50/50 [00:00<00:00, 11382.10it/s]


Loaded 200 training examples and 50 validation examples
Initializing model components...
Loading model t5-base...
Model not found locally. Downloading t5-base...
Downloading model t5-base to ./models/t5_base_cache...
Tokenizer downloaded and saved to ./models/t5_base_cache
Model downloaded and moved to mps
CoT Generator initialized
Reflection Module initialized
Retrieval Module initialized
Refinement Module initialized
Reward Function initialized
Self-Trainer initialized
Training model...


Building exemplar bank:  25%|██▌       | 50/200 [00:00<00:00, 301.62it/s]



=== Starting epoch 1/2 ===


Generating pseudo-labels:   6%|▌         | 3/50 [00:14<04:01,  5.15s/it]


Question: Albert is wondering how much pizza he can eat in o...
Predicted: What are the pieces he can eat in one day?
True: 48
Outcome reward: 0.00, Process reward: 0.66, Combined: 0.26
First reasoning step: Albert is a pizza person. Albert wants to eat 16 s...


Generating pseudo-labels:  18%|█▊        | 9/50 [00:25<01:26,  2.12s/it]


Question: Yesterday, David and William were invited to a par...
Predicted: 
True: 8
Outcome reward: 0.80, Process reward: 0.65, Combined: 0.74
First reasoning step: David broke 2 glasses, while his friend William br...


Generating pseudo-labels:  46%|████▌     | 23/50 [01:04<01:22,  3.07s/it]


Question: There are 16 people at a dinner party. There are 4...
Predicted: 
True: 24
Outcome reward: 0.80, Process reward: 0.64, Combined: 0.74
First reasoning step: 40 dinner rolls are available for them. each. The ...


Generating pseudo-labels:  68%|██████▊   | 34/50 [01:30<00:33,  2.06s/it]


Question: A concert ticket costs $40. Mr. Benson bought 12 t...
Predicted: How about you?
True: 40
Outcome reward: 0.00, Process reward: 0.65, Combined: 0.26
First reasoning step: 5. How about you?...


Generating pseudo-labels:  76%|███████▌  | 38/50 [01:39<00:25,  2.15s/it]


Question: Anna goes trick-or-treating in a subdivision where...
Predicted: How many more pieces of candy does Anna get?
True: 15
Outcome reward: 0.00, Process reward: 0.70, Combined: 0.28
First reasoning step: Her brother Billy goes trick-or-treating in a neig...


Generating pseudo-labels: 100%|██████████| 50/50 [02:13<00:00,  2.67s/it]


Reward stats - Min: 0.21, Max: 0.87, Avg: 0.59
Rewards histogram: [(0.1, 0), (0.2, 16), (0.3, 16), (0.4, 0), (0.5, 0), (0.6, 2), (0.7, 32), (0.8, 32), (0.9, 2)]
Generated 34 pseudo-labeled examples

Top example (reward: 0.87):
Question: There is very little car traffic on Happy Street. During the week, most cars pass it on Tuesday - 25...
Answer: 
First step: On Wednesday, 2 more cars than on Monday. On Thursday, 9 cars. On Friday, 10 cars. On the weekend, 5 cars.


Validating:  10%|█         | 2/20 [00:06<00:57,  3.17s/it]


Question: John plans to sell all his toys and use the money ...
Predicted: 
True: 2
Outcome reward: 0.80, Process reward: 0.66, Combined: 0.74
First reasoning step: John sells his toys. he will sell everything he ha...


Validating: 100%|██████████| 20/20 [00:47<00:00,  2.38s/it]


Epoch 1/2, Validation Reward: 0.6423, Accuracy: 0.0500 (1/20)
New best validation reward: 0.6423, saving model...
Model saved to ./model_output/best_cot_generator

=== Starting epoch 2/2 ===


Generating pseudo-labels:   6%|▌         | 3/50 [00:07<01:55,  2.45s/it]


Question: Mary bought 5 boxes of drinks at $6 each box and 1...
Predicted: 
True: 200
Outcome reward: 0.80, Process reward: 0.64, Combined: 0.74
First reasoning step: Mary paid $200 for everything she buys. Mary bough...


Generating pseudo-labels:  16%|█▌        | 8/50 [00:19<01:40,  2.38s/it]


Question: Tim's cat bit him.  He decided to get himself and ...
Predicted: 
True: 300
Outcome reward: 0.80, Process reward: 0.62, Combined: 0.73
First reasoning step: 5. his cat bit him. his cat insurance covered $60....


Generating pseudo-labels:  38%|███▊      | 19/50 [00:49<01:51,  3.60s/it]


Question: Bella bought stamps at the post office. Some of th...
Predicted: 
True: 38
Outcome reward: 0.80, Process reward: 0.63, Combined: 0.73
First reasoning step: Bella walked down the hall with her stamps. She bo...


Generating pseudo-labels:  52%|█████▏    | 26/50 [01:14<01:13,  3.08s/it]


Question: Silvia’s bakery is offering 10% on advanced orders...
Predicted: 
True: 2
Outcome reward: 0.80, Process reward: 0.61, Combined: 0.72
First reasoning step: Silvia’s bakery is offering 10% on advanced orders...


Generating pseudo-labels:  80%|████████  | 40/50 [02:03<00:35,  3.58s/it]


Question: Lilah's family gallery has 400 photos. On a two-da...
Predicted: 
True: 920
Outcome reward: 0.80, Process reward: 0.65, Combined: 0.74
First reasoning step: Lilah's family gallery has 400 photos. the family ...


Generating pseudo-labels: 100%|██████████| 50/50 [02:26<00:00,  2.93s/it]



Question: Larry spends half an hour twice a day walking and ...
Predicted: 
True: 72
Outcome reward: 0.80, Process reward: 0.56, Combined: 0.70
First reasoning step: Larry is a dedicated dog trainer.. Larry spends an...
Reward stats - Min: 0.21, Max: 0.82, Avg: 0.63
Rewards histogram: [(0.1, 0), (0.2, 11), (0.3, 11), (0.4, 0), (0.5, 0), (0.6, 3), (0.7, 38), (0.8, 36), (0.9, 1)]
Generated 39 pseudo-labeled examples

Top example (reward: 0.82):
Question: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are y...
Answer: 
First step: He has a garden with flowers. Mark's garden is filled with flowers. He planted plants of three different colors in it. He planted about ten yellow flowers. The yellow flowers are mostly green. There are also green flowers.


Validating:  20%|██        | 4/20 [00:10<00:40,  2.53s/it]


Question: Two trains leave San Rafael at the same time. They...
Predicted: What's the distance covered?
True: 150
Outcome reward: 0.00, Process reward: 0.62, Combined: 0.25
First reasoning step: the next day, they travel northward. The next day,...


Validating:  35%|███▌      | 7/20 [00:18<00:37,  2.87s/it]


Question: Cynthia eats one serving of ice cream every night....
Predicted: 
True: 60
Outcome reward: 0.80, Process reward: 0.72, Combined: 0.77
First reasoning step: 1.00 per serving of ice cream. Cynthia's budget is...


Validating:  40%|████      | 8/20 [00:20<00:30,  2.54s/it]


Question: Mike plays ping pong for 40 minutes.  In the first...
Predicted: 
True: 4
Outcome reward: 0.80, Process reward: 0.66, Combined: 0.74
First reasoning step: In the second 20 minutes, he scores 4 points. 6 po...


Validating:  70%|███████   | 14/20 [00:33<00:12,  2.08s/it]


Question: James decides to run 3 sprints 3 times a week.  He...
Predicted: 
True: 9
Outcome reward: 0.80, Process reward: 0.65, Combined: 0.74
First reasoning step: . He runs 60 meters each sprint..? 668 meters each...


Validating:  75%|███████▌  | 15/20 [00:35<00:10,  2.11s/it]


Question: Siobhan has 2 fewer jewels than Aaron. Aaron has 5...
Predicted: 
True: 23
Outcome reward: 0.80, Process reward: 0.64, Combined: 0.73
First reasoning step: Aaron has 5 more jewels than Aaron's.. Aaron has 5...


Validating: 100%|██████████| 20/20 [00:48<00:00,  2.40s/it]


Epoch 2/2, Validation Reward: 0.6159, Accuracy: 0.0000 (0/20)

Training complete. Best validation reward: 0.6423
Saving final model...
Model saved to ./model_output/final_cot_generator
Model saved to ./model_output/final_cot_generator

Final evaluation...


Evaluating: 100%|██████████| 20/20 [00:47<00:00,  2.36s/it]

Evaluation Accuracy: 0.0000 (0/20)

Training complete!
Best validation reward: 0.6423
Final evaluation accuracy: 0.0000



