In [1]:
from vllm import LLM, SamplingParams
from datasets import load_dataset
import pandas as pd
from xml.etree import ElementTree as ET
import re
import json

## get gsm8k dataset

In [2]:
dataset = load_dataset("openai/gsm8k", "main")
data = dataset['train'].to_pandas()

In [3]:
for idx, row in data.iterrows():
    question = row["question"]
    solution = row["answer"]
    
    cot, answer = solution.split("#### ")

    data.loc[idx, 'cot']    = cot
    data.loc[idx, 'answer'] = answer

In [51]:
data.to_csv('gsm8k_train.csv', index=False)

## parse the gpt-4o evaluator data

In [None]:
data = pd.read_csv("/scratch/hd2584/llm_marl/sft/data/gpt-4o/gsm8k_train_eval.csv")

In [7]:
def extract_tags(text, agent='generator'):
    """
    Parse XML‑like tags from text using regex. Returns a dict:
      - for agent='generator': keys "think" and "answer"
      - for agent='evaluator': keys "evaluate" and "verify"
    Each value is a list of the inner texts of those tags.
    """
    # choose which tags to look for
    if agent == 'generator':
        tag_pairs = [("think", "think"), ("answer", "answer")]
    else:  # evaluator
        tag_pairs = [("evaluate", "evaluate"), ("verify", "verify")]

    out = {}
    for key, tag in tag_pairs:
        # non‑greedy match between <tag>...</tag>, including newlines
        pattern = rf"<{tag}>(.*?)</{tag}>"
        matches = re.findall(pattern, text, flags=re.DOTALL)
        # strip leading/trailing whitespace from each match
        out[key] = [m.strip() for m in matches]
    return out

In [8]:
for idx, row in data.iterrows():
    evaluate = row['evaluate']
    if 'evaluate>' not in evaluate:
        evaluate = evaluate.replace("<verify>", "</evaluate> <verify>")
    evaluate = evaluate.replace("<answer>", "<verify>").replace("</answer>", "</verify>")   
    evaluate = evaluate.replace("< /evaluate>", "</evaluate>").replace("< /verify>", "</verify>")
    evaluate = evaluate.replace("<<", "").replace(">>", "")
    data.loc[idx, 'evaluate'] = evaluate # prepare the training data  

In [12]:
# accuracy
correct_cnt = 0
gen_c_ver_c, gen_c_ver_w, gen_w_ver_c, gen_w_ver_w, no_ver = 0, 0, 0, 0, 0

for idx, row in data.iterrows():
    response = row['response']
    evaluate = row['evaluate']
    answer = row['answer']

    try:
        tags_gen = extract_tags("<think>" + response.replace("<<", "").replace(">>", ""))
    except ET.ParseError:
        tags_gen = {"think": [], "answer": []}
        
    try:
        tags_eval = extract_tags("<evaluate>" + evaluate, agent = 'evaluator')
    except ET.ParseError:
        tags_eval = {"evaluate": [], "verify": []}


    if any((answer in attempt.replace('$', '').replace('ml', '').replace('.00', '').split()) for attempt in tags_gen["answer"]):
        data.loc[idx, 'res_correct'] = 1
        gen_is_correct = True
        correct_cnt += 1
    else:
        data.loc[idx, 'res_correct'] = 0
        gen_is_correct = False

    if len(tags_eval["verify"])==0:
        # print(f'{idx}---')
        # print(evaluate)
        no_ver += 1
    else:
        verification = tags_eval["verify"][0].lower()
        if gen_is_correct and verification == 'correct':
            # print('Generator correct, evaluator correct')
            data.loc[idx, 'eval_correct'] = 1
            gen_c_ver_c += 1
        elif gen_is_correct and verification == 'wrong' :
            # print('Generator correct, evaluator wrong')
            # print(idx)
            gen_c_ver_w += 1
            data.loc[idx, 'eval_correct'] = 0
        elif (not gen_is_correct) and verification == 'wrong':
            # print('Generator wrong, evaluator correct')
            gen_w_ver_c += 1
            data.loc[idx, 'eval_correct'] = 1
        elif (not gen_is_correct) and verification == 'correct':
            # print('Generator wrong, evaluator wrong')
            # print(idx)
            gen_w_ver_w += 1
            data.loc[idx, 'eval_correct'] = 0
        
        
print('- Accuracy:', f'{correct_cnt/len(data):.4f}')
print('- Generator correct, evaluator correct:', gen_c_ver_c, f'{gen_c_ver_c/len(data):.4f}')
print('- Generator correct, evaluator wrong:', gen_c_ver_w, f'{gen_c_ver_w/len(data):.4f}')
print('- Generator wrong, evaluator correct:', gen_w_ver_c, f'{gen_w_ver_c/len(data):.4f}')
print('- Generator wrong, evaluator wrong:', gen_w_ver_w, f'{gen_w_ver_w/len(data):.4f}')
print('- No verification:', no_ver, f'{no_ver/len(data):.4f}')
# len(data[data['res_correct']==1][data['eval_correct']==0])/len(data)

- Accuracy: 0.7274
- Generator correct, evaluator correct: 4707 0.6299
- Generator correct, evaluator wrong: 728 0.0974
- Generator wrong, evaluator correct: 1135 0.1519
- Generator wrong, evaluator wrong: 896 0.1199
- No verification: 0 0.0000


