In [1]:
from vllm import LLM, SamplingParams
from datasets import load_dataset
import pandas as pd
from xml.etree import ElementTree as ET
import re
import json

In [None]:
data = pd.read_json("/scratch/hd2584/llm_marl/sft/data/eval_sft_data/eval_sft_train.json")
data.res_correct.value_counts()

res_correct
1    3769
0     905
Name: count, dtype: int64

In [2]:
def extract_tags(text, agent='generator'):
    """
    Parse XML‑like tags from text using regex. Returns a dict:
      - for agent='generator': keys "think" and "answer"
      - for agent='evaluator': keys "evaluate" and "verify"
    Each value is a list of the inner texts of those tags.
    """
    # choose which tags to look for
    if agent == 'generator':
        tag_pairs = [("think", "think"), ("answer", "answer")]
    else:  # evaluator
        tag_pairs = [("evaluate", "evaluate"), ("verify", "verify")]

    out = {}
    for key, tag in tag_pairs:
        # non‑greedy match between <tag>...</tag>, including newlines
        pattern = rf"<{tag}>(.*?)</{tag}>"
        matches = re.findall(pattern, text, flags=re.DOTALL)
        # strip leading/trailing whitespace from each match
        out[key] = [m.strip() for m in matches]
    return out

In [13]:
# data = pd.read_csv("/scratch/hd2584/llm_marl/results/test_eval/final_test_base.csv")
data = pd.read_csv("/scratch/hd2584/llm_marl/results/final_eval/final_train_gen_sft.csv")

## generator

In [14]:
# accuracy
correct_cnt = 0
gen_c_ver_c, gen_c_ver_w, gen_w_ver_c, gen_w_ver_w, no_ver = 0, 0, 0, 0, 0

for idx, row in data.iterrows():
    response = row['response']
    answer = row['answer']

    try:
        tags_gen = extract_tags("<think>" + response.replace("<<", "").replace(">>", ""))
    except ET.ParseError:
        tags_gen = {"think": [], "answer": []}


    if any((answer in attempt.replace('$', '').replace('ml', '').replace('.00', '').split()) for attempt in tags_gen["answer"]):
        data.loc[idx, 'res_correct'] = 1
        gen_is_correct = True
        correct_cnt += 1
    else:
        data.loc[idx, 'res_correct'] = 0
        gen_is_correct = False

print('- Accuracy:', correct_cnt/len(data))
# len(data[data['res_correct']==1][data['eval_correct']==0])/len(data)

- Accuracy: 0.714177407126611


In [15]:
i=0
for idx, row in data.iterrows():
    print('-answer:', row['answer'])
    print('-cot:', row['cot'])
    print('-generator:', row['response'])
    print('----')
    i += 1
    if i ==15:
        break

-answer: 18
-cot: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.
She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.

-generator: Janet’s ducks lay 16 eggs per day. She eats 3 for breakfast every morning. So, she has 16 - 3 = 13 eggs left. She bakes muffins for her friends every day with 4 eggs. So, she has 13 - 4 = 9 eggs left. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. So, she makes 9 * $2 = $18 every day at the farmers' market.</think> <answer>$18</answer>
----
-answer: 3
-cot: It takes 2/2=<<2/2=1>>1 bolt of white fiber
So the total amount of fabric is 2+1=<<2+1=3>>3 bolts of fabric

-generator: First, we need to find out how much white fiber is needed. Since the robe takes half as much white fiber as blue fiber, we can calculate the white fiber by dividing the blue fiber by 2. So, the white fiber is 2/2 = 1 bolt. Now, to find the total bolts of fiber needed, we add the blue and white fiber together. So, the total bolts of

## evaluator

