In [1]:
import os
import json
import re
import time

from datasets import Dataset
from pyprojroot import here
from pydantic import BaseModel, Field
from langchain_mistralai import ChatMistralAI
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
dataset = Dataset.load_from_disk(here("evals/logs/phased_self_discover/llama/unstructured/few_shot_0/math/math_eval"))
dataset

Dataset({
    features: ['problem', 'level', 'type', 'solution', 'self_discover_input', 'task_description', 'selected_modules', 'adapted_modules', 'reasoning_plan', 'reasoning', 'trajectory', 'answer_pred'],
    num_rows: 200
})

In [3]:
dataset.filter(lambda x: x["answer_pred"] == None)

Filter:   0%|          | 0/200 [00:00<?, ? examples/s]

Dataset({
    features: ['problem', 'level', 'type', 'solution', 'self_discover_input', 'task_description', 'selected_modules', 'adapted_modules', 'reasoning_plan', 'reasoning', 'trajectory', 'answer_pred'],
    num_rows: 3
})

In [4]:
dataset["answer_pred"][2]

'$y^4 - 2y^3 + 7y^2 + y - 5$.'

In [5]:
rate_limiter = InMemoryRateLimiter(
        requests_per_second=0.5,
        check_every_n_seconds=1,
        max_bucket_size=1,
    )

llm = ChatMistralAI(model="mistral-large-2407", rate_limiter=rate_limiter)

class ReasoningState(BaseModel):
    is_correct: bool = Field(
        "True if reasoning answer is same as the answer in the solution, False otherwise"
    )
    correction_reasoning: str = Field(
        "Your reasoning to why you determined if your answer is True or False"
    )

structured_llm = llm.with_structured_output(ReasoningState)

In [6]:
system_prompt = """You are an expert mathematician who is tasked with comparing a student's reasoning and final answer with a given reference answer.

You don't need to be worried about the math problem. Only if the student's final answer is correct.

The final answer in the student's answer is given inside latex \\boxed command and their total reasoning is wrapped within <reasoning> tags.

The reference answer is given within <reference> tags.

You will provide your answer in the given schema. You will determine True if the student's final answer(within \\boxed latex) is same as the <reference> answer."""

user_message = """<reasoning>
{reasoning}
</reasoning>

<reference>
{reference}
</reference>"""

In [9]:
def compare(instance):
    answer = instance["solution"]
    reasoning = instance["reasoning"]

    prompt_template = ChatPromptTemplate([
        ("system", system_prompt),
        ("user", user_message)
    ])

    chain = prompt_template | structured_llm

    try:
        response = chain.invoke({"reasoning": reasoning, "reference": answer})
    except:
        return {
            "is_correct": "",
            "correction_reasoning": "error"
        }

    return response.model_dump()

In [10]:
batch_size = 10
output_dir = "math_processed_batches"
os.makedirs(output_dir, exist_ok=True)

# Get previously saved batches
saved_batches = {int(f.split("_")[1].split(".")[0]) for f in os.listdir(output_dir) if f.startswith("batch_") and f.endswith(".json")}

for i in range(0, len(dataset), batch_size):
    if (i//batch_size) in saved_batches:
        print(f"Skipping already processed batch {i//batch_size}")
        continue
    
    batch = dataset.select(range(i, min(i + batch_size, len(dataset))))
    processed_batch = batch.map(compare)

    # Save batch to JSON
    batch_path = os.path.join(output_dir, f"batch_{i//batch_size}.json")
    processed_batch.to_json(batch_path)
    print(f"Saved batch {i//batch_size} to {batch_path}")

    time.sleep(5)

print("Processing complete!")

Skipping already processed batch 0
Skipping already processed batch 1
Skipping already processed batch 2
Skipping already processed batch 3
Skipping already processed batch 4
Skipping already processed batch 5
Skipping already processed batch 6
Skipping already processed batch 7
Skipping already processed batch 8
Skipping already processed batch 9
Skipping already processed batch 10
Skipping already processed batch 11


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Saved batch 12 to math_processed_batches/batch_12.json


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Saved batch 13 to math_processed_batches/batch_13.json


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Saved batch 14 to math_processed_batches/batch_14.json


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Saved batch 15 to math_processed_batches/batch_15.json


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Saved batch 16 to math_processed_batches/batch_16.json


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Saved batch 17 to math_processed_batches/batch_17.json


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Saved batch 18 to math_processed_batches/batch_18.json


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Saved batch 19 to math_processed_batches/batch_19.json
Processing complete!


# Load and calculate accuracy

In [11]:
from datasets import concatenate_datasets

In [12]:
processed_dataset_list = []

for batch_file in os.listdir(output_dir):
    if batch_file.endswith(".json"):
        processed_dataset_list.append(Dataset.from_json(os.path.join(output_dir, batch_file)))

processed_dataset = concatenate_datasets(processed_dataset_list)

processed_dataset

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['problem', 'level', 'type', 'solution', 'self_discover_input', 'task_description', 'selected_modules', 'adapted_modules', 'reasoning_plan', 'reasoning', 'trajectory', 'answer_pred', 'is_correct', 'correction_reasoning'],
    num_rows: 200
})

In [16]:
processed_dataset.to_json(f"{output_dir}/math.json")

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

2335034

In [13]:
processed_dataset.filter(lambda x: x["is_correct"] == True).num_rows / 200

Filter:   0%|          | 0/200 [00:00<?, ? examples/s]

0.755

In [14]:
processed_dataset.filter(lambda x: x["is_correct"] == "")

Filter:   0%|          | 0/200 [00:00<?, ? examples/s]

Dataset({
    features: ['problem', 'level', 'type', 'solution', 'self_discover_input', 'task_description', 'selected_modules', 'adapted_modules', 'reasoning_plan', 'reasoning', 'trajectory', 'answer_pred', 'is_correct', 'correction_reasoning'],
    num_rows: 0
})