In [13]:
len(data)

7473

## get evaluator sft dataset

In [14]:
df = data[data['eval_correct']==1]
len(df)

5842

In [15]:
df_train = df.sample(frac=0.8, random_state=42)
df_test = df.drop(df_train.index)

In [16]:
df_train.res_correct.value_counts()

res_correct
1.0    3769
0.0     905
Name: count, dtype: int64

In [17]:
df_train.reset_index(drop=True, inplace=True)

In [180]:
df_train.to_csv('eval_sft_train.csv', index=False)

In [18]:
df_test.res_correct.value_counts()

res_correct
1.0    938
0.0    230
Name: count, dtype: int64

In [177]:
df_test.reset_index(drop=True, inplace=True)

In [181]:
df_test.to_csv('eval_sft_test.csv', index=False)

In [19]:
data = pd.read_csv('./eval_sft_data/eval_sft_train.csv')

In [20]:
BASE_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. "
    "The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
    "The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively, "
    "i.e., <think>reasoning process here</think> <answer>answer here</answer>. User: %s. Assistant: <think>"
)


EVALUATION_PROMPT = (
    "%s. "
    "The Assistant evaluates the problem and the solution. The assistant first evaluates and finds out where the solution is wrong. "
    "Then the assistant provides the verification. "
    "The evaluation process and verification are enclosed within <evaluate></evaluate> and <verify></verify> tags, respectively, "
    "i.e., <evaluate>evaluation process here</evaluate> <verify>verification here</verify>. Verification is limited to 'correct' or 'wrong'. Assistant:<evaluate>"
)

data['gen_prompt'] = [BASE_PROMPT % question for question in data['question']]
for idx, row in data.iterrows():
    res = row['gen_prompt'] + row['response']
    data.loc[idx, 'eval_prompt'] = EVALUATION_PROMPT % res

In [21]:
len(data)

4674

In [None]:
data = data.to_dict(orient="records")
path = f"/scratch/hd2584/llm_marl/sft/data/eval_sft_data/eval_sft_train.json"
with open(path, "w", encoding="utf-8") as f:
    json.dump(data, f, indent=4, ensure_ascii=False)

## get generator sft dataset

In [48]:
data = pd.read_csv('./gsm8k/gsm8k_train.csv')

In [49]:
BASE_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. "
    "The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
    "The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively, "
    "i.e., <think>reasoning process here</think> <answer>answer here</answer>. User: %s. Assistant: <think>"
)

data['prompt'] = [BASE_PROMPT % question for question in data['question']]
data['target'] = [f"{row['cot']}</think> <answer>{row['answer']}</answer>" for id, row in data.iterrows()]

In [None]:
data = data.to_dict(orient="records")
path = f"/scratch/hd2584/llm_marl/sft/data/gen_sft_data/gen_sft_train.json"
with open(path, "w", encoding="utf-8") as f:
    json.dump(data, f, indent=4, ensure_ascii=False)

## get mixed sft dataset

In [31]:
data = pd.read_csv('./eval_sft_data/eval_sft_train.csv')

In [23]:
BASE_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. "
    "The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
    "The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively, "
    "i.e., <think>reasoning process here</think> <answer>answer here</answer>. User: %s. Assistant: <think>"
)


EVALUATION_PROMPT = (
    "%s. "
    "The Assistant evaluates the problem and the solution. The assistant first evaluates and finds out where the solution is wrong. "
    "Then the assistant provides the verification. "
    "The evaluation process and verification are enclosed within <evaluate></evaluate> and <verify></verify> tags, respectively, "
    "i.e., <evaluate>evaluation process here</evaluate> <verify>verification here</verify>. Verification is limited to 'correct' or 'wrong'. Assistant:<evaluate>"
)

data['gen_prompt'] = [BASE_PROMPT % question for question in data['question']]
for idx, row in data.iterrows():
    res = row['gen_prompt'] + row['response']
    data.loc[idx, 'eval_prompt'] = EVALUATION_PROMPT % res

In [25]:
data_wrong = data[data['res_correct']==0]
n = len(data_wrong)
data_correct = data[data['res_correct']==1].sample(n, random_state=42)
data_final = pd.concat([data_correct, data_wrong])

In [26]:
n

905

In [27]:
data_final = data_final.sample(frac=1, random_state=42).reset_index(drop=True)

In [28]:
data_final_a = data_final.copy()
data_final_a['prompt'] = data_final_a['gen_prompt']
data_final_a['target'] = [f"{row['cot']}</think> <answer>{row['answer']}</answer>" for id, row in data_final_a.iterrows()]
data_final_b = data_final.copy()
data_final_b['prompt'] = data_final_a['eval_prompt']
data_final_b['target'] = data_final_a['evaluate']

In [29]:
data_final = pd.concat([data_final_a, data_final_b])
data_final.reset_index(drop=True, inplace=True)
len(data_final)

3620

In [None]:
data_final = data_final.to_dict(orient="records")
path = f"/scratch/hd2584/llm_marl/sft/data/eval_sft_data/mix_sft_train.json"
with open(path, "w", encoding="utf-8") as f:
    json.dump(data_final, f, indent=4, ensure_ascii=False)

In [None]:
data_final = pd.read_json("/scratch/hd2584/llm_marl/sft/data/eval_sft_data/mix_sft_train.json")