In [1]:
import os
from mistralai import Mistral

In [2]:
api_key = os.environ["MISTRAL_API_KEY"]
model = "mistral-small-latest"

client = Mistral(api_key=api_key)

In [3]:
def call_strong(messages: list) -> str:
    chat_response = client.chat.complete(
        model= "mistral-large-latest",
        messages = messages
    )
    return chat_response.choices[0].message.content

In [4]:
def call_mini(messages: list) -> str:
    chat_response = client.chat.complete(
        model= "mistral-small-latest",
        messages = messages
    )
    return chat_response.choices[0].message.content

### Prereq

In [5]:
from reasonable import DefaultReasoningAgent

In [6]:
agent = DefaultReasoningAgent(
    main_function=call_mini,
    thoughts_function=call_mini,
    max_steps=10,
    timeout=2,
    verbose=True
)

ReasoningAgent initialized. Verbose mode is on.


In [8]:
import re
import time

def evaluate_llm_response(question, llm_answer_func, evaluator_func):
    question_text = question["question"]
    correct_solution = question["solution"]
    correct_answer = question["answer"]

    time.sleep(5)
    llm_answer = llm_answer_func([{"role": "user", "content": question_text}])

    evaluation_prompt = (
        f"Question: {question_text}\n"
        f"Correct solution: {correct_solution}\n"
        f"Correct answer: {correct_answer}\n"
        f"Answer by student: {llm_answer}\n"
        "Please evaluate the correctness of the student's answer above."
        "Evaluate the student's answer in <evaluation> tag and provide a verdict ('correct' or 'incorrect') in <verdict> tag and score (from 0 to 10) in <score> tag."
        "Respond in this format:"
        "<evaluation>"
        "..."
        "</evaluation>"
        "<verdict>"
        "..."
        "</verdict>"
        "<score>"
        "..."
        "</score>"
    )
    
    time.sleep(5)
    

    evaluation_result = evaluator_func([{"role": "user", "content": evaluation_prompt}])

    # extract score
    pattern = r'<score>(.*?)</score>'
    score = re.search(pattern, evaluation_result, re.DOTALL)
    if score:
        score = int(score.group(1).strip())
    else:
        score = 0

    # extract verdict
    pattern = r'<verdict>(.*?)</verdict>'
    verdict = re.search(pattern, evaluation_result, re.DOTALL)
    if verdict:
        verdict = verdict.group(1).strip()
        if "incorrect" in verdict:
            verdict = 0
        else:
            verdict = 1
    else:
        verdict = 0

    return {"question": question_text, "correct_answer": correct_answer, "llm_answer": llm_answer, "score": score, "verdict": verdict}

### Dataset

In [9]:
from datasets import load_dataset

ds = load_dataset("HuggingFaceH4/MATH-500")

In [10]:
import random

In [11]:
MATH50_ds = []
for example in ds["test"]:
    MATH50_ds.append({
        "question": example["problem"],
        "solution": example["solution"],
        "answer": example["answer"]
    })

In [12]:
random.shuffle(MATH50_ds)

In [13]:
MATH50_ds = MATH50_ds[:50]

In [14]:
MATH50_ds

