In [1]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    T5ForConditionalGeneration,
    T5Tokenizer
)
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset
from tqdm.auto import tqdm
import re
from torch.optim import AdamW
import random
import traceback
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss
import random
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [20]:
%pip install faiss

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[31mERROR: Could not find a version that satisfies the requirement faiss (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for faiss[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
def set_seed(seed=11):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed()

In [3]:
# Device configuration - M1 Mac specific
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple M1 MPS device")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device")
else:
    device = torch.device("cpu")
    print("Using CPU device")

Using Apple M1 MPS device


In [4]:
# Constants
MAX_LENGTH = 768
BATCH_SIZE = 4
LEARNING_RATE = 2e-5
EPOCHS = 3
MODEL_NAME = "t5-base"
MAX_SAMPLES = 400

COT_PROMPT = """Solve this step by step:
1. Understand what the problem is asking for
2. Extract the relevant information and variables
3. Choose the appropriate mathematical operations
4. Perform the calculations step by step
5. Make sure your solution answers the question
6. Double-check your work

Format your answer like this:
Step 1: [First reasoning step with clear explanation]
Step 2: [Second reasoning step with calculations shown]
...
Final Answer: [The answer with units if applicable]"""

In [5]:
def extract_final_answer(answer_text):
    """Extract the final numerical answer from text."""
    # First try to find explicit "Final Answer: X" pattern
    final_answer_match = re.search(r"Final Answer:\s*([\d\.\$]+)", answer_text, re.IGNORECASE)
    if final_answer_match:
        return final_answer_match.group(1).strip().strip('$')
    
    # Try other common patterns
    patterns = [
        r"The answer is\s*[-]?\s*\$?\s*([\d,\.]+)",
        r"Therefore,?\s.*?[-]?\s*\$?\s*([\d,\.]+)",
        r"So,?\s.*?[-]?\s*\$?\s*([\d,\.]+)",
        r"Thus,?\s.*?[-]?\s*\$?\s*([\d,\.]+)",
        r".*?([\d,\.]+)$"  # Last number in text as fallback
    ]
    
    for pattern in patterns:
        matches = re.search(pattern, answer_text, re.DOTALL | re.IGNORECASE)
        if matches:
            return matches.group(1).strip().replace(',', '')
    
    # Extract the last number if nothing else works
    numbers = re.findall(r"\d+(?:,\d+)*(?:\.\d+)?", answer_text)
    if numbers:
        return numbers[-1].strip().replace(',', '')
    
    return ""

def extract_cot_steps(answer_text):
    """Extract structured reasoning steps from text."""
    # Look for step patterns
    step_pattern = r"Step\s*(\d+):\s*(.*?)(?=Step\s*\d+:|Final Answer:|$)"
    steps = re.findall(step_pattern, answer_text, re.DOTALL | re.IGNORECASE)
    
    # If we found structured steps, use them
    if steps:
        return [step[1].strip() for step in steps]
    
    # Otherwise, try to split by lines and process
    lines = answer_text.split('\n')
    cleaned_steps = []
    
    current_step = ""
    for line in lines:
        line = line.strip()
        # Skip empty lines or final answer lines
        if not line or re.search(r"(answer|final).*?is", line, re.IGNORECASE):
            continue
        
        # Check if this is a new step indicator
        new_step_match = re.match(r'(\d+\)|\(\d+\)|[\d]+\.)', line)
        if new_step_match and current_step:
            cleaned_steps.append(current_step)
            current_step = line
        else:
            # Look for lines with mathematical operations or reasoning words
            if re.search(r'[+\-*/=]|because|means|so|thus|therefore', line, re.IGNORECASE):
                # Clean the line
                clean_line = re.sub(r'<<.*?>>', '', line).strip()
                if clean_line:
                    if current_step:
                        current_step += " " + clean_line
                    else:
                        current_step = clean_line
    
    # Add the last step if there is one
    if current_step:
        cleaned_steps.append(current_step)
    
    return cleaned_steps if cleaned_steps else [answer_text.strip()]

def test_cot_steps():
    test_cases = [
        """Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
The answer is 72.""",
        
        """Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
The answer is 10."""
    ]
    
    for case in test_cases:
        print("Original:")
        print(case)
        print("\nExtracted Steps:")
        print(extract_cot_steps(case))
        print("\n---\n")

# Safe device transfer function for M1 MPS
def to_device(tensor_or_module):
    """Safely move tensors or modules to the selected device"""
    if tensor_or_module is None:
        return None
    
    device = torch.device("cuda" if torch.cuda.is_available() else 
                         "mps" if torch.backends.mps.is_available() else "cpu")
    
    try:
        return tensor_or_module.to(device)
    except Exception as e:
        print(f"Warning: Could not move to {device}: {e}")
        return tensor_or_module

In [6]:
test_cot_steps()

Original:
Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
The answer is 72.

Extracted Steps:
['Natalia sold 48/2 = 24 clips in May. Natalia sold 48+24 = 72 clips altogether in April and May.']

---

Original:
Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
The answer is 10.

Extracted Steps:
['Weng earns 12/60 = $0.2 per minute. Working 50 minutes, she earned 0.2 x 50 = $10.']

---



In [7]:
class GSM8KDataset(Dataset):
    def __init__(self, split="train", tokenizer=None, max_length=768, max_samples=None):
        self.data = load_dataset("gsm8k", "main")[split]
        if max_samples:
            self.data = self.data.select(range(min(max_samples, len(self.data))))
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.processed_data = self.preprocess_data()
        
    def preprocess_data(self):
        processed = []
        for item in tqdm(self.data, desc="Preprocessing data"):
            question = item["question"]
            answer_with_cot = item["answer"]
            
            # Extract the CoT steps and the final answer
            final_answer = extract_final_answer(answer_with_cot)
            cot_steps = extract_cot_steps(answer_with_cot)
            
            # Reformat the CoT steps in a more structured way
            formatted_cot = ""
            for i, step in enumerate(cot_steps, 1):
                formatted_cot += f"Step {i}: {step}\n"
            formatted_cot += f"Final Answer: {final_answer}"
            
            # Format question with improved CoT prompt
            formatted_question = f"{question}\n\n{COT_PROMPT}"
            
            processed.append({
                "question": question,
                "formatted_question": formatted_question,
                "formatted_cot": formatted_cot,
                "cot_steps": cot_steps,
                "final_answer": final_answer,
                "full_answer": answer_with_cot
            })
        return processed
    
    def __len__(self):
        return len(self.processed_data)
    
    def __getitem__(self, idx):
        item = self.processed_data[idx]
        
        if self.tokenizer:
            input_text = item["formatted_question"]
            target_text = item["formatted_cot"]
            
            try:
                inputs = self.tokenizer(
                    input_text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length // 3,
                    return_tensors="pt"
                )
                
                targets = self.tokenizer(
                    target_text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length * 2 // 3,
                    return_tensors="pt"
                )
                
                return {
                    "input_ids": inputs.input_ids.squeeze(),
                    "attention_mask": inputs.attention_mask.squeeze(),
                    "labels": targets.input_ids.squeeze(),
                    "raw_question": item["question"],
                    "raw_cot": item["formatted_cot"],
                    "raw_answer": item["final_answer"]
                }
            except Exception as e:
                print(f"Error tokenizing item {idx}: {e}")
                # Return a fallback
                dummy_tensor = torch.zeros(self.max_length, dtype=torch.long)
                return {
                    "input_ids": dummy_tensor,
                    "attention_mask": dummy_tensor,
                    "labels": dummy_tensor,
                    "raw_question": item["question"],
                    "raw_cot": "",
                    "raw_answer": ""
                }
        else:
            return item

In [8]:
class ReflectionModule(nn.Module):
    def __init__(self, base_model_name="distilroberta-base", device=None):
        super(ReflectionModule, self).__init__()
        
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else 
                                      "mps" if torch.backends.mps.is_available() else "cpu")
        else:
            self.device = device
            
        # Load base transformer model
        self.encoder = AutoModel.from_pretrained(base_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        
        # Define evaluation heads
        self.coherence_head = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
        self.progress_head = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
        self.consistency_head = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size * 2, 256),  # Takes concatenated embeddings
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
        self.math_patterns = {
            'equation': r'(\d+\s*[\+\-\*\/]\s*\d+\s*=\s*\d+)',
            'numerical_equality': r'(\d+\s*=\s*\d+)',
            'variable_assignment': r'([a-zA-Z]\s*=\s*\d+)',
            'arithmetic_operation': r'(\d+\s*[\+\-\*\/]\s*\d+)'
        }
        
        self.to(self.device)
        
    def _encode_text(self, text):
        # Tokenize and encode text
        inputs = self.tokenizer(text, return_tensors="pt", 
                              padding=True, truncation=True, max_length=512).to(self.device)
        
        with torch.no_grad():
            outputs = self.encoder(**inputs)
            # Use [CLS] token embedding as the representation
            return outputs.last_hidden_state[:, 0, :]
    
    def evaluate_step(self, step, question, previous_steps=None):
        # Encode current step
        step_embedding = self._encode_text(step)
        
        coherence_score = self.coherence_head(step_embedding).item()
        
        question_embedding = self._encode_text(question)
        concat_embedding = torch.cat([question_embedding, step_embedding], dim=1)
        progress_score = self.progress_head(step_embedding).item()
        
        consistency_score = 1.0
        
        if previous_steps and len(previous_steps) > 0:
            prev_text = " ".join(previous_steps)
            prev_embedding = self._encode_text(prev_text)
            
            concat_embedding = torch.cat([prev_embedding, step_embedding], dim=1)
            consistency_score = self.consistency_head(concat_embedding).item()
        
        math_validity = self._check_math_validity(step)
        
        process_reward = (coherence_score * 0.3 + 
                          consistency_score * 0.3 + 
                          progress_score * 0.3 +
                          math_validity * 0.1)
        
        return {
            "coherence": coherence_score,
            "consistency": consistency_score,
            "progress": progress_score,
            "math_validity": math_validity,
            "process_reward": process_reward
        }
    
    def _check_math_validity(self, step):
        """Check if mathematical expressions in the step are valid"""
        expressions = []
        for pattern_name, pattern in self.math_patterns.items():
            matches = re.findall(pattern, step)
            expressions.extend(matches)
        
        if not expressions:
            return 0.5  # Default value if not valid
        
        valid_count = 0
        for expr in expressions:
            try:
                eval_expr = expr.replace("=", "==")
                if eval(eval_expr):
                    valid_count += 1
            except:
                pass
        
        return valid_count / len(expressions) if expressions else 0.5
    
    def refine_step(self, question, current_step, previous_steps=None):
        """Use the reflection module to refine a step"""
        eval_results = self.evaluate_step(current_step, question, previous_steps)
        
        # threshold
        if eval_results["process_reward"] > 0.8:
            return current_step
        
        fixed_step = current_step
        
        if eval_results["math_validity"] < 0.7:
            fixed_step = self._fix_math_expressions(fixed_step)
        
        if previous_steps and eval_results["consistency"] < 0.7:
            # Extract any variables or values from previous steps
            context = self._extract_context_from_steps(previous_steps)
            fixed_step = self._ensure_consistency(fixed_step, context)
        
        return fixed_step
    
    def _fix_math_expressions(self, step):
        """Attempt to fix mathematical expressions in a step"""
        # Implement simple fixes for common math errors, think of more
        
        step = re.sub(r'(\d+)([+\-*/])(\d+)', r'\1 \2 \3', step)
        
        # Fix incorrect equals signs (= vs ==)
        step = re.sub(r'(\d+)\s*==\s*(\d+)', r'\1 = \2', step)
        
        return step
    
    def _extract_context_from_steps(self, steps):
        """Extract variables and their values from previous steps"""
        context = {}
        for step in steps:
            var_matches = re.findall(r'([a-zA-Z])\s*=\s*(\d+(?:\.\d+)?)', step)
            for var, value in var_matches:
                context[var] = float(value)
        
        return context
    
    def _ensure_consistency(self, step, context):
        """Ensure step is consistent with the extracted context"""
        for var, value in context.items():
            pattern = rf'\b{var}\b(?!\s*=)'
            step = re.sub(pattern, str(value), step)
        
        return step

In [9]:
class RetrievalModule:
    """
    Maintains a memory bank of high-quality exemplar CoT sequences and retrieves similar exemplars for any input prompt.
    """
    def __init__(self, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
        self.embedding_model = SentenceTransformer(embedding_model)
        self.memory_bank = []
        self.embeddings = None
        self.index = None
        
    def add_exemplar(self, question, cot_steps, final_answer, quality_score=1.0):
        """Add a new exemplar to the memory bank"""
        exemplar = {
            "question": question,
            "cot_steps": cot_steps,
            "final_answer": final_answer,
            "quality_score": quality_score
        }
        self.memory_bank.append(exemplar)
        
        # Rebuild index if we have enough examples
        if len(self.memory_bank) % 100 == 0:
            self._build_index()
    
    def add_batch_exemplars(self, exemplars):
        """Add multiple exemplars at once"""
        self.memory_bank.extend(exemplars)
        self._build_index()
    
    def _build_index(self):
        """Build a FAISS index for fast similarity search"""
        if not self.memory_bank:
            return
        
        # Create embeddings for all questions
        questions = [item["question"] for item in self.memory_bank]
        self.embeddings = self.embedding_model.encode(questions)
        
        # Convert to numpy array with correct dtype
        embeddings_np = np.array(self.embeddings).astype('float32')
        
        # Build FAISS index
        dimension = embeddings_np.shape[1]
        self.index = faiss.IndexFlatL2(dimension)
        self.index.add(embeddings_np)
        
        print(f"Built index with {len(self.memory_bank)} exemplars")
    
    def retrieve_similar(self, question, top_k=5):
        """Retrieve most similar exemplars for a given question"""
        if not self.memory_bank or self.index is None:
            return []
        
        # Encode the question
        question_embedding = self.embedding_model.encode([question])
        question_embedding = np.array(question_embedding).astype('float32')
        
        # Search for similar questions
        distances, indices = self.index.search(question_embedding, min(top_k, len(self.memory_bank)))
        
        # Return the corresponding exemplars
        similar_exemplars = [self.memory_bank[idx] for idx in indices[0]]
        
        return similar_exemplars
    
    def save_memory_bank(self, file_path):
        """Save the memory bank to a file"""
        with open(file_path, 'w') as f:
            json.dump(self.memory_bank, f)
    
    def load_memory_bank(self, file_path):
        """Load the memory bank from a file"""
        with open(file_path, 'r') as f:
            self.memory_bank = json.load(f)
        self._build_index()

In [10]:
class CoTLossCalculator:
    """
    Calculates various loss functions for training the CoT generator
    """
    def __init__(self, tokenizer, reflection_module=None):
        self.tokenizer = tokenizer
        self.reflection_module = reflection_module
        
        # Weights for combining loss components
        self.weights = {
            "nll_loss": 1.0,
            "consistency_loss": 0.3,
            "step_quality_loss": 0.5,
            "answer_correctness_loss": 1.0
        }
    
    def calculate_nll_loss(self, logits, labels):
        """Standard negative log-likelihood loss"""
        # Shift logits and labels for next token prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        
        # Calculate cross entropy loss
        loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        nll_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        return nll_loss
    
    def calculate_consistency_loss(self, generated_steps, question):
        """Loss based on consistency between steps"""
        if not self.reflection_module or not generated_steps:
            return torch.tensor(0.0)
        
        total_loss = 0.0
        for i in range(1, len(generated_steps)):
            eval_result = self.reflection_module.evaluate_step(
                generated_steps[i], question, generated_steps[:i]
            )
            step_loss = 1.0 - eval_result["consistency"]
            total_loss += step_loss
        
        # Average over all steps
        return torch.tensor(total_loss / len(generated_steps) if generated_steps else 0.0)
    
    def calculate_step_quality_loss(self, generated_steps, question):
        """Loss based on the quality of each step"""
        if not self.reflection_module or not generated_steps:
            return torch.tensor(0.0)
        
        total_loss = 0.0
        for step in generated_steps:
            eval_result = self.reflection_module.evaluate_step(step, question)
            step_loss = 1.0 - eval_result["process_reward"]
            total_loss += step_loss
        
        # Average over all steps
        return torch.tensor(total_loss / len(generated_steps) if generated_steps else 0.0)
    
    def calculate_answer_correctness_loss(self, predicted_answer, correct_answer):
        """Loss based on correctness of the final answer"""
        if predicted_answer == correct_answer:
            return torch.tensor(0.0)
        else:
            return torch.tensor(1.0)
    
    def calculate_combined_loss(self, outputs, labels, generated_steps=None, 
                               question=None, predicted_answer=None, correct_answer=None):
        """Combine multiple loss functions"""
        # Calculate NLL loss
        nll_loss = self.calculate_nll_loss(outputs.logits, labels)
        losses = {"nll_loss": nll_loss}
        
        # Calculate consistency loss if data is available
        if generated_steps and question:
            consistency_loss = self.calculate_consistency_loss(generated_steps, question)
            losses["consistency_loss"] = consistency_loss
        else:
            losses["consistency_loss"] = torch.tensor(0.0)
        
        # Calculate step quality loss if data is available
        if generated_steps and question:
            step_quality_loss = self.calculate_step_quality_loss(generated_steps, question)
            losses["step_quality_loss"] = step_quality_loss
        else:
            losses["step_quality_loss"] = torch.tensor(0.0)
        
        # Calculate answer correctness loss if data is available
        if predicted_answer is not None and correct_answer is not None:
            answer_loss = self.calculate_answer_correctness_loss(predicted_answer, correct_answer)
            losses["answer_correctness_loss"] = answer_loss
        else:
            losses["answer_correctness_loss"] = torch.tensor(0.0)
        
        # Combine all losses with their weights
        combined_loss = sum(self.weights[key] * losses[key] for key in losses)
        
        return combined_loss, losses

In [25]:
class EnhancedCoTGenerator(torch.nn.Module):
    def __init__(self, 
                 model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", 
                 local_dir="./models/tinyllama_cache",
                 max_steps=8,
                 reflection_module=None,
                 retrieval_module=None,
                 device=None,
                 load_in_8bit=False,
                 hf_token=None):
        super(EnhancedCoTGenerator, self).__init__()
        
        self.max_steps = max_steps
        self.model_name = model_name
        self.local_dir = local_dir
        self.load_in_8bit = load_in_8bit
        self.hf_token = hf_token
        
        # Initialize reflection and retrieval modules
        self.reflection_module = reflection_module
        self.retrieval_module = retrieval_module
        
        # Set device
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else 
                                       "mps" if torch.backends.mps.is_available() else "cpu")
        else:
            self.device = device
        
        # Create directory if needed
        os.makedirs(self.local_dir, exist_ok=True)
        
        print(f"Loading model {model_name}...")
        
        # Authenticate with Hugging Face if token is provided
        if self.hf_token:
            try:
                from huggingface_hub import login
                login(token=self.hf_token)
                print("Successfully logged in to Hugging Face!")
            except Exception as e:
                print(f"Authentication failed: {e}")
                print("Will try to continue, but may fail if model requires authentication.")
        
        # Check for local model first
        if os.path.exists(os.path.join(self.local_dir, "pytorch_model.bin")) and \
           os.path.exists(os.path.join(self.local_dir, "tokenizer_config.json")):
            print(f"Found existing model at {self.local_dir}. Loading locally...")
            self._load_local_model()
        else:
            print(f"Model not found locally. Downloading {model_name}...")
            self._download_model()
    
    def _download_model(self):
        try:
            # Add token to kwargs if available
            token_kwargs = {"token": self.hf_token} if self.hf_token else {}
            
            # Download tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                cache_dir=self.local_dir,
                use_fast=True,
                **token_kwargs
            )
            
            # Make sure padding token is set
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            print(f"Tokenizer downloaded and saved to {self.local_dir}")
            
            # Download model with appropriate settings for device
            model_kwargs = {
                "cache_dir": self.local_dir,
                "low_cpu_mem_usage": True,
                **token_kwargs  # Include token if available
            }
            
            if self.load_in_8bit and self.device.type == "cuda":  # Only use 8-bit for CUDA
                model_kwargs["load_in_8bit"] = True
                print("Loading model in 8-bit quantization")
            elif self.device.type == "mps" or self.device.type == "cuda":
                model_kwargs["torch_dtype"] = torch.float16
                print(f"Loading model in float16 on {self.device.type}")
            else:
                print("Loading model in default precision on CPU")
            
            try:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    **model_kwargs
                )
                if not self.load_in_8bit:  # Only move to device if not 8-bit (8-bit handles device placement)
                    self.model = self.model.to(self.device)
                print(f"Model downloaded and moved to {self.device}")
            except Exception as e:
                print(f"Error loading model to device: {e}")
                print("Falling back to CPU with reduced precision")
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    cache_dir=self.local_dir,
                    torch_dtype=torch.float16,
                    low_cpu_mem_usage=True,
                    **token_kwargs
                )
                print(f"Model downloaded (CPU version)")
        except Exception as e:
            print(f"Error downloading model: {e}")
            raise e
    
    def _load_local_model(self):
        try:
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(self.local_dir)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # Load model with appropriate settings for device
            model_kwargs = {
                "low_cpu_mem_usage": True,
            }
            
            if self.load_in_8bit and self.device.type == "cuda":  # Only use 8-bit for CUDA
                from transformers import BitsAndBytesConfig
                model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
                print("Loading model in 8-bit quantization")
            elif self.device.type == "mps" or self.device.type == "cuda":
                model_kwargs["torch_dtype"] = torch.float16
                print(f"Loading model in float16 on {self.device.type}")
            else:
                print("Loading model in default precision on CPU")
            
            try:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.local_dir, 
                    **model_kwargs
                )
                if not self.load_in_8bit:  # Only move to device if not 8-bit
                    self.model = self.model.to(self.device)
                print(f"Model loaded from {self.local_dir} and moved to {self.device}")
            except Exception as e:
                print(f"Error loading model to device: {e}")
                print("Falling back to CPU with reduced precision")
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.local_dir,
                    torch_dtype=torch.float16,
                    low_cpu_mem_usage=True
                )
                print(f"Model loaded from {self.local_dir} (CPU version)")
        except Exception as e:
            print(f"Error loading local model: {e}")
            print("Will attempt to download from source...")
            self._download_model()
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        """Forward pass for training"""
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
    
    def _enhance_prompt_with_examples(self, question):
        """Enhance the prompt with similar examples from retrieval module"""
        if not self.retrieval_module:
            return question
        
        # Retrieve similar examples
        similar_examples = self.retrieval_module.retrieve_similar(question, top_k=2)
        
        if not similar_examples:
            return question
        
        # Format examples as part of the prompt
        enhanced_prompt = "Here are some examples of how to solve similar math problems:\n\n"
        
        for i, example in enumerate(similar_examples, 1):
            enhanced_prompt += f"Example {i}:\n"
            enhanced_prompt += f"Question: {example['question']}\n"
            
            # Format the chain of thought steps
            for j, step in enumerate(example['cot_steps'], 1):
                enhanced_prompt += f"Step {j}: {step}\n"
            
            enhanced_prompt += f"Final Answer: {example['final_answer']}\n\n"
        
        # Add the original question
        enhanced_prompt += f"Now solve this problem:\n{question}"
        
        return enhanced_prompt
    
    def _refine_steps_with_reflection(self, steps, question, final_answer=None):
        """Refine generated steps using the reflection module"""
        if not self.reflection_module:
            return steps
        
        refined_steps = []
        previous_steps = []
        
        for i, step in enumerate(steps):
            # Refine each step based on previous steps and question
            refined_step = self.reflection_module.refine_step(
                question=question,
                current_step=step,
                previous_steps=previous_steps
            )
            
            refined_steps.append(refined_step)
            previous_steps.append(refined_step)
        
        return refined_steps
    
    def generate(self, question, use_reflection=True, use_retrieval=True):
        """Generate a complete chain of thought reasoning for a math problem"""
        try:
            # 1. Use retrieval module to enhance prompt if available
            if use_retrieval and self.retrieval_module:
                enhanced_question = self._enhance_prompt_with_examples(question)
            else:
                enhanced_question = question
            
            # 2. Create a structured prompt for TinyLlama model format
            # TinyLlama chat format: <|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n
            prompt = f"<|im_start|>user\nPlease solve this math problem step by step.\n\n{enhanced_question}<|im_end|>\n<|im_start|>assistant\n"
            
            # 3. Encode input
            inputs = self.tokenizer(
                prompt, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=512
            ).to(self.device)
            
            # 4. Generate initial output
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    max_length=inputs.input_ids.shape[1] + 512,  # Allow reasonable length for math reasoning
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.95,
                    top_k=40,
                    repetition_penalty=1.1,
                    num_return_sequences=1,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )
            
            # 5. Decode output - only decode the new tokens, not the input prompt
            input_length = inputs.input_ids.shape[1]
            generated_ids = outputs[0][input_length:]
            generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
            
            # 6. Process the generated text to extract steps and final answer
            cot_steps = self._extract_cot_steps(generated_text)
            final_answer = self._extract_final_answer(generated_text)
            
            # 7. Use reflection module to refine steps if available
            if use_reflection and self.reflection_module and cot_steps:
                refined_steps = self._refine_steps_with_reflection(
                    cot_steps, 
                    question, 
                    final_answer
                )
                cot_steps = refined_steps
            
            # 8. Format the output consistently
            formatted_output = ""
            for i, step in enumerate(cot_steps, 1):
                formatted_output += f"Step {i}: {step}\n"
            
            if final_answer:
                formatted_output += f"Final Answer: {final_answer}"
            else:
                # If no final answer was extracted, regenerate it from steps
                regenerated_answer = self._regenerate_final_answer(cot_steps, question)
                formatted_output += f"Final Answer: {regenerated_answer}"
                final_answer = regenerated_answer
            
            # 9. Create result dictionary
            result = {
                "question": question,
                "cot_steps": cot_steps,
                "final_answer": final_answer,
                "full_output": formatted_output
            }
            
            # 10. Store in retrieval module if available
            if self.retrieval_module:
                # Add to retrieval module
                self.retrieval_module.add_exemplar(
                    question=question,
                    cot_steps=cot_steps,
                    final_answer=final_answer,
                    quality_score=1.0  # Default score
                )
                
            return result
                
        except Exception as e:
            print(f"Error generating solution: {e}")
            return {
                "question": question,
                "cot_steps": ["Error generating steps"],
                "final_answer": "Error",
                "full_output": f"Error: {str(e)}"
            }
    
    def _extract_cot_steps(self, text):
        """Extract chain of thought steps from generated text"""
        # Pattern matching for steps
        step_pattern = r"Step\s*\d+:?\s*(.*?)(?=Step\s*\d+:|Final Answer:|$)"
        steps = re.findall(step_pattern, text, re.DOTALL)
        
        # If no steps are found, try to extract reasoning paragraphs
        if not steps:
            # Split into lines and look for reasoning
            lines = [line.strip() for line in text.split('\n') if line.strip()]
            # Filter out short lines and final answer
            steps = [line for line in lines if len(line) > 10 and not line.startswith("Final Answer")]
        
        # Clean up each step
        steps = [step.strip() for step in steps if step.strip()]
        
        return steps
    
    def _extract_final_answer(self, text):
        """Extract final answer from generated text"""
        # Look for explicit "Final Answer:" pattern
        answer_pattern = r"Final Answer:?\s*(.*?)(?=<|$)"
        matches = re.findall(answer_pattern, text, re.DOTALL)
        
        if matches:
            return matches[0].strip()
        
        # If no explicit final answer, look at the last numerical value
        numbers = re.findall(r"\d+(?:,\d+)*(?:\.\d+)?", text)
        if numbers:
            return numbers[-1].strip()
        
        # If still no answer, take the last sentence
        sentences = text.split('.')
        if sentences:
            return sentences[-1].strip()
        
        return ""
    
    def _regenerate_final_answer(self, steps, question):
        """Regenerate final answer if extraction failed"""
        try:
            # Combine steps into a context
            context = "\n".join(steps)
            
            # Create a prompt to regenerate the answer - adjusted for TinyLlama
            prompt = f"<|im_start|>user\nBased on these steps, what is the final numerical answer to this math problem? Only provide the numerical value.\n\nQuestion: {question}\n\nSteps:\n{context}<|im_end|>\n<|im_start|>assistant\nThe final answer is "
            
            # Encode input
            inputs = self.tokenizer(
                prompt, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=512
            ).to(self.device)
            
            # Generate output
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    max_length=inputs.input_ids.shape[1] + 50,  # Just enough for the answer
                    do_sample=False,  # Deterministic for final answer
                    num_return_sequences=1,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )
            
            # Decode output - only the new tokens
            input_length = inputs.input_ids.shape[1]
            generated_ids = outputs[0][input_length:]
            answer_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
            
            # Extract numerical answer
            numbers = re.findall(r"\d+(?:,\d+)*(?:\.\d+)?", answer_text)
            if numbers:
                return numbers[-1].strip().replace(',', '')
            
            return answer_text.strip()
                
        except Exception as e:
            print(f"Error regenerating final answer: {e}")
            return "Error"
    
    def train_with_examples(self, train_dataset, val_dataset=None, 
                            epochs=3, batch_size=2, learning_rate=5e-5):
        """Train the model with examples - simplified for TinyLlama on M1 Mac"""
        # 1. Setup optimizer with parameter efficient fine-tuning
        from peft import get_peft_model, LoraConfig, TaskType
        
        # Using LoRA for efficient fine-tuning
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=4,  # Smaller rank for M1 Mac
            lora_alpha=16,
            lora_dropout=0.1,
            # Target appropriate modules for TinyLlama
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
        )
        
        # Convert to PEFT model
        self.model = get_peft_model(self.model, peft_config)
        
        # Only optimize the LoRA parameters to save memory
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
        
        # 2. Create dataloaders with smaller batch size for M1 Mac
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        if val_dataset:
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        # 3. Training loop
        for epoch in range(epochs):
            # Training phase
            self.model.train()
            train_loss = 0.0
            
            for batch_idx, batch in enumerate(tqdm(train_loader, 
                                        total=len(train_loader),
                                        desc=f"Epoch {epoch+1}/{epochs} Training")):
                # Move batch to device
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["labels"].to(self.device)
                
                # Prepare inputs for causal LM training
                inputs = {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "labels": labels
                }
                
                # Forward pass
                outputs = self.model(**inputs)
                loss = outputs.loss
                
                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                optimizer.step()
                
                train_loss += loss.item()
                
                # Log every 10 batches - more frequent for visibility
                if (batch_idx + 1) % 10 == 0:
                    avg_loss = train_loss / (batch_idx + 1)
                    print(f"Batch {batch_idx+1}/{len(train_loader)}, Loss: {avg_loss:.4f}")
            
            # Calculate average training loss
            avg_train_loss = train_loss / len(train_loader)
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}")
            
            # Validation phase
            if val_dataset:
                self.model.eval()
                val_loss = 0.0
                
                with torch.no_grad():
                    for batch in tqdm(val_loader, desc="Validation"):
                        input_ids = batch["input_ids"].to(self.device)
                        attention_mask = batch["attention_mask"].to(self.device)
                        labels = batch["labels"].to(self.device)
                        
                        outputs = self.model(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            labels=labels
                        )
                        
                        loss = outputs.loss
                        val_loss += loss.item()
                
                avg_val_loss = val_loss / len(val_loader)
                print(f"Epoch {epoch+1}/{epochs}, Val Loss: {avg_val_loss:.4f}")
            
            # Save model after each epoch
            self.save(f"{self.local_dir}_epoch_{epoch+1}")
    
    def save(self, path=None):
        """Save the model and tokenizer"""
        save_path = path if path else self.local_dir
        try:
            # For PEFT models, save the adapter
            if hasattr(self.model, "save_pretrained") and hasattr(self.model, "config") and hasattr(self.model.config, "peft_config_id"):
                self.model.save_pretrained(save_path)
                print(f"PEFT adapter saved to {save_path}")
            else:
                # Regular save for non-PEFT models
                self.model.save_pretrained(save_path)
            
            self.tokenizer.save_pretrained(save_path)
            print(f"Model and tokenizer saved to {save_path}")
        except Exception as e:
            print(f"Error saving model: {e}")
    
    def load(self, path=None):
        """Load the model and tokenizer"""
        load_path = path if path else self.local_dir
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(load_path)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            # Check if this is a PEFT adapter
            from peft import PeftModel, PeftConfig
            import os
            
            # Check if it's a PEFT model by looking for adapter_config.json
            if os.path.exists(os.path.join(load_path, "adapter_config.json")):
                # Need to load base model first, then adapter
                peft_config = PeftConfig.from_pretrained(load_path)
                base_model = AutoModelForCausalLM.from_pretrained(
                    peft_config.base_model_name_or_path,
                    torch_dtype=torch.float16 if self.device.type != "cpu" else torch.float32,
                    low_cpu_mem_usage=True
                )
                self.model = PeftModel.from_pretrained(base_model, load_path)
                print(f"PEFT model loaded from {load_path}")
            else:
                # Regular model load
                self.model = AutoModelForCausalLM.from_pretrained(
                    load_path,
                    torch_dtype=torch.float16 if self.device.type != "cpu" else torch.float32,
                    low_cpu_mem_usage=True
                )
            
            self.model = self.model.to(self.device)
            print(f"Model loaded from {load_path} and moved to {self.device}")
        except Exception as e:
            print(f"Error loading model: {e}")


