# Synthetic CoT Generation

This notebook is used to generate synthetic CoT solutions for the MATH dataset.

In [14]:
from transformer_lens import HookedTransformer
import torch
import circuitsvis as cv
import einops
from IPython.display import display
import numpy as np
from pprint import pprint
from datasets import load_dataset
import random
from tqdm import tqdm
import os
import json

In [2]:
device = torch.device(
    "mps" if torch.backends.mps.is_available() else 
    "cuda" if torch.cuda.is_available() else 
    "cpu"
)
print('Got device:', device)

Got device: cuda


Disable gradient computation for all tensors to speed up inference and save memory.

In [3]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f637c91eb90>

For this notebook to work, please add 

```python
deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
```

to the `OFFICIAL_MODEL_NAMES` list in the `loading_from_pretrained.py` file under the `TransformerLens` library after you've downloaded it locally.

It seems one can also increase the `n_ctx` for architectures `"QWenLMHeadModel"` and `"QWen2ForCausalLM"` from 2048 to 4096 because the documentation mentions that they are capped due to memory constraints.

In [4]:
model = HookedTransformer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
model = model.to(device)
model.cfg.n_ctx = 2048

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loaded pretrained model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B into HookedTransformer
Moving model to device:  cuda


In [5]:
print(f"📏 Model context length: {model.cfg.n_ctx}")
print(f"🧠 Model layers: {model.cfg.n_layers}")
print(f"🔤 Vocabulary size: {model.cfg.d_vocab}")
print(f"📊 Hidden dimension: {model.cfg.d_model}")
print(f"🧩 Attention heads: {model.cfg.n_heads}")
print(f"🏷️ Model name: {model.cfg.model_name}")

📏 Model context length: 2048
🧠 Model layers: 28
🔤 Vocabulary size: 151936
📊 Hidden dimension: 1536
🧩 Attention heads: 12
🏷️ Model name: DeepSeek-R1-Distill-Qwen-1.5B


In [8]:
# Load the MATH dataset
math_dataset = load_dataset("fdyrd/math")

# Examine the dataset structure
print(f"Available splits: {list(math_dataset.keys())}")
print(f"Available fields: {list(math_dataset['train'].features.keys())}")
print(f"Number of examples in train: {len(math_dataset['train'])}")
print(f"Number of examples in test: {len(math_dataset['test'])}")

# Look at the first example to understand the format
print("\nExample from the dataset:")
example = math_dataset['train'][0]
print(f"Problem: {example['problem']}")
print(f"Level: {example['level']}")
print(f"Type: {example['type']}")
print(f"Solution: {example['solution']}")

Available splits: ['train', 'test']
Available fields: ['problem', 'level', 'type', 'solution']
Number of examples in train: 7500
Number of examples in test: 5000

