In [1]:
TOY_AIME_24_QUESTIONS_FILE = "ripe_dataset.jsonl"
QUESTIONS_LIM = 10
CONCEPT_LIM = 100
CONCEPTS_ASKED_PER_QUESTION = 7
PROPERTIES_ASKED_PER_CONCEPTS = 7
NUM_PROPERTIES_FOR_PROBLEM_GEN = 5
PROBLEM_RANDOM_INTEGER_INPUT_RANGE = 30
PROBLEM_NUM_RANDOM_INTEGERS = 5
MAX_PROBLEM_NUM_INTEGER_INPUT = PROBLEM_NUM_RANDOM_INTEGERS
MAX_LLM_RETRIES = 5
NUM_CHECKS_ALL_IMPLEMENTATIONS_GEN_SAME_RESULT = 100

INT_AGGREGATORS = [
    "sum of squares",
    "sum of cubes",
    "alternating sum",
    "the sum of values modulo 7",
    "the sum of values modulo 11",
]

CONCEPTS_OUTPUT_FILE = "../synth-data/syn-concepts.jsonl"
PROPERTIES_OUTPUT_FILE = "../synth-data/syn-prop.jsonl"
CODE_OUTPUT_FILE = "../synth-data/syn-gen-code.jsonl"
OUTPUT_FILE = "../synth-data/syn-problems.jsonl"

In [2]:
import json
import ast
import os
import itertools
import re
import random
from dotenv import load_dotenv
from groq import Groq
# from cerebras.cloud.sdk import Cerebras

In [3]:
# Load environment variables from .env file
load_dotenv()

True

In [4]:
def load_jsonl(file_path):
    with open(file_path, 'r') as f:
        return [json.loads(line) for line in f]

def save_to_jsonl(file_path, a_list):
    # Create directory if it doesn't exist
    dirname = os.path.dirname(file_path)
    if len(dirname.strip()) > 0:
        os.makedirs(dirname, exist_ok=True)

    # Save to JSONL file
    with open(file_path, 'w') as f:
        for item in a_list:
            f.write(json.dumps(item) + '\n')

    print(f"Saved {len(a_list)} records to {file_path}")

### Debug Toy AIME 24 Questions

In [5]:
input_questions = load_jsonl(TOY_AIME_24_QUESTIONS_FILE)
data_with_answers = input_questions[:1]
data_with_answers
print(data_with_answers[0]['problem'])

A list of positive integers has the following properties:
$\bullet$ The sum of the items in the list is $30$.
$\bullet$ The unique mode of the list is $9$.
$\bullet$ The median of the list is a positive integer that does not appear in the list itself.
 Find the sum of the squares of all the items in the list. Let's think step by step and output the final answer within \boxed{}.


### Load S1 data with answers

In [6]:
# def get_questions(split="train", num_samples=None) -> Dataset:
#     data = load_dataset('simplescaling/data_ablation_full59K', split=split)
    
#     if num_samples is not None:
#         # Randomly sample the specified number of examples
#         data = data.shuffle(seed=42).select(range(min(num_samples, len(data))))
    
#     return data

In [7]:
# data_raw = get_questions()
# # data_raw = get_questions(num_samples=10) # Debug

# print(len(data_raw))

In [8]:
# data_with_answers = []
# for k in range(len(data_raw)):
#   answer = ast.literal_eval(data_raw[k]['metadata']).get('answer', None)
#   if answer is not None:
#     data_with_answers.append({
#       'problem': data_raw[k]['question'],
#       'answer': answer
#     })

# print(len(data_with_answers))

In [9]:
# print(data_raw[0].keys())
# print()

# for k in range(3):
#   print("Question:")
#   print(data_with_answers[k]['problem'])
#   print("Answer:")
#   print(data_with_answers[k]['answer'])
#   print()

### Use Inference APIs to generate math concepts used and properties of those concepts.

In [11]:
groq_client = Groq()

def run_groq_inference_qwq_32b(groq_client, prompt):
    completion = groq_client.chat.completions.create(
        model="qwen-qwq-32b",
        messages=[
            {
                "role": "user",
                "content": prompt
            },
        ],
        temperature=0.6,
        max_completion_tokens=131072,
        top_p=0.9,
        stream=False,
        stop=None,
    )
    return completion.choices[0].message.content