In [26]:
# Initialize our modules
def initialize_modules():
    # 1. Create Reflection Module
    reflection_module = ReflectionModule(base_model_name="distilroberta-base")
    
    # 2. Create Retrieval Module
    retrieval_module = RetrievalModule(embedding_model="sentence-transformers/all-MiniLM-L6-v2")
    
    # 3. Create Enhanced CoT Generator
    cot_generator = EnhancedCoTGenerator(
        reflection_module=reflection_module,
        retrieval_module=retrieval_module,
        load_in_8bit=False  # Reduced precision for memory efficiency
    )
    
    return cot_generator, reflection_module, retrieval_module

In [27]:
# Example problems for testing
test_problems = [
    "Natalia sold clips in May. If she sold 48/2 clips in May, how many clips did she sell?",
    "Weng earns money per minute. If she earns 12/60 dollars per minute and works 50 minutes, how much does she earn?",
    "Betty wants to save $100. Her grandparents gave her $30. She already has $50. How much more does she need to save?",
    "Maila wants to read a 120-page book. She read 12 pages today and will read half the remaining pages tomorrow. How many pages will she read tomorrow?"
]

# Function to test the CoT generator
def test_cot_generator(cot_generator, problems):
    results = []
    
    for i, problem in enumerate(problems, 1):
        print(f"{'='*50}")
        print(f"Problem {i}: {problem}")
        print(f"{'='*50}")
        
        result = cot_generator.generate(problem)
        
        print("Chain of Thought Steps:")
        for j, step in enumerate(result["cot_steps"], 1):
            print(f"{j}. {step}")
        print(f"Final Answer: {result['final_answer']}")
        print(f"{'='*50}")
        
        results.append(result)
    
    return results