Example from the dataset:
Problem: Let \[f(x) = \left\{
\begin{array}{cl} ax+3, &\text{ if }x>2, \\
x-5 &\text{ if } -2 \le x \le 2, \\
2x-b &\text{ if } x <-2.
\end{array}
\right.\]Find $a+b$ if the piecewise function is continuous (which means that its graph can be drawn without lifting your pencil from the paper).
Level: Level 5
Type: Algebra
Solution: For the piecewise function to be continuous, the cases must "meet" at $2$ and $-2$. For example, $ax+3$ and $x-5$ must be equal when $x=2$. This implies $a(2)+3=2-5$, which we solve to get $2a=-6 \Rightarrow a=-3$. Similarly, $x-5$ and $2x-b$ must be equal when $x=-2$. Substituting, we get $-2-5=2(-2)-b$, which implies $b=3$. So $a+b=-3+3=\boxed{0}$.


In [9]:
# Function to sample problems from the dataset
def sample_math_problems(dataset, n=5, level=None, problem_type=None):
    """
    Sample n problems from the dataset, optionally filtering by level or type.
    
    Args:
        dataset: The MATH dataset
        n: Number of problems to sample
        level: Optional filter for problem difficulty (e.g., "Level 1")
        problem_type: Optional filter for problem type (e.g., "Algebra")
    
    Returns:
        List of sampled problems
    """
    filtered_dataset = dataset['train']
    
    if level:
        filtered_dataset = [ex for ex in filtered_dataset if ex['level'] == level]
    
    if problem_type:
        filtered_dataset = [ex for ex in filtered_dataset if ex['type'] == problem_type]
    
    filtered_dataset = list(filtered_dataset)  # Convert to list to ensure it's a sequence
    return random.sample(filtered_dataset, min(n, len(filtered_dataset)))

In [11]:
sampled_problems = sample_math_problems(math_dataset, n=3, level="Level 3")
print("\nSampled problems for testing:")
for i, problem in enumerate(sampled_problems):
    print(f"\nProblem {i+1}:")
    print(f"Type: {problem['type']}, Level: {problem['level']}")
    print(f"Problem statement: {problem['problem']}")


Sampled problems for testing:

Problem 1:
Type: Number Theory, Level: Level 3
Problem statement: How many different positive values of $x$ will make this statement true: there are exactly $2$ positive two-digit multiples of $x$.

Problem 2:
Type: Algebra, Level: Level 3
Problem statement: Find the value of $t$ that satisfies $\frac{1}{t+2} + \frac{2t}{t+2} - \frac{3}{t+2} = 3$.

Problem 3:
Type: Prealgebra, Level: Level 3
Problem statement: The scores on a $110$-point test were organized in the stem-and-leaf plot shown. $9 | 6$ represents $96$ points. What is the mode of the scores? \begin{tabular}{c|lllllll}
\multicolumn{8}{c}{\underline{Points on the Test}}\\
5 &0 & 0 & & & & &\\
6 &3 & & & & & &\\
7 &7 & 8 & & & & &\\
8 &2 & 6 & 7 & 9 & 9 & 9 & 9\\
9 &1 & 4 & 4 & 4 & 6 & &\\
10 &0 & 0 & 0 & & & &\\
\end{tabular}


In [12]:
# Function to generate CoT using the model
def generate_cot_for_problem(
    model: HookedTransformer, 
    problem: str, 
    temperature: float = 0.4, 
    max_new_tokens: int = 1500, 
    top_p: float = 0.92
):
    """
    Generate a chain-of-thought solution for a given math problem.
    
    Args:
        model: The HookedTransformer model
        problem: The math problem text
        temperature: The temperature for the model
        max_new_tokens: The maximum number of tokens to generate
        top_p: The top-p value for the model
    Returns:
        The generated chain-of-thought solution
    """
    prompt = f"""Solve this math problem step by step. Put your final answer in \\boxed{{}}. Problem: {problem} Solution: \n<think>\n"""
    result = model.generate(prompt, 
                            temperature=temperature,
                            max_new_tokens=max_new_tokens,
                            top_p=top_p)
    return result

In [None]:
# Select a problem
problem_text = sampled_problems[0]['problem']

# Generate CoT
cot_solution = generate_cot_for_problem(
    model, 
    problem_text, 
    temperature=0.6, 
    max_new_tokens=1500, 
    top_p=0.92
)
print("\nGenerated Chain-of-Thought solution:")
print(cot_solution)

In [17]:
# Function to batch process multiple problems
def batch_generate_cot(
    model, 
    problems, 
    batch_size=4,  # Process this many problems in parallel
    temperature=0.6, 
    max_new_tokens=1500, 
    top_p=0.92, 
    save_every=5,
    save_path=None
):
    """
    Generate CoT solutions for multiple problems in parallel batches.
    
    Args:
        model: The HookedTransformer model
        problems: List of problem dictionaries
        batch_size: Number of problems to process in parallel
        temperature: The temperature for the model
        max_new_tokens: The maximum number of tokens to generate
        top_p: The top-p value for the model
        save_every: How often to save intermediate results
        save_path: Optional path to save results
    
    Returns:
        List of dictionaries containing problems and their CoT solutions
    """
    results = []
    
    # Check if save_path exists and load existing results
    if save_path and os.path.exists(save_path):
        import json
        print(f"Loading existing results from {save_path}...")
        with open(save_path, 'r') as f:
            results = json.load(f)
        
        # Get existing problem IDs
        existing_ids = set(result["problem_id"] for result in results)
        print(f"Found {len(existing_ids)} existing problems. Continuing from where we left off.")
    else:
        existing_ids = set()
    
    # Process problems in batches
    for i in tqdm(range(0, len(problems), batch_size), desc="Processing batches"):
        batch_problems = problems[i:i+batch_size]
        
        # Filter out problems that have already been processed
        filtered_batch = []
        filtered_indices = []
        for j, problem in enumerate(batch_problems):
            problem_id = i + j
            if problem_id not in existing_ids:
                filtered_batch.append(problem)
                filtered_indices.append(j)
        
        # Skip if all problems in this batch have been processed
        if not filtered_batch:
            continue
            
        # Prepare prompts for the filtered batch
        prompts = [
            f"""Solve this math problem step by step. Put your final answer in \\boxed{{}}. Problem: {problem['problem']} Solution: \n<think>\n"""
            for problem in filtered_batch
        ]
        
        # Generate solutions for the batch in parallel
        try:
            batch_solutions = model.generate(
                prompts,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                top_p=top_p,
            )
        except Exception as e:
            print(f"Error generating solutions for batch {i}: {e}")
            continue
        
        # Process and store results
        solutions_list = batch_solutions if len(filtered_batch) > 1 else [batch_solutions]
        for idx, (j, problem, solution) in enumerate(zip(filtered_indices, filtered_batch, solutions_list)):
            problem_id = i + batch_problems.index(problem)
            results.append({
                "problem_id": problem_id,
                "problem_text": problem['problem'],
                "problem_type": problem['type'],
                "problem_level": problem['level'],
                "ground_truth_solution": problem['solution'],
                "generated_cot": solution
            })
        
        # Save intermediate results
        if (i // batch_size) % save_every == 0 and save_path:
            print(f"Saving results to {save_path}...")
            import json
            with open(save_path, 'w') as f:
                json.dump(results, f, indent=2)
    
    # Save final results
    if save_path:
        import json
        with open(save_path, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Results saved to {save_path}")
    
    return results

In [None]:
# Sample a small set of problems for testing
random.seed(42)
np.random.seed(42)
test_problems = sample_math_problems(math_dataset, n=1000)

temperature = 0.8
max_new_tokens = 3600
top_p = 0.92
batch_size = 1
model.cfg.n_ctx = 4096

# Generate CoT solutions for the test problems
cot_results = batch_generate_cot(
    model, 
    test_problems, 
    batch_size=batch_size,
    temperature=temperature, 
    max_new_tokens=max_new_tokens, 
    top_p=top_p, 
    save_path=f"math_cot_results_t={temperature}_mnt={max_new_tokens}_tp={top_p}.json",
    save_every=1
)