def extract_after_think(text):
    # Find the last occurrence of </think> and extract everything after it
    pattern = r'</think>(.*?)$'
    match = re.search(pattern, text, re.DOTALL)  # re.DOTALL to match newlines
    if match:
        return match.group(1).strip()
    else:
        return None

def extract_last_python_code(text):
    """Extract the code from the last ```python ... ``` code block in the text."""
    # Pattern to match Python code blocks
    pattern = r"```python\s*(.*?)\s*```"
    
    # Find all matches, using re.DOTALL to match across multiple lines
    matches = re.findall(pattern, text, re.DOTALL)
    
    # Return the last match if any matches are found
    if matches:
        return matches[-1].strip()
    else:
        return ""

def extract_last_json_code(text):
    """Extract the code from the last ```json ... ``` code block in the text."""
    # Pattern to match json data blocks
    pattern = r"```json\s*(.*?)\s*```"
    
    # Find all matches, using re.DOTALL to match across multiple lines
    matches = re.findall(pattern, text, re.DOTALL)
    
    # Return the last match if any matches are found
    if matches:
        return matches[-1].strip()
    else:
        return ""

def replace_random_integer(a_list_of_str, orig_random_integer, new_random_integer):
  return [a_str.replace(str(orig_random_integer), str(new_random_integer)) for a_str in a_list_of_str]

def load_and_execute_code(code_string):
    """
    Load code from a string into Python context and make its functions available
    
    Args:
        code_string: String containing Python code to execute
        
    Returns:
        Dictionary containing all definitions from the executed code
    """
    # Create a namespace dictionary to hold the functions
    namespace = {}
    
    # Execute the code in this namespace
    exec(code_string, namespace)
    
    # Return the namespace that now contains all defined functions/classes
    return namespace

def find_integers_in_string(text):
    """
    Find all integers in a string using regular expressions.
    
    Args:
        text: String to search for integers
        
    Returns:
        List of integers found in the string
    """
    # Pattern matches integers (including negative numbers)
    # \b ensures we match whole numbers, not parts of words/other numbers
    pattern = r'\b-?\d+\b'
    
    # Find all matches
    matches = re.findall(pattern, text)
    
    # Convert matches to integers
    integers = [int(match) for match in matches]
    
    return integers

In [None]:
all_concepts = set()
num_questions = min(len(data_with_answers), QUESTIONS_LIM) if QUESTIONS_LIM else len(data_with_answers)

for k in range(min(len(data_with_answers), 10)):
  prompt = f"""
  What are the math concepts used in the problem below and related math concepts? Try to list {CONCEPTS_ASKED_PER_QUESTION} concepts and if ambiguous, add their field of math in parenthesis.

  Don't solve problem, just give the math concepts used as a json list.

  Please only return the math concepts as a json list in ```json ``` tags.

  Example answers:
  ```json
  ["median (statistics)", "mode (statistics)", "mean (statistics)", "frequency (statistics)", "sum", "sum of squares", ...], ["set", "subsets", "permutations", "combinations", "probability", "coin flips", ...], ["linear equations", "systems of equations", "substitution", ...], etc.
  ```

  {data_with_answers[k]['problem']}
  """
  print("="*100)
  print("Prompt:")
  print(prompt)
  think_answer = run_groq_inference_qwq_32b(groq_client, prompt)
  answer = extract_last_json_code(extract_after_think(think_answer))
  print("\nAnswer:")
  print(answer)
  print()
  if answer is not None:
    concepts = ast.literal_eval(answer) # ast doesn't error out on latex formulas.
    for concept in concepts:
      if concept not in all_concepts:
        all_concepts.add(concept)
  
  save_to_jsonl(CONCEPTS_OUTPUT_FILE, all_concepts)