[{'question': 'Evaluate $\\log_264$.',
  'solution': 'We have $2^6=64$, so $\\log_2 64 = \\boxed{6}$.',
  'answer': '6'},
 {'question': 'What is $\\left(4\\dfrac{5}{8}\\right)^{55} \\cdot \\left(\\dfrac{8}{37}\\right)^{55}$?',
  'solution': 'First we convert $4\\dfrac{5}{8}$ into an improper fraction: \\[4\\dfrac{5}{8} = 4 + \\dfrac{5}{8} = \\dfrac{32}{8} + \\dfrac{5}{8} = \\dfrac{37}{8}.\\]We discover that $4\\dfrac{5}{8}$ and $\\dfrac{8}{37}$ are in fact reciprocals of each other. Using the fact that $(ab)^n = a^nb^n$, we get our answer: \\[\n\\left(4\\dfrac{5}{8}\\right)^{55} \\cdot \\left(\\dfrac{8}{37}\\right)^{55} = \\left(4\\dfrac{5}{8} \\cdot \\dfrac{8}{37}\\right)^{55} = 1^{55} = \\boxed{1}.\\]',
  'answer': '1'},
 {'question': 'Determine the modulo 4 remainder of the following sum: $$ 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12. $$',
  'solution': 'Grouping residues helps make some series computations easier:  \\begin{align*}\n1 + 2 + 3 + 0 + 1 + 2& + 3 + 0 + 1 + 2 + 3 +

### Evaluating base model

In [9]:
from tqdm import tqdm

In [36]:
base_model_test_info = {
    "total": 0,
    "score": 0,
    "answers" : []
}

max_total = 50
max_score = 50 * 10

for question in tqdm(MATH50_ds):
    foo = evaluate_llm_response(question, call_mini, call_mini)
    base_model_test_info["total"] += foo["verdict"]
    base_model_test_info["score"] += foo["score"]
    base_model_test_info["answers"].append(foo)

    print(f"Current accuracy: {base_model_test_info['total']/len(base_model_test_info['answers'])}/1.0")
    print(f"Current score: {base_model_test_info['score']}/{10*len(base_model_test_info['answers'])}")
    time.sleep(1)

  0%|          | 0/50 [00:00<?, ?it/s]

Current accuracy: 1.0/1.0
Current score: 10/10


  2%|▏         | 1/50 [00:13<11:04, 13.56s/it]

Current accuracy: 1.0/1.0
Current score: 20/20


  4%|▍         | 2/50 [00:27<11:02, 13.81s/it]

Current accuracy: 0.6666666666666666/1.0
Current score: 27/30


  6%|▌         | 3/50 [00:42<11:18, 14.44s/it]

Current accuracy: 0.75/1.0
Current score: 37/40


  8%|▊         | 4/50 [00:58<11:21, 14.82s/it]

Current accuracy: 0.8/1.0
Current score: 47/50


 10%|█         | 5/50 [01:12<11:05, 14.78s/it]

Current accuracy: 0.8333333333333334/1.0
Current score: 55/60


 12%|█▏        | 6/50 [01:33<12:23, 16.91s/it]

Current accuracy: 0.8571428571428571/1.0
Current score: 65/70


 14%|█▍        | 7/50 [01:48<11:34, 16.14s/it]

Current accuracy: 0.875/1.0
Current score: 75/80


 16%|█▌        | 8/50 [02:03<11:03, 15.80s/it]

Current accuracy: 0.8888888888888888/1.0
Current score: 85/90


 18%|█▊        | 9/50 [02:18<10:40, 15.62s/it]

Current accuracy: 0.9/1.0
Current score: 95/100


 20%|██        | 10/50 [02:32<09:57, 14.94s/it]

Current accuracy: 0.9090909090909091/1.0
Current score: 105/110


 22%|██▏       | 11/50 [02:48<10:04, 15.50s/it]

Current accuracy: 0.9166666666666666/1.0
Current score: 115/120


 24%|██▍       | 12/50 [03:05<09:56, 15.70s/it]

Current accuracy: 0.9230769230769231/1.0
Current score: 125/130


 26%|██▌       | 13/50 [03:20<09:37, 15.62s/it]

Current accuracy: 0.9285714285714286/1.0
Current score: 135/140


 28%|██▊       | 14/50 [03:34<09:03, 15.09s/it]

Current accuracy: 0.8666666666666667/1.0
Current score: 141/150


 30%|███       | 15/50 [03:52<09:24, 16.13s/it]

Current accuracy: 0.875/1.0
Current score: 151/160


 32%|███▏      | 16/50 [04:06<08:43, 15.40s/it]

Current accuracy: 0.8823529411764706/1.0
Current score: 161/170


 34%|███▍      | 17/50 [04:19<08:07, 14.78s/it]

Current accuracy: 0.8333333333333334/1.0
Current score: 164/180


 36%|███▌      | 18/50 [04:33<07:45, 14.55s/it]

Current accuracy: 0.7894736842105263/1.0
Current score: 169/190


 38%|███▊      | 19/50 [04:49<07:37, 14.76s/it]

Current accuracy: 0.75/1.0
Current score: 172/200


 40%|████      | 20/50 [05:06<07:42, 15.43s/it]

Current accuracy: 0.7142857142857143/1.0
Current score: 180/210


 42%|████▏     | 21/50 [05:23<07:42, 15.94s/it]

Current accuracy: 0.6818181818181818/1.0
Current score: 185/220


 44%|████▍     | 22/50 [05:39<07:31, 16.11s/it]

Current accuracy: 0.6956521739130435/1.0
Current score: 195/230


 46%|████▌     | 23/50 [05:54<07:03, 15.68s/it]

Current accuracy: 0.7083333333333334/1.0
Current score: 202/240


 48%|████▊     | 24/50 [06:18<07:51, 18.14s/it]

Current accuracy: 0.68/1.0
Current score: 206/250


 50%|█████     | 25/50 [06:33<07:10, 17.21s/it]

Current accuracy: 0.6923076923076923/1.0
Current score: 216/260


 52%|█████▏    | 26/50 [06:50<06:54, 17.25s/it]

Current accuracy: 0.7037037037037037/1.0
Current score: 226/270


 54%|█████▍    | 27/50 [07:05<06:15, 16.34s/it]

Current accuracy: 0.6785714285714286/1.0
Current score: 233/280


 56%|█████▌    | 28/50 [07:19<05:46, 15.74s/it]

Current accuracy: 0.6551724137931034/1.0
Current score: 236/290


 58%|█████▊    | 29/50 [07:34<05:27, 15.58s/it]

Current accuracy: 0.6666666666666666/1.0
Current score: 246/300


 60%|██████    | 30/50 [07:48<05:03, 15.20s/it]

Current accuracy: 0.6774193548387096/1.0
Current score: 249/310


 62%|██████▏   | 31/50 [08:08<05:13, 16.48s/it]

Current accuracy: 0.6875/1.0
Current score: 259/320


 64%|██████▍   | 32/50 [08:22<04:46, 15.89s/it]

Current accuracy: 0.6666666666666666/1.0
Current score: 261/330


 66%|██████▌   | 33/50 [08:40<04:39, 16.44s/it]

Current accuracy: 0.6470588235294118/1.0
Current score: 263/340


 68%|██████▊   | 34/50 [08:57<04:27, 16.71s/it]

Current accuracy: 0.6285714285714286/1.0
Current score: 269/350


 70%|███████   | 35/50 [09:19<04:31, 18.07s/it]

Current accuracy: 0.6388888888888888/1.0
Current score: 279/360


 72%|███████▏  | 36/50 [09:33<03:56, 16.89s/it]

Current accuracy: 0.6486486486486487/1.0
Current score: 289/370


 74%|███████▍  | 37/50 [09:49<03:37, 16.70s/it]

Current accuracy: 0.6578947368421053/1.0
Current score: 299/380


 76%|███████▌  | 38/50 [10:02<03:06, 15.58s/it]

Current accuracy: 0.6666666666666666/1.0
Current score: 309/390


 78%|███████▊  | 39/50 [10:17<02:48, 15.35s/it]

Current accuracy: 0.675/1.0
Current score: 319/400


 80%|████████  | 40/50 [10:34<02:37, 15.76s/it]

Current accuracy: 0.6829268292682927/1.0
Current score: 329/410


 82%|████████▏ | 41/50 [10:47<02:16, 15.14s/it]

Current accuracy: 0.6904761904761905/1.0
Current score: 339/420


 84%|████████▍ | 42/50 [11:02<01:59, 14.96s/it]

Current accuracy: 0.6976744186046512/1.0
Current score: 349/430


 86%|████████▌ | 43/50 [11:15<01:41, 14.56s/it]

Current accuracy: 0.7045454545454546/1.0
Current score: 359/440


 88%|████████▊ | 44/50 [11:29<01:25, 14.25s/it]

Current accuracy: 0.6888888888888889/1.0
Current score: 366/450


 90%|█████████ | 45/50 [11:47<01:16, 15.35s/it]

Current accuracy: 0.6956521739130435/1.0
Current score: 376/460


 92%|█████████▏| 46/50 [12:00<00:59, 14.84s/it]

Current accuracy: 0.7021276595744681/1.0
Current score: 386/470


 94%|█████████▍| 47/50 [12:15<00:44, 14.83s/it]

Current accuracy: 0.7083333333333334/1.0
Current score: 396/480


 96%|█████████▌| 48/50 [12:30<00:29, 14.73s/it]

Current accuracy: 0.6938775510204082/1.0
Current score: 402/490


 98%|█████████▊| 49/50 [12:45<00:14, 14.94s/it]

Current accuracy: 0.68/1.0
Current score: 402/500


100%|██████████| 50/50 [13:01<00:00, 15.63s/it]


In [37]:
import json

with open("mistral_base_model_test_info.json", 'w', encoding='utf-8') as json_file:
    json.dump(base_model_test_info, json_file, ensure_ascii=False, indent=4)

In [38]:
with open("mistral_MATH50_ds.json", 'w', encoding='utf-8') as json_file:
    json.dump(MATH50_ds, json_file, ensure_ascii=False, indent=4)

### Open json

In [10]:
import json

with open("mistral_MATH50_ds.json", 'r', encoding='utf-8') as json_file:
    MATH50_ds = json.loads(json_file.read())

### Eval reasoner

In [11]:
def reasonable_wrapper(messages):
    message = messages[-1]
    response = agent.reason(message["content"])
    return response['answer']

In [None]:
reasoning_model_test_info = {
    "total": 0,
    "score": 0,
    "answers" : []
}

max_total = 50
max_score = 50 * 10

for question in tqdm(MATH50_ds):
    try:
        foo = evaluate_llm_response(question, reasonable_wrapper, reasonable_wrapper)
        reasoning_model_test_info["total"] += foo["verdict"]
        reasoning_model_test_info["score"] += foo["score"]
        reasoning_model_test_info["answers"].append(foo)
    except:
        pass

    print(f"Current accuracy: {reasoning_model_test_info['total']/len(reasoning_model_test_info['answers'])}/1.0")
    print(f"Current score: {reasoning_model_test_info['score']}/{10*len(reasoning_model_test_info['answers'])}")

  0%|          | 0/50 [00:00<?, ?it/s]

Parsed from assistant:
  reasoning: This problem requires us to evaluate the logarithm base 2 of 64. To solve this, we need to find the exponent to which the base 2 must be raised to yield 64. Let's start by considering the properties of logarithms and exponential functions. First, I will use my knowledge of logarithms and exponentials to estimate the answer.
  next_action: continue
Parsed from assistant:
  reasoning: To evaluate $\log_264$, we need to find the exponent to which the base 2 must be raised to produce 64. This is a problem that can be solved by applying the definition of logarithms. Let's start by considering the prime factorization of 64 to determine the power of 2 that gives 64. The prime factorization of 64 is $2^6 = 64$.

We know that $2^6 = 64$, so $\log_264 = 6$. We also know that $2^5 = 32$ and $2^7 = 128$. Therefore, $6$ is the correct exponent.

To confirm this, we can use another method. We can use the change of base formula which states that $\log_b a = \frac{\