In [14]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AdamW,
    get_linear_schedule_with_warmup,
    T5ForConditionalGeneration,
    T5Tokenizer
)
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset
from tqdm.auto import tqdm
import re
from torch.optim import Adam
import random
import traceback

In [3]:
def set_seed(seed=11):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed()

In [4]:
# Device configuration - M1 Mac specific
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple M1 MPS device")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device")
else:
    device = torch.device("cpu")
    print("Using CPU device")

Using Apple M1 MPS device


In [5]:
MAX_LENGTH = 768
BATCH_SIZE = 4
LEARNING_RATE = 2e-5
EPOCHS = 3
COT_PROMPT = "Let's solve this step-by-step. To find the answer, I'll break down the problem into smaller parts."
MODEL_NAME = "t5-base"
MAX_SAMPLES = 400

In [9]:
# Utility functions for extracting final answers and CoT steps
def extract_final_answer(answer_text):
    # Look for patterns like "The answer is X" or "Therefore, the answer is X"
    patterns = [
        r"The answer is\s*[-]?\s*\$?\s*([\d,\.]+)",
        r"Therefore,?\s.*?[-]?\s*\$?\s*([\d,\.]+)",
        r"So,?\s.*?[-]?\s*\$?\s*([\d,\.]+)",
        r"Thus,?\s.*?[-]?\s*\$?\s*([\d,\.]+)",
        r"The final answer is\s*[-]?\s*\$?\s*([\d,\.]+)",
        # Add a pattern to catch just the last number in the text
        r".*?([\d,\.]+)$"
    ]
    
    for pattern in patterns:
        matches = re.search(pattern, answer_text, re.DOTALL | re.IGNORECASE)
        if matches:
            return matches.group(1).strip()
    
    # If no patterns match, extract the last number in the text
    numbers = re.findall(r"\d+(?:,\d+)*(?:\.\d+)?", answer_text)
    if numbers:
        return numbers[-1].strip()
    
    # Last resort fallback
    return answer_text.strip().split("\n")[-1]

def extract_cot_steps(answer_text):
    final_answer_patterns = [
        r"The answer is.*$",
        r"Therefore, the answer is.*$",
        r"Final answer:.*$"
    ]
    
    for pattern in final_answer_patterns:
        answer_text = re.sub(pattern, '', answer_text, flags=re.IGNORECASE | re.MULTILINE)
    
    # Split and clean steps, focusing on lines with mathematical operations
    steps = []
    for line in answer_text.split('\n'):
        # Look for lines with mathematical operations or reasoning
        if re.search(r'[+\-*/=]|because|means|so|thus', line, re.IGNORECASE):
            # Remove any inline calculation markers
            clean_line = re.sub(r'<<.*?>>', '', line).strip()
            if clean_line:
                steps.append(clean_line)
    
    return steps

def test_cot_steps():
    test_cases = [
        """Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
The answer is 72.""",
        
        """Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
The answer is 10."""
    ]
    
    for case in test_cases:
        print("Original:")
        print(case)
        print("\nExtracted Steps:")
        print(extract_cot_steps(case))
        print("\n---\n")

# Safe device transfer function for M1 MPS
def to_device(tensor_or_module):
    """Safely move tensors or modules to the selected device"""
    if tensor_or_module is None:
        return None
    
    device = torch.device("cuda" if torch.cuda.is_available() else 
                         "mps" if torch.backends.mps.is_available() else "cpu")
    
    try:
        return tensor_or_module.to(device)
    except Exception as e:
        print(f"Warning: Could not move to {device}: {e}")
        return tensor_or_module

In [10]:
test_cot_steps()

Original:
Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
The answer is 72.

Extracted Steps:
['Natalia sold 48/2 = 24 clips in May.', 'Natalia sold 48+24 = 72 clips altogether in April and May.']

---

Original:
Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
The answer is 10.

Extracted Steps:
['Weng earns 12/60 = $0.2 per minute.', 'Working 50 minutes, she earned 0.2 x 50 = $10.']

---



