In [55]:
from mr_eval.utils.utils import *
import os
from copy import deepcopy
import numpy as np

data_dir = "/mnt/petrelfs/songmingyang/code/reasoning/MR_Hallucination/mr_eval/tasks/prmtest_classified/data"
dataset_type = "dir_of_jsonl"

#domain_inconsistency	redundency	multi_solutions	deception	confidence	step_contradiction	circular	missing_condition	counterfactual
classification_name_dict = dict(
    domain_inconsistency="DC.",
    redundency="NR.",
    multi_solutions="MS.",
    deception="DR.",
    confidence="CI.",
    step_contradiction="SC.",
    circular="NCL.",
    missing_condition="PS.",
    counterfactual="ES."
)
classifications = ["redundency", "circular", "counterfactual", "step_contradiction", "domain_inconsistency",  "confidence", "missing_condition", "deception", "multi_solutions", ]


In [58]:


def get_steps_info(raw_data,classifications):
    meta_res_sample = dict(total_step_length=[], error_step_length=[],)
    total_list = deepcopy(meta_res_sample)
    classification_list = {classification: deepcopy(meta_res_sample) for classification in classifications}
    
    for item in raw_data:
        steps_num = len(item["modified_process"])
        error_num = len(item["error_steps"])
        correct_num = steps_num - error_num
        total_list["total_step_length"].append(steps_num)
        total_list["error_step_length"].append(error_num)
        classification_list[item["classification"]]["total_step_length"].append(steps_num)
        classification_list[item["classification"]]["error_step_length"].append(error_num)
    
    for k,v in total_list.items():
        total_list[k] = np.mean(v) if len(v) > 0 else -1
    for classification in classifications:
        for k,v in classification_list[classification].items():
            classification_list[classification][k] = np.mean(v) if len(v) > 0 else -1
            
    return total_list, classification_list

def get_first_error_loc(raw_data,classifications):
    meta_res_sample = dict(first_error_loc=[])
    total_list = deepcopy(meta_res_sample)
    classification_list = {classification: deepcopy(meta_res_sample) for classification in classifications}
    
    for item in raw_data:

        first_error_loc = item["error_steps"][0] if len(item["error_steps"]) > 0 else -1
        total_list["first_error_loc"].extend([first_error_loc] if first_error_loc != -1 else [])
        classification_list[item["classification"]]["first_error_loc"].extend([first_error_loc] if first_error_loc != -1 else [])
    
    for k,v in total_list.items():
        total_list[k] = np.mean(v) if len(v) > 0 else -1
    for classification in classifications:
        for k,v in classification_list[classification].items():
            classification_list[classification][k] = np.mean(v) if len(v) > 0 else -1
            
    return total_list, classification_list

def get_quesiton_length(raw_data,classifications):
    meta_res_sample = dict(question_length=[])
    total_list = deepcopy(meta_res_sample)
    classification_list = {classification: deepcopy(meta_res_sample) for classification in classifications}
    
    for item in raw_data:
        question_length = len(item["modified_question"])
        total_list["question_length"].extend([question_length])
        classification_list[item["classification"]]["question_length"].extend([question_length])
    
    for k,v in total_list.items():
        total_list[k] = np.mean(v) if len(v) > 0 else -1
    for classification in classifications:
        for k,v in classification_list[classification].items():
            classification_list[classification][k] = np.mean(v) if len(v) > 0 else -1
            
    return total_list, classification_list 
        
def get_total_num(raw_data,classifications):
    meta_res_sample = dict(total_num=0)
    total_list = deepcopy(meta_res_sample)
    classification_list = {classification: deepcopy(meta_res_sample) for classification in classifications}
    
    for item in raw_data:
        total_list["total_num"] += 1
        classification_list[item["classification"]]["total_num"] += 1
            
    return total_list, classification_list

def merge_res_to_base(base_total_dict, base_classification_dict, merge_total_dict, merge_classification_dict):
    for k,v in merge_total_dict.items():
        base_total_dict[k] = v
    for classification in base_classification_dict.keys():
        for k,v in merge_classification_dict[classification].items():
            base_classification_dict[classification][k] = v
    return base_total_dict, base_classification_dict