In [16]:
for idx, row in data.iterrows():
    evaluate = row['evaluate']
    if 'evaluate>' not in evaluate:
        evaluate = evaluate.replace("<verify>", "</evaluate> <verify>")
    evaluate = evaluate.replace("<answer>", "<verify>").replace("</answer>", "</verify>")   
    evaluate = evaluate.replace("< /evaluate>", "</evaluate>").replace("< /verify>", "</verify>")
    evaluate = evaluate.replace("<<", "").replace(">>", "")
    data.loc[idx, 'evaluate'] = evaluate # prepare the training data  

In [17]:
# accuracy
correct_cnt = 0
gen_c_ver_c, gen_c_ver_w, gen_w_ver_c, gen_w_ver_w, no_ver = 0, 0, 0, 0, 0

for idx, row in data.iterrows():
    response = row['response']
    evaluate = row['evaluate']
    answer = row['answer']

    try:
        tags_gen = extract_tags("<think>" + response.replace("<<", "").replace(">>", ""))
    except ET.ParseError:
        tags_gen = {"think": [], "answer": []}
        
    try:
        tags_eval = extract_tags("<evaluate>" + evaluate, agent = 'evaluator')
    except ET.ParseError:
        tags_eval = {"evaluate": [], "verify": []}


    if any((answer in attempt.replace('$', '').replace('ml', '').replace('.00', '').split()) for attempt in tags_gen["answer"]):
        data.loc[idx, 'res_correct'] = 1
        gen_is_correct = True
        correct_cnt += 1
    else:
        data.loc[idx, 'res_correct'] = 0
        gen_is_correct = False

    if len(tags_eval["verify"])==0:
        # print(f'{idx}---')
        # print(evaluate)
        no_ver += 1
    else:
        verification = tags_eval["verify"][0].lower()
        if gen_is_correct and verification == 'correct':
            # print('Generator correct, evaluator correct')
            data.loc[idx, 'eval_correct'] = 1
            gen_c_ver_c += 1
        elif gen_is_correct and verification == 'wrong' :
            # print('Generator correct, evaluator wrong')
            # print(idx)
            gen_c_ver_w += 1
            data.loc[idx, 'eval_correct'] = 0
        elif (not gen_is_correct) and verification == 'wrong':
            # print('Generator wrong, evaluator correct')
            gen_w_ver_c += 1
            data.loc[idx, 'eval_correct'] = 1
        elif (not gen_is_correct) and verification == 'correct':
            # print('Generator wrong, evaluator wrong')
            # print(idx)
            gen_w_ver_w += 1
            data.loc[idx, 'eval_correct'] = 0
        
        
print('- Accuracy:', f'{correct_cnt/len(data):.4f}')
print('- Generator correct, evaluator correct:', gen_c_ver_c, f'{gen_c_ver_c/len(data):.4f}')
print('- Generator correct, evaluator wrong:', gen_c_ver_w, f'{gen_c_ver_w/len(data):.4f}')
print('- Generator wrong, evaluator correct:', gen_w_ver_c, f'{gen_w_ver_c/len(data):.4f}')
print('- Generator wrong, evaluator wrong:', gen_w_ver_w, f'{gen_w_ver_w/len(data):.4f}')
print('- No verification:', no_ver, f'{no_ver/len(data):.4f}')
# len(data[data['res_correct']==1][data['eval_correct']==0])/len(data)

- Accuracy: 0.7142
- Generator correct, evaluator correct: 578 0.4382
- Generator correct, evaluator wrong: 345 0.2616
- Generator wrong, evaluator correct: 202 0.1531
- Generator wrong, evaluator wrong: 163 0.1236
- No verification: 26 0.0197


In [18]:
print(f'{gen_c_ver_c/len(data):.4f}')
print(f'{gen_c_ver_w/len(data):.4f}')
print(f'{gen_w_ver_c/len(data):.4f}')
print(f'{gen_w_ver_w/len(data):.4f}')
print(f'{no_ver/len(data):.4f}')

0.4382
0.2616
0.1531
0.1236
0.0197


## multi round data

In [19]:
data = pd.read_csv(f"/scratch/hd2584/llm_marl/results/final_eval_multi_round/final_test_gen_sft.csv")