In [12]:
class GSM8KDataset(Dataset):
    def __init__(self, split="train", tokenizer=None, max_length=768, max_samples=None):
        self.data = load_dataset("gsm8k", "main")[split]
        if max_samples:
            self.data = self.data.select(range(min(max_samples, len(self.data))))
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.processed_data = self.preprocess_data()
        
    def preprocess_data(self):
        processed = []
        for item in tqdm(self.data, desc="Preprocessing data"):
            question = item["question"]
            answer_with_cot = item["answer"]
            
            # Extract the CoT steps and the final answer
            final_answer = extract_final_answer(answer_with_cot)
            cot_steps = extract_cot_steps(answer_with_cot)

            print( cot_steps )
            print( final_answer )
            
            # Format for T5 training - improved prompt to guide the model
            formatted_question = f"Solve this math problem step-by-step: {question} {COT_PROMPT}"
            
            processed.append({
                "question": question,
                "formatted_question": formatted_question,
                "cot_steps": cot_steps,
                "final_answer": final_answer,
                "full_answer": answer_with_cot
            })
        return processed
    
    def __len__(self):
        return len(self.processed_data)
    
    def __getitem__(self, idx):
        item = self.processed_data[idx]
        
        if self.tokenizer:
            # Prepare input with task-specific prefix for T5
            input_text = item["formatted_question"]
            target_text = item["full_answer"]
            
            # Improved tokenization with more balanced token allocation
            try:
                inputs = self.tokenizer(
                    input_text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length // 3,  # Allow more space for output
                    return_tensors="pt"
                )
                
                targets = self.tokenizer(
                    target_text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length * 2 // 3,  # Allow more space for reasoning
                    return_tensors="pt"
                )
                
                return {
                    "input_ids": inputs.input_ids.squeeze(),
                    "attention_mask": inputs.attention_mask.squeeze(),
                    "labels": targets.input_ids.squeeze(),
                    "raw_question": item["question"],
                    "raw_cot": item["cot_steps"],
                    "raw_answer": item["final_answer"]
                }
            except Exception as e:
                print(f"Error tokenizing item {idx}: {e}")
                # Return a simple fallback
                dummy_tensor = torch.zeros(self.max_length, dtype=torch.long)
                return {
                    "input_ids": dummy_tensor,
                    "attention_mask": dummy_tensor,
                    "labels": dummy_tensor,
                    "raw_question": item["question"],
                    "raw_cot": item["cot_steps"],
                    "raw_answer": item["final_answer"]
                }
        else:
            return item

In [None]:
class EnhancedCoTGenerator:
    def __init__(self, 
                 model_name="t5-base", 
                 local_dir="./models/t5_base_cache",
                 max_steps=8):
        
        self.model_name = model_name
        self.local_dir = local_dir
        self.max_steps = max_steps
        
        os.makedirs(self.local_dir, exist_ok=True)
        
        self._load_model()
        
        self.memory_bank = []
    
    def _load_model(self):
        try:
            device = torch.device("cuda" if torch.cuda.is_available() else 
                                  "mps" if torch.backends.mps.is_available() else "cpu")
            
            # Use T5 tokenizer for encoder-decoder models
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name, 
                cache_dir=self.local_dir,
                use_fast=True
            )
            
            # For T5 models, use AutoModelForSeq2SeqLM instead of AutoModelForCausalLM
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                self.model_name,
                cache_dir=self.local_dir,
                torch_dtype=torch.float16 if device.type != "cpu" else torch.float32,
                low_cpu_mem_usage=True
            )
            
            self.model = self.model.to(device)
            print(f"Model loaded and moved to {device}")
            
            # Check if we're using a T5 or Flan-T5 model which are good for math reasoning
            if not any(t5_name in self.model_name.lower() for t5_name in ["t5", "flan-t5"]):
                print("Warning: This generator is optimized for T5 or Flan-T5 models. Consider using t5-base, t5-large, flan-t5-base, or flan-t5-large for better mathematical reasoning.")
        
        except Exception as e:
            print(f"Error loading model: {e}")
            raise
    
    def _create_prompt(self, question, previous_steps=None):
        # T5 models work best with task-specific prefixes
        prompt = f"Solve step by step: {question}"
        
        # Add previous steps if available
        if previous_steps and len(previous_steps) > 0:
            steps_text = " ".join([f"Step {i+1}: {step}" for i, step in enumerate(previous_steps)])
            prompt = f"{prompt} {steps_text} Next step:"
        
        return prompt
    
    def _extract_next_step(self, generated_text, previous_steps):
        # Clean up the generated text
        clean_text = generated_text.strip()
        
        # Check if this looks like a final answer
        is_final_step = any(phrase in clean_text.lower() 
                           for phrase in ['answer is', 'final answer', 'thus', 'therefore', 
                                          'the answer', 'so the answer', 'so,', 'equals'])
        
        # If there's a numerical answer with an equals sign, it's likely a final step
        if re.search(r'=\s*-?\d+(\.\d+)?', clean_text):
            is_final_step = True
            
        return {
            "step": clean_text,
            "is_final_step": is_final_step
        }
    
    def generate(self, question, reflection_module=None):
        cot_steps = []
        
        for step_num in range(self.max_steps):
            prompt = self._create_prompt(question, cot_steps)
            
            try:
                # Encode the input for T5
                inputs = self.tokenizer(
                    prompt, 
                    return_tensors="pt", 
                    padding=True, 
                    truncation=True,
                    max_length=512
                ).to(self.model.device)
                
                # Generate with T5
                with torch.no_grad():
                    outputs = self.model.generate(
                        input_ids=inputs.input_ids,
                        attention_mask=inputs.attention_mask,
                        max_length=150,  # T5 typically needs shorter output lengths
                        num_return_sequences=1,
                        do_sample=True,
                        temperature=0.7,
                        top_p=0.9,
                        top_k=50
                    )
                
                # Decode the generated output
                generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                
                step_result = self._extract_next_step(generated_text, cot_steps)
                current_step = step_result["step"]
                
                # Apply reflection if available
                if reflection_module:
                    if not reflection_module.evaluate_step(current_step, question):
                        current_step = reflection_module.refine_step(question, current_step, cot_steps)
                
                cot_steps.append(current_step)
                
                if step_result["is_final_step"]:
                    # Extract numerical answer if present
                    answer_match = re.search(r'=\s*(-?\d+(?:\.\d+)?)', current_step)
                    if answer_match:
                        final_answer = answer_match.group(1)
                    else:
                        # Try to find the number at the end of the step
                        number_match = re.search(r'(-?\d+(?:\.\d+)?)\s*$', current_step)
                        final_answer = number_match.group(1) if number_match else current_step
                    break
            
            except Exception as e:
                print(f"Error generating step {step_num + 1}: {e}")
                break
        
        full_output = "\n".join(cot_steps)
        
        # Extract final answer if not already done
        if 'final_answer' not in locals():
            final_answer = extract_final_answer(full_output)
        
        return {
            "cot_steps": cot_steps,
            "final_answer": final_answer,
            "full_output": full_output
        }
    
    def add_to_memory_bank(self, cot_sequence):
        self.memory_bank.append(cot_sequence)
    
    def retrieve_similar_cot(self, question, top_k=3):
        # Need to write mebeddings model here.
        # For now, returning the most recent examples
        return self.memory_bank[-top_k:] if len(self.memory_bank) >= top_k else self.memory_bank