### Visualization
def print_res_to_excel(total_statistic_data, classification_statistic_data, split_token="\t", return_token="\n",classifications=classifications):
    metrics = list(total_statistic_data.keys())
    
    if classifications is None:
        classifications = list(classification_statistic_data.keys())
        
    all_res_str = f"Models{split_token}Total"
    for classification in classifications:
        all_res_str += f"{split_token}{classification_name_dict[classification]}"
    all_res_str += return_token
    for metric in metrics:
        write_val = round(total_statistic_data[metric],1) if isinstance(total_statistic_data[metric],float) else total_statistic_data[metric]
        write_val = write_val if write_val != -1 else "N/A"
        res_str = f"{metric}{split_token}{write_val}"
        for classification in classifications:
            write_val = round(classification_statistic_data[classification][metric],1) if isinstance(classification_statistic_data[classification][metric],float) else classification_statistic_data[classification][metric]
            write_val = write_val if write_val != -1 else "N/A"
            res_str += f"{split_token}{write_val}"
        res_str += return_token
        all_res_str += res_str
    
    all_res_str = all_res_str.replace("_"," ")
    print(all_res_str)
    

In [59]:
data_files = os.listdir(data_dir)
data_files = [os.path.join(data_dir, file) for file in data_files if file.endswith(".jsonl")]
raw_data = []
for file in data_files:
    temp = process_jsonl(file)
    raw_data.extend(temp)

classifications = set([item["classification"] for item in raw_data])
total_statistic_data = {}
classification_statistic_data = {classification: {} for classification in classifications}

step_info_total, step_info_classification = get_steps_info(raw_data,classifications)
first_error_loc_total, first_error_loc_classification = get_first_error_loc(raw_data,classifications)
question_length_total, question_length_classification = get_quesiton_length(raw_data,classifications)
total_num_total, total_num_classification = get_total_num(raw_data,classifications)

total_statistic_data, classification_statistic_data = merge_res_to_base(total_statistic_data, classification_statistic_data, step_info_total, step_info_classification)
total_statistic_data, classification_statistic_data = merge_res_to_base(total_statistic_data, classification_statistic_data, first_error_loc_total, first_error_loc_classification)
total_statistic_data, classification_statistic_data = merge_res_to_base(total_statistic_data, classification_statistic_data, question_length_total, question_length_classification)
total_statistic_data, classification_statistic_data = merge_res_to_base(total_statistic_data, classification_statistic_data, total_num_total, total_num_classification)



In [60]:
print_res_to_excel(total_statistic_data, classification_statistic_data,)

Models	Total	NR.	NCL.	ES.	SC.	DC.	CI.	PS.	DR.	MS.
total step length	13.4	15.3	10.3	13.8	14.2	13.3	14.2	12.7	13.4	14.1
error step length	2.1	2.0	2.8	2.8	1.6	1.8	1.7	2.5	2.3	0.0
first error loc	7.8	7.8	4.9	8.0	9.1	6.8	11.4	6.2	8.3	N/A
question length	152.7	153.6	152.5	153.5	149.7	152.5	152.7	158.0	153.5	132.2
total num	6216	758	758	757	758	757	757	756	750	165



In [62]:
print_res_to_excel(total_statistic_data, classification_statistic_data,split_token="&", return_token="\\\\\n")

Models&Total&NR.&NCL.&ES.&SC.&DC.&CI.&PS.&DR.&MS.\\
total step length&13.4&15.3&10.3&13.8&14.2&13.3&14.2&12.7&13.4&14.1\\
error step length&2.1&2.0&2.8&2.8&1.6&1.8&1.7&2.5&2.3&0.0\\
first error loc&7.8&7.8&4.9&8.0&9.1&6.8&11.4&6.2&8.3&N/A\\
question length&152.7&153.6&152.5&153.5&149.7&152.5&152.7&158.0&153.5&132.2\\
total num&6216&758&758&757&758&757&757&756&750&165\\