In [20]:
# accuracy
correct_cnt = 0
gen_c_ver_c, gen_c_ver_w, gen_w_ver_c, gen_w_ver_w, no_ver = 0, 0, 0, 0, 0

for idx, row in data.iterrows():
    response = row['response']
    evaluate = row['evaluate']
    answer = row['answer']

    try:
        tags_gen = extract_tags("<think>" + response.replace("<<", "").replace(">>", ""))
    except ET.ParseError:
        tags_gen = {"think": [], "answer": []}
        
    try:
        tags_eval = extract_tags("<evaluate>" + evaluate, agent = 'evaluator')
    except ET.ParseError:
        tags_eval = {"evaluate": [], "verify": []}



    if len(tags_eval["verify"])==0:
        # print(f'{idx}---')
        # print(evaluate)
        no_ver += 1
    else:
        verification = tags_eval["verify"][0].lower()
        data.loc[idx, 'verification'] = verification

In [21]:
for idx, row in data.iterrows():
    if row['verification'] == 'wrong':
        data.loc[idx, 'response_multi'] = row['response_2']
    else:
        data.loc[idx, 'response_multi'] = row['response']

In [22]:
len(data[data['verification']=='wrong'])

643

In [23]:
cnt = 0
for i in range(len(data)):
    if data['response'][i] != data['response_2'][i]:
        cnt+=1
cnt

1247

In [25]:
# accuracy
correct_cnt = 0
gen_c_ver_c, gen_c_ver_w, gen_w_ver_c, gen_w_ver_w, no_ver = 0, 0, 0, 0, 0

for idx, row in data.iterrows():
    response = row['response_multi']
    answer = row['answer']

    try:
        tags_gen = extract_tags("<think>" + response.replace("<<", "").replace(">>", ""))
    except:
        tags_gen = {"think": [], "answer": []}


    if any((answer in attempt.replace('$', '').replace('ml', '').replace('.00', '').split()) for attempt in tags_gen["answer"]):
        data.loc[idx, 'res_correct'] = 1
        gen_is_correct = True
        correct_cnt += 1
    else:
        data.loc[idx, 'res_correct'] = 0
        gen_is_correct = False

print('- Accuracy:', correct_cnt/len(data))
# len(data[data['res_correct']==1][data['eval_correct']==0])/len(data)

- Accuracy: 0.5655799848369977


## example

In [59]:
data =pd.read_csv("/scratch/hd2584/llm_marl/results/final_eval/final_test_selfplay.csv")

In [60]:
row = data.iloc[0]
# print('----- question: -----')
# print(row['question'])
# print('----- answer: -----')
# print(row['cot'])
# print(row['answer'])
print('----- response: -----')
print(row['response'])
print('----- evaluate: -----')
print(row['evaluate'])

----- response: -----
 
Let’s think through this problem step by step. 
1. First, we need to determine how many eggs are left after Janet eats and bakes with her ducks' daily production. 
2. The ducks lay 16 eggs per day. 
3. Janet eats 3 eggs for breakfast each day and uses 4 eggs to make muffins. 
4. The total number of eggs she consumes and uses for muffins is 3 + 4 = <<3+4=7>>7 eggs.
5. The number of eggs left after she consumes and uses some is 16 - 7 = <<16-7=9>>9 eggs.
6. Janet sells each of the remaining eggs for $2. 
7. Therefore, her total sales would be 9 * 2 = $<<9*2=18>>18 dollars per day.
</think> <answer>18</answer>
----- evaluate: -----
 
Let's evaluate the problem step by step to identify any errors in the solution provided. 

1. The ducks lay 16 eggs per day.
2. Janet eats 3 eggs for breakfast and uses 4 eggs for muffins. This totals to 3 + 4 = 7 eggs.
3. The number of eggs left after consumption is 16 - 7 = 9 eggs.
4. Janet sells each of the remaining eggs for $2. Th