In [28]:
print("Initializing modules...")
cot_generator, reflection_module, retrieval_module = initialize_modules()

# Pre-seed the retrieval module with some examples
print("Pre-seeding retrieval module with examples...")
gsm8k_dataset = GSM8KDataset(split="train", max_samples=50)

# Add a few examples to the retrieval module
for i, item in enumerate(gsm8k_dataset.processed_data[:10]):
    retrieval_module.add_exemplar(
        question=item["question"],
        cot_steps=item["cot_steps"],
        final_answer=item["final_answer"],
        quality_score=1.0  # Assuming these are good examples
    )

print("Testing CoT generator...")
results = test_cot_generator(cot_generator, test_problems)

# Save results
with open("cot_results.json", "w") as f:
    json.dump(results, f, indent=2)

print("Done! Results saved to cot_results.json")

Initializing modules...
Loading model TinyLlama/TinyLlama-1.1B-Chat-v1.0...
Model not found locally. Downloading TinyLlama/TinyLlama-1.1B-Chat-v1.0...
Tokenizer downloaded and saved to ./models/tinyllama_cache
Loading model in float16 on mps
Model downloaded and moved to mps
Pre-seeding retrieval module with examples...


Preprocessing data: 100%|██████████| 50/50 [00:00<00:00, 7049.25it/s]


Testing CoT generator...
Problem 1: Natalia sold clips in May. If she sold 48/2 clips in May, how many clips did she sell?
Chain of Thought Steps:
1. Sure, here's the solution:
2. - Natalia sold clips in May
3. - She sold 48 clips in May
4. - This means that she sold a total of 48/2 clips in May
5. - Therefore, she sold 24 clips in total
Final Answer: 24
Problem 2: Weng earns money per minute. If she earns 12/60 dollars per minute and works 50 minutes, how much does she earn?
Chain of Thought Steps:
1. Convert the percentage to a fraction. The percentage "per minute" is represented by the term "12/60", so we need to divide it by 60 to get a fraction. So, we have:
- Percentage: 12/60 (in decimal form)
- Fractional part: 1/60 (in decimal form)
- Fraction: 1/60 (in decimal form) / 1 (in decimal form) = 1/60
2. Multiply the result by 50 minutes. The time duration of one work hour is 8 hours (or 48 minutes), so we multiply the resulting fraction by 50 minutes to find the amount earned in to