In [16]:
def test_cot_generator():
        try:
            # Sample math problems to test
            test_problems = [
                "Natalia sold clips in May. If she sold 48/2 clips in May, how many clips did she sell?",
                "Weng earns money per minute. If she earns 12/60 dollars per minute and works 50 minutes, how much does she earn?",
                "Betty wants to save $100. Her grandparents gave her $30. She already has $50. How much more does she need to save?",
                "Maila wants to read a 120-page book. She read 12 pages today and will read half the remaining pages tomorrow. How many pages will she read tomorrow?"
            ]
            
            # Initialize the CoT Generator
            generator = EnhancedCoTGenerator(
                model_name="t5-base",
                local_dir="./models/t5-base-cache"
            )
            
            # Test each problem
            for i, problem in enumerate(test_problems, 1):
                print(f"\n{'='*50}")
                print(f"Problem {i}: {problem}")
                print(f"{'='*50}")
                
                # Generate Chain of Thought
                result = generator.generate(problem)
                
                # Print results
                print("\nChain of Thought Steps:")
                for j, step in enumerate(result['cot_steps'], 1):
                    print(f"{j}. {step}")
                
                print(f"\nFinal Answer: {result['final_answer']}")
                print(f"{'='*50}\n")
        
        except Exception as e:
            print("Error running test:")
            traceback.print_exc()

In [17]:
test_cot_generator()

Error loading model: Unrecognized configuration class <class 'transformers.models.t5.configuration_t5.T5Config'> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of AriaTextConfig, BambaConfig, BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, LlamaConfig, CodeGenConfig, CohereConfig, Cohere2Config, CpmAntConfig, CTRLConfig, Data2VecTextConfig, DbrxConfig, DiffLlamaConfig, ElectraConfig, Emu3Config, ErnieConfig, FalconConfig, FalconMambaConfig, FuyuConfig, GemmaConfig, Gemma2Config, GitConfig, GlmConfig, GotOcr2Config, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, GraniteConfig, GraniteMoeConfig, GraniteMoeSharedConfig, HeliumConfig, JambaConfig, JetMoeConfig, LlamaConfig, MambaConfig, Mamba2Config, MarianConfig, MBartConfig, MegaConfig, MegatronBertConfig, MistralConfig, MixtralConfig

Traceback (most recent call last):
  File "/var/folders/_2/xx5z8xdj6j98wh4vt2b59jz80000gp/T/ipykernel_19783/2416326621.py", line 12, in test_cot_generator
    generator = EnhancedCoTGenerator(
  File "/var/folders/_2/xx5z8xdj6j98wh4vt2b59jz80000gp/T/ipykernel_19783/1004135450.py", line 13, in __init__
    self._load_model()
  File "/var/folders/_2/xx5z8xdj6j98wh4vt2b59jz80000gp/T/ipykernel_19783/1004135450.py", line 31, in _load_model
    self.model = AutoModelForCausalLM.from_pretrained(
  File "/Users/Viku/GitHub/ReCoT/.venv/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 567, in from_pretrained
    raise ValueError(
ValueError: Unrecognized configuration class <class 'transformers.models.t5.configuration_t5.T5Config'> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of AriaTextConfig, BambaConfig, BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BlenderbotConfig, BlenderbotSmallCon

In [None]:
test_cot_generator()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