Prompt:

  What are the math concepts used in the problem below and related math concepts? Try to list 7 concepts and if ambiguous, add their field of math in parenthesis.

  Don't solve problem, just give the math concepts used as a json list.

  Please only return the math concepts as a json list in ```json ``` tags.

  Example answers:
  ```json
  ["median (statistics)", "mode (statistics)", "mean (statistics)", "frequency (statistics)", "sum", "sum of squares", ...], ["set", "subsets", "permutations", "combinations", "probability", "coin flips", ...], ["linear equations", "systems of equations", "substitution", ...], etc.
  ```

  A list of positive integers has the following properties:
$\bullet$ The sum of the items in the list is $30$.
$\bullet$ The unique mode of the list is $9$.
$\bullet$ The median of the list is a positive integer that does not appear in the list itself.
 Find the sum of the squares of all the items in the list. Let's think step by step and output the final 

In [12]:
all_properties = set()
all_concepts_list = list(all_concepts)
num_concepts = min(len(all_concepts), CONCEPT_LIM) if CONCEPT_LIM else len(all_concepts)

for k in range(num_concepts):
  prompt = f"""
  Please list key math properties leveraging a formula for the concept below. Try to list {PROPERTIES_ASKED_PER_CONCEPTS} math properties leveraging a formula.

  Just return all key math properties leveraging formulas as a json list, nothing else.
  
  Don't include a math properties that don't leverage a formula. Please only use precise formulas (no approximations).

  Please only return the math concepts as a json list in ```json ``` tags.

  Example answer for input "logarithm math properties":
  ```json
  ["log_b(a) = c if and only if b^c = a", "log_b(M*N) = log_b(M) + log_b(N) (Product Rule)", ...]
  ```

  {all_concepts_list[k]}
  """
  print("="*100)
  print("Prompt:")
  print(prompt)
  think_answer = run_groq_inference_qwq_32b(groq_client, prompt)
  answer = extract_last_json_code(extract_after_think(think_answer))
  print("\nAnswer:")
  print(answer)
  print()
  if answer is not None:
    properties = ast.literal_eval(answer) # ast doesn't error out on latex formulas.
    for property in properties:
      contextualized_property = f"Property of {all_concepts_list[k]}: {property}"
      if contextualized_property not in all_properties:
        all_properties.add(contextualized_property)
  
save_to_jsonl(PROPERTIES_OUTPUT_FILE, all_properties)

Prompt:

  Please list key math properties leveraging a formula for the concept below. Try to list 7 math properties leveraging a formula.

  Just return all key math properties leveraging formulas as a json list, nothing else.
  
  Don't include a math properties that don't leverage a formula. Please only use precise formulas (no approximations).

  Please only return the math concepts as a json list in ```json ``` tags.

  Example answer for input "logarithm math properties":
  ```json
  ["log_b(a) = c if and only if b^c = a", "log_b(M*N) = log_b(M) + log_b(N) (Product Rule)", ...]
  ```

  sum of squares (algebra)
  

Answer:
[
  "Sum_{k=1}^n k^2 = \\frac{n(n+1)(2n+1)}{6}",
  "Sum_{k=1}^n (2k)^2 = \\frac{2}{3}n(n+1)(2n+1)",
  "Sum_{k=1}^n (2k-1)^2 = \\frac{n(2n-1)(2n+1}{3}",
  "Sum_{i=1}^n x_i^2 = \\left(Sum_{i=1}^n x_i\\right)^2 - 2Sum_{1 \\le i < j \\le n} x_i x_j",
  "Sum_{k=0}^n \\binom{n}{k}^2 = \\binom{2n}{n}",
  "Sum_{k=0}^{n-1} (a + kd)^2 = n a^2 + 2a d \\frac{n(n-1)}{2} + d

In [102]:
def generate_problem(sampled_properties, sampled_aggregator, input_integers, print_prompt=False):
  sampled_properties_str = '\n'.join(sampled_properties)
  prompt = f"""
Here are math properties:
{sampled_properties_str}

Please create a problem where the underlying solution space is a list of exactly {PROBLEM_NUM_RANDOM_INTEGERS} random integers
in the range [1, {PROBLEM_RANDOM_INTEGER_INPUT_RANGE}] and the problem has a unique solution.

The problem should start with: A list of {PROBLEM_NUM_RANDOM_INTEGERS} integers in the range [1, {PROBLEM_RANDOM_INTEGER_INPUT_RANGE}].

The problem should finish with: 'What is the {sampled_aggregator} of the list of {PROBLEM_NUM_RANDOM_INTEGERS} integers?'

Solving the problem must include a non-trivial steps using at most 1 of the formulas in the math properties highlighted above and 1 or more input integers {input_integers}.

Please return the problem as a json list of strings where each string is a problem requirement, including start statement and the final question, and nothing else.
  """
  if print_prompt:
    print(prompt)
  think_answer = run_groq_inference_qwq_32b(groq_client, prompt)
  sampled_problem = extract_after_think(think_answer)
  return sampled_problem

def iterate_problem(sampled_properties, sampled_aggregator, input_integers, problem, code_and_output, print_prompt=False):
  sampled_properties_str = '\n'.join(sampled_properties)
  prompt = f"""
Here are math properties:
{sampled_properties_str}

Here is a problem:
{problem}
========================================

Here is code that checks if a solution is valid and some solutions found for the problem, if any exist:
{code_and_output}
========================================

Please iterate on the problem above where the underlying solution space is a list of exactly {PROBLEM_NUM_RANDOM_INTEGERS} random integers
in the range [1, {PROBLEM_RANDOM_INTEGER_INPUT_RANGE}] so that the problem has a unique solution.

The problem should start with statement: A list of {PROBLEM_NUM_RANDOM_INTEGERS} integers in the range [1, {PROBLEM_RANDOM_INTEGER_INPUT_RANGE}].

The problem should finish with: 'What is the {sampled_aggregator} of the list of {PROBLEM_NUM_RANDOM_INTEGERS} integers?'

Solving the problem must include a non-trivial steps using at most 1 of the formulas in the math properties highlighted above and 1 or more input integers {input_integers}.

Please return the problem as a json list of strings where each string is a problem requirement, including start statement and the final question, and nothing else.
  """
  if print_prompt:
    print(prompt)
  think_answer = run_groq_inference_qwq_32b(groq_client, prompt)
  sampled_problem = extract_after_think(think_answer)
  return sampled_problem

def generate_check_solution_code(sampled_problem, activated_input_integers, previous_code_and_error=None, print_prompt=False):
  prompt = f"""
Here is a problem: {sampled_problem}
========================================

The problem uses the random integer parameters {activated_input_integers}, we'll call them params.

Please write code for a python method called 'check_solution' that takes as input the proposed list of exactly {PROBLEM_NUM_RANDOM_INTEGERS} integers
that form a solution to the problem along with an extra params dictionary, and checks if the solution is valid.

Example Problem: 'Median of the list of 5 integers is 10. Sum of the list of 5 integers is 55.'
Example params: {{'guiding_param_1': 10, 'guiding_param_2': 55}}

Example code:
def check_solution(i, j, k, l, m, **params):
  import statistics
  l = [i, j, k, l, m]
  m = statistics.median(l)
  s = sum(l)
  return m == params['guiding_param_1'] and s == params['guiding_param_2']

Please only return code in in ```python ``` tags.
  """
  if previous_code_and_error is not None and len(previous_code_and_error.strip()) > 0:
    previous_code_and_error_prompt = f"""

    Previous Code and Error:
    {previous_code_and_error}
    """
    prompt = prompt + previous_code_and_error_prompt
  if print_prompt:
    print(prompt)
  think_answer = run_groq_inference_qwq_32b(groq_client, prompt)
  answer = extract_after_think(think_answer)
  code = extract_last_python_code(answer)  # Using the previous function
  return code

def brute_force_solutions(check_solution, num_inputs, **params):
    """
    Parametrized brute force solution finder that works with any number of inputs.
    
    Args:
        check_solution: Function that checks if a solution is valid
        num_inputs: Number of input variables to iterate through
        params: Additional parameters to pass to check_solution
        
    Returns:
        Set of solution tuples
    """
    solutions = set()
    STOP_AT_MAX_SOLUTIONS = 20
    
    # Generate all combinations using itertools
    for values in itertools.product(range(PROBLEM_RANDOM_INTEGER_INPUT_RANGE), repeat=num_inputs):
        if check_solution(*values, **params):
            solutions.add(values)
            if len(solutions) >= STOP_AT_MAX_SOLUTIONS:
                return solutions
    
    return solutions

In [None]:
condition_method_str = """
def check_solution(i, j, k, l, **params):
  zero_eq_check_1 = params['guiding_param_1'] * i + params['guiding_param_2'] * j
  zero_eq_check_2 = params['guiding_param_3'] * k + params['guiding_param_4'] * l
  above_zero_check = i > 0 and j > 0 and k > 0 and l > 0
  return zero_eq_check_1 and zero_eq_check_2 and above_zero_check
"""
namespace = {}
exec(condition_method_str, namespace)
check_solution = namespace['check_solution']

print(brute_force_solutions(check_solution, 4, **{
  'guiding_param_1': 10, 'guiding_param_2': -5, 'guiding_param_3': 6, 'guiding_param_4': -9,
}))

NameError: name 'brute_force_solutions' is not defined

In [109]:
def gen_problem_with_retry(sampled_properties, sampled_aggregator, input_integers, max_retries=MAX_LLM_RETRIES, debug=False):
  total_iterations = 0

  while total_iterations < max_retries:
    total_iterations += 1

    try:
      problem = generate_problem(sampled_properties, sampled_aggregator, input_integers, print_prompt=debug)
    except Exception as e:
      if debug:
        print(f"Error generating problem: {e}")
      continue
    if problem is None or len(problem.strip()) == 0:
      if debug:
        print("No problem generated, skipping.")
      continue

    if debug:
      print("Problem:")
      print(problem)
      print()
    
    problem_statements_str = extract_last_json_code(problem)
    if problem_statements_str is None or len(problem_statements_str.strip()) == 0:
      problem_statements_str = problem
    
    try:
      problem_statements = ast.literal_eval(problem_statements_str) # ast doesn't error out on latex formulas.
    except:
      print("Error parsing problem statements.")
      problem_statements = None
      continue
    if problem_statements is None or len(problem_statements) == 0:
      if debug:
        print("No problem statements extracted into a list, skipping.")
      continue

    # Print problem.
    if debug:
      print("Problem statements:")
      for statement in problem_statements:
        print(statement)
      print()
    
    if total_iterations >= MAX_LLM_RETRIES:
      print(f"Tried to generate problem {total_iterations} times, failed and stopped.\n")
      return None

    return problem_statements

def iterate_problem_with_retry(sampled_properties, sampled_aggregator, input_integers, previous_problem_statements, code_and_output, max_retries=MAX_LLM_RETRIES, debug=False):
  total_iterations = 0

  while total_iterations < max_retries:
    total_iterations += 1

    try:
      problem = iterate_problem(sampled_properties, sampled_aggregator, input_integers, previous_problem_statements, code_and_output, print_prompt=debug)
    except Exception as e:
      if debug:
        print(f"Error iterating problem: {e}")
      continue
    if problem is None or len(problem.strip()) == 0:
      if debug:
        print("No problem generated during iteration, skipping.")
      continue

    if debug:
      print("Problem:")
      print(problem)
      print()
    
    problem_statements_str = extract_last_json_code(problem)
    if problem_statements_str is None or len(problem_statements_str.strip()) == 0:
      problem_statements_str = problem
    
    try:
      problem_statements = ast.literal_eval(problem_statements_str) # ast doesn't error out on latex formulas.
    except:
      print("Error parsing problem statements.")
      problem_statements = None
      continue
    if problem_statements is None or len(problem_statements) == 0:
      if debug:
        print("No problem statements extracted into a list, skipping.")
      continue

    # Print problem.
    if debug:
      print("Problem statements:")
      for statement in problem_statements:
        print(statement)
      print()
  
    if total_iterations >= MAX_LLM_RETRIES:
      print(f"Tried to iterate on problem {total_iterations} times, failed and stopped.\n")
      return None

    return problem_statements

def generate_check_solution_code_with_retry(problem_statements, activated_input_integers, max_retries=MAX_LLM_RETRIES, debug=False):
  total_iterations = 0
  previous_code_and_error = None
  while total_iterations < max_retries:
    total_iterations += 1

    try:
      check_solution_function_str = generate_check_solution_code(
        ' '.join(problem_statements), activated_input_integers, previous_code_and_error=previous_code_and_error, print_prompt=debug
      )
    except Exception as e:
      if debug:
        print(f"Error generating check solution code: {e}")
      continue
    if check_solution_function_str is None or len(check_solution_function_str.strip()) == 0:
      if debug:
        print("No check solution code generated, skipping.")
      continue

    if debug:
      print("Check solution code:")
      print(check_solution_function_str)
      print()

    try:
      check_solution_function = load_and_execute_code(check_solution_function_str).get('check_solution', None)
    except Exception as e:
      if debug:
        print(f"Error loading and executing check solution code: {e}")
      previous_code_and_error = f"{check_solution_function_str}\nError:{e}"
      continue
    if check_solution_function is None:
      if debug:
        print("No check solution function found, skipping.")
      continue

    # Try running the check solution function with the input integers.
    try:
      # Generate all solutions.
      possible_solutions = brute_force_solutions(check_solution_function, MAX_PROBLEM_NUM_INTEGER_INPUT, **activated_input_integers)
      if debug:
        print(f"possible_solutions: {possible_solutions}\n")
    except Exception as e:
      if debug:
        print(f"Error checking all brute force solutions: {e}")
      previous_code_and_error = f"{check_solution_function_str}\nError:{e}"
      continue

    return possible_solutions, check_solution_function, check_solution_function_str

  print(f"Tried to generate check solution for problem {total_iterations} times, failed and stopped.\n")
  return None, None, None

In [None]:
input_integers = [random.randint(1, PROBLEM_RANDOM_INTEGER_INPUT_RANGE) for i in range(MAX_PROBLEM_NUM_INTEGER_INPUT)]
sampled_aggregator = random.choice(INT_AGGREGATORS)
sampled_properties = random.sample(list(all_properties), NUM_PROPERTIES_FOR_PROBLEM_GEN)

total_iterations = 0
previous_gens = None
while total_iterations < MAX_LLM_RETRIES:
  total_iterations += 1

  # Generete problem.
  if previous_gens is None:
    problem_statements = gen_problem_with_retry(sampled_properties, sampled_aggregator, input_integers, max_retries=MAX_LLM_RETRIES, debug=True)
  else:
    problem_statements = iterate_problem_with_retry(
      sampled_properties, sampled_aggregator, input_integers,
      previous_gens['previous_problem_statements'], previous_gens['previous_check_code_and_possible_solutions'],
      max_retries=MAX_LLM_RETRIES, debug=True
    )

  # Get params that were actually used by LLM in generated problem statements.
  # Skip first and last statement since they set number of integers and how to aggregate them.
  problem_statements_with_activated_integers = ' '.join(problem_statements[1:-1])
  activated_integers = find_integers_in_string(problem_statements_with_activated_integers)
  activated_input_integers = {
    f"guiding_param_{idx+1}": i
    for idx, i in enumerate(activated_integers)
  }
  print(f"Activated input integers:\n{activated_input_integers}\n")

  # Generete code to check a possible solution.
  possible_solutions, check_solution_function, check_solution_function_str = generate_check_solution_code_with_retry(
    problem_statements, activated_input_integers, max_retries=MAX_LLM_RETRIES, debug=True
  )

  # If unique solution, break, we have a validated problem.
  if len(possible_solutions) == 1:
    # Print problem.
    print("Final Problem:")
    for statement in problem_statements:
      print(statement)
    print(f"\nFinal List of Integers Solution: {possible_solutions}")
    print(f"Final Check Code:\n{check_solution_function_str}")
    break

  if total_iterations >= MAX_LLM_RETRIES:
    print(f"Tried {total_iterations} times, failed and stopped.\n")
    break

  check_code_and_possible_solutions = f"""
  Check Code:
  {check_solution_function_str}

  Possible solutions:
  {possible_solutions}
  """
  previous_gens = {
    'previous_problem_statements': problem_statements,
    'previous_check_code_and_possible_solutions': check_code_and_possible_solutions
  }

  print(f"Previous gens: {previous_gens}\n")
