In [3]:
import json
import pandas as pd

In [4]:
with open('llama3_70b_dpo_results.json', 'r') as f:
    data = json.load(f)

In [5]:
data.keys()

dict_keys(['llama3_70b_a_low', 'llama3_70b_o_low', 'llama3_70b_n_high', 'llama3_70b_o_high', 'llama3_70b_c_high', 'llama3_70b_n_low', 'llama3_70b_c_low', 'llama3_70b_a_high', 'llama3_70b_e_high', 'llama3_70b_e_low'])

In [6]:
def extract_metrics(data, metrics):
    results = {}
    for model in data:
        results[model] = {}
        for metric, specific_metric in metrics.items():
            for test_suite in data[model]:
                if metric in test_suite:
                    if metric == "truthfulqa":
                        value = data[model][test_suite]["truthfulqa_mc2"]["acc,none"]
                    elif metric == "gsm8k":
                        value = data[model][test_suite]["gsm8k"]["exact_match,flexible-extract"]
                    else:
                        value = data[model][test_suite][metric][specific_metric]
                    results[model][metric] = value
                    break
    return results

metrics = {
    "truthfulqa": "truthfulqa_mc2", 
    "gpqa_main_zeroshot": "acc,none", 
    "gpqa_main_n_shot": "acc,none", 
    "social_iqa": "acc,none", 
    "commonsense_qa": "acc,none", 
    "gsm8k": "exact_match,flexible-extract", 
    "mathqa": "acc,none",
    "mmlu": "acc,none",
    "piqa": "acc,none",
}

results = extract_metrics(data, metrics)

# 打印结果
for model, model_results in results.items():
    print(f'"{model}": {json.dumps(model_results)}')

"llama3_70b_a_low": {"truthfulqa": 0.5058314086539912, "gpqa_main_zeroshot": 0.35714285714285715, "gpqa_main_n_shot": 0.32142857142857145, "social_iqa": 0.39048106448311154, "commonsense_qa": 0.3923013923013923, "gsm8k": 0.8999241849886277, "mathqa": 0.32763819095477387, "mmlu": 0.6251958410482837, "piqa": 0.7404787812840044}
"llama3_70b_o_low": {"truthfulqa": 0.5421157034230984, "gpqa_main_zeroshot": 0.31919642857142855, "gpqa_main_n_shot": 0.3125, "social_iqa": 0.44472876151484136, "commonsense_qa": 0.6592956592956593, "gsm8k": 0.8847611827141774, "mathqa": 0.34706867671691793, "mmlu": 0.6437117219769264, "piqa": 0.7676822633297062}
"llama3_70b_n_high": {"truthfulqa": 0.4304264387243563, "gpqa_main_zeroshot": 0.32589285714285715, "gpqa_main_n_shot": 0.36607142857142855, "social_iqa": 0.39969293756397134, "commonsense_qa": 0.20065520065520065, "gsm8k": 0.1516300227445034, "mathqa": 0.2887772194304858, "mmlu": 0.33186155818259505, "piqa": 0.7285092491838956}
"llama3_70b_o_high": {"trut

In [7]:
import json

def extract_metrics_with_stderr(data, metrics):
    results = {}
    for model in data:
        results[model] = {}
        for metric, specific_metric in metrics.items():
            for test_suite in data[model]:
                if metric in test_suite:
                    if metric == "truthfulqa":
                        value = data[model][test_suite]["truthfulqa_mc2"]["acc,none"]
                        stderr = data[model][test_suite]["truthfulqa_mc2"].get("acc_stderr,none", None)
                    elif metric == "gsm8k":
                        value = data[model][test_suite]["gsm8k"]["exact_match,flexible-extract"]
                        stderr = data[model][test_suite]["gsm8k"].get("exact_match_stderr,flexible-extract", None)
                    else:
                        value = data[model][test_suite][metric][specific_metric]
                        stderr = data[model][test_suite][metric].get(specific_metric.replace("acc,", "acc_stderr,"), None)
                    results[model][metric] = (value, stderr)
                    break
    return results

metrics = {
    "truthfulqa": "truthfulqa_mc2", 
    "gpqa_main_zeroshot": "acc,none", 
    "gpqa_main_n_shot": "acc,none", 
    "social_iqa": "acc,none", 
    "commonsense_qa": "acc,none", 
    "gsm8k": "exact_match,flexible-extract", 
    "mathqa": "acc,none",
    "mmlu": "acc,none"
}

results = extract_metrics_with_stderr(data, metrics)

# 打印结果
for model, model_results in results.items():
    formatted_results = {k: {"value": v[0], "stderr": v[1]} for k, v in model_results.items()}
    print(f'"{model}": {json.dumps(formatted_results)}')

"llama3_70b_a_low": {"truthfulqa": {"value": 0.5058314086539912, "stderr": 0.016427216245233724}, "gpqa_main_zeroshot": {"value": 0.35714285714285715, "stderr": 0.02266336846322688}, "gpqa_main_n_shot": {"value": 0.32142857142857145, "stderr": 0.022089519157170157}, "social_iqa": {"value": 0.39048106448311154, "stderr": 0.01103932371486307}, "commonsense_qa": {"value": 0.3923013923013923, "stderr": 0.01397893643494679}, "gsm8k": {"value": 0.8999241849886277, "stderr": 0.008266274528685637}, "mathqa": {"value": 0.32763819095477387, "stderr": 0.008592100906266604}, "mmlu": {"value": 0.6251958410482837, "stderr": 0.00392094851749934}}
"llama3_70b_o_low": {"truthfulqa": {"value": 0.5421157034230984, "stderr": 0.016755339076097026}, "gpqa_main_zeroshot": {"value": 0.31919642857142855, "stderr": 0.022048861164576057}, "gpqa_main_n_shot": {"value": 0.3125, "stderr": 0.021923384489444957}, "social_iqa": {"value": 0.44472876151484136, "stderr": 0.011244731148193177}, "commonsense_qa": {"value":

In [8]:
model_results

{'truthfulqa': (0.6528340114952984, 0.01595351060013572),
 'gpqa_main_zeroshot': (0.359375, 0.022694577961439925),
 'gpqa_main_n_shot': (0.3549107142857143, 0.022631623416326744),
 'social_iqa': (0.43551688843398156, 0.011219586604022594),
 'commonsense_qa': (0.7084357084357085, 0.013011802821401595),
 'gsm8k': (0.9044730856709629, 0.00809660577115574),
 'mathqa': (0.3504187604690117, 0.008733956045067806),
 'mmlu': (0.7230451502634953, 0.0035789783761124203)}

In [9]:
results

{'llama3_70b_a_low': {'truthfulqa': (0.5058314086539912, 0.016427216245233724),
  'gpqa_main_zeroshot': (0.35714285714285715, 0.02266336846322688),
  'gpqa_main_n_shot': (0.32142857142857145, 0.022089519157170157),
  'social_iqa': (0.39048106448311154, 0.01103932371486307),
  'commonsense_qa': (0.3923013923013923, 0.01397893643494679),
  'gsm8k': (0.8999241849886277, 0.008266274528685637),
  'mathqa': (0.32763819095477387, 0.008592100906266604),
  'mmlu': (0.6251958410482837, 0.00392094851749934)},
 'llama3_70b_o_low': {'truthfulqa': (0.5421157034230984, 0.016755339076097026),
  'gpqa_main_zeroshot': (0.31919642857142855, 0.022048861164576057),
  'gpqa_main_n_shot': (0.3125, 0.021923384489444957),
  'social_iqa': (0.44472876151484136, 0.011244731148193177),
  'commonsense_qa': (0.6592956592956593, 0.013569036984855006),
  'gsm8k': (0.8847611827141774, 0.00879538230154542),
  'mathqa': (0.34706867671691793, 0.00871449153541417),
  'mmlu': (0.6437117219769264, 0.0038392275190125944)},
 '

In [12]:
import re
import ast

def format_value(value, error):
    return f"{(value * 100):.1f} $\\pm$ {(error * 100):.1f}"

def get_benchmark_name(key):
    benchmark_map = {
        'truthfulqa': 'TruthfulQA',
        'gpqa_main_zeroshot': 'GPQA Zero Shot',
        'gpqa_main_n_shot': 'GPQA N Shot',
        'social_iqa': 'SocialIQA',
        'commonsense_qa': 'CommonsenseQA',
        'gsm8k': 'GSM8K',
        'mathqa': 'MathQA',
        'mmlu': 'MMLU'
    }
    return benchmark_map.get(key, key.capitalize())

def generate_latex_table(data):
    benchmarks = list(next(iter(data.values())).keys())
    personalities = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Neuroticism']
    trait_abbr = {'o': 'Openness', 'c': 'Conscientiousness', 'e': 'Extraversion', 'a': 'Agreeableness', 'n': 'Neuroticism'}
    
    table = "\\begin{table}[htbp]\n\\centering\n\\resizebox{\\textwidth}{!}{\n"
    table += "\\begin{tabular}{llc" + "cc" * len(personalities) + "c}\n"
    table += "\\toprule\n"
    table += "\\textbf{Benchmark} & \\textbf{Original} & \\textbf{Method} & "
    table += " & ".join([f"\\multicolumn{{2}}{{c}}{{\\textbf{{{p}}}}}" for p in personalities])
    table += " & \\multicolumn{2}{c}{\\textbf{Average}} \\\\\n"
    table += " & & & " + "High & Low & " * (len(personalities) + 1) + "\\\\\n"
    table += "\\midrule\n"
    
    for benchmark in benchmarks:
        row = f"\\textbf{{{get_benchmark_name(benchmark)}}} & - & DPO & "
        for trait in personalities:
            trait_abbr_key = trait[0].lower()
            high_key = f"llama3_70b_{trait_abbr_key}_high"
            low_key = f"llama3_70b_{trait_abbr_key}_low"
            if high_key in data and low_key in data and benchmark in data[high_key] and benchmark in data[low_key]:
                row += format_value(*data[high_key][benchmark]) + " & "
                row += format_value(*data[low_key][benchmark]) + " & "
            else:
                row += "- & - & "
        row += "- & - \\\\\n"
        table += row
    
    table += "\\bottomrule\n"
    table += "\\end{tabular}\n}\n"
    table += "\\caption{DPO Benchmark results for different personality traits on Llama3 70B.}\n"
    table += "\\label{tab:dpo_personality_results_llama3_70b}\n"
    table += "\\end{table}"
    
    return table

In [13]:
print(generate_latex_table(results))

\begin{table}[htbp]
\centering
\resizebox{\textwidth}{!}{
\begin{tabular}{llcccccccccccc}
\toprule
\textbf{Benchmark} & \textbf{Original} & \textbf{Method} & \multicolumn{2}{c}{\textbf{Openness}} & \multicolumn{2}{c}{\textbf{Conscientiousness}} & \multicolumn{2}{c}{\textbf{Extraversion}} & \multicolumn{2}{c}{\textbf{Agreeableness}} & \multicolumn{2}{c}{\textbf{Neuroticism}} & \multicolumn{2}{c}{\textbf{Average}} \\
 & & & High & Low & High & Low & High & Low & High & Low & High & Low & High & Low & \\
\midrule
\textbf{TruthfulQA} & - & DPO & 54.6 $\pm$ 1.6 & 54.2 $\pm$ 1.7 & 64.6 $\pm$ 1.6 & 38.5 $\pm$ 1.6 & 46.0 $\pm$ 1.7 & 65.3 $\pm$ 1.6 & 59.6 $\pm$ 1.6 & 50.6 $\pm$ 1.6 & 43.0 $\pm$ 1.7 & 65.8 $\pm$ 1.6 & - & - \\
\textbf{GPQA Zero Shot} & - & DPO & 36.8 $\pm$ 2.3 & 31.9 $\pm$ 2.2 & 35.7 $\pm$ 2.3 & 30.6 $\pm$ 2.2 & 35.9 $\pm$ 2.3 & 35.9 $\pm$ 2.3 & 35.5 $\pm$ 2.3 & 35.7 $\pm$ 2.3 & 32.6 $\pm$ 2.2 & 34.6 $\pm$ 2.2 & - & - \\
\textbf{GPQA N Shot} & - & DPO & 37.5 $\pm$ 2.3 & 31.2 $\p

In [14]:
### Merge

In [15]:
import pandas as pd

# Function to create LaTeX table
def create_latex_table(data):
    latex_table = r"\begin{table}[htbp]" + "\n"
    latex_table += r"\centering" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{llcccccccccccc}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"\multirow[c]{2}{*}{\textbf{Benchmark}} & \multirow[c]{2}{*}{\textbf{Original}} & \multirow[c]{2}{*}{\textbf{Method}} & \multicolumn{2}{c}{\textbf{Openness}} & \multicolumn{2}{c}{\textbf{Conscientiousness}} & \multicolumn{2}{c}{\textbf{Extraversion}} & \multicolumn{2}{c}{\textbf{Agreeableness}} & \multicolumn{2}{c}{\textbf{Neuroticism}} & \multicolumn{2}{c}{\textbf{Average}} \\" + "\n"
    latex_table += r" & & & High & Low & High & Low & High & Low & High & Low & High & Low & \textbf{High} & \textbf{Low} \\" + "\n"
    latex_table += r"\midrule" + "\n"

    for benchmark in data:
        latex_table += r"\multirow[c]{3}{*}{\textbf{" + benchmark["name"] + r"}}" + "\n"
        latex_table += r" & \multirow[c]{3}{*}{" + benchmark["original"] + r"} & Prompt & " + " & ".join(benchmark["prompt"]) + r" \\" + "\n"
        latex_table += r" & & SFT & " + " & ".join(benchmark["sft"]) + r" \\" + "\n"
        latex_table += r" & & DPO & " + " & ".join(benchmark["dpo"]) + r" \\" + "\n"
        latex_table += r"\midrule" + "\n"

    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"
    latex_table += r"\caption{Benchmark results for different personality traits on Llama3 70B.}" + "\n"
    latex_table += r"\label{tab:benchmark_results_llama3_70b}" + "\n"
    latex_table += r"\end{table}"

    return latex_table

# Data for the table
data = [
    {
        "name": "TruthfulQA",
        "original": "-",
        "prompt": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
        "sft": ["55.2 $\pm$ 1.6", "52.8 $\pm$ 1.6", "55.6 $\pm$ 1.6", "50.8 $\pm$ 1.5", "54.5 $\pm$ 1.6", "56.7 $\pm$ 1.6", "54.4 $\pm$ 1.6", "51.6 $\pm$ 1.6", "52.4 $\pm$ 1.5", "56.7 $\pm$ 1.6", "-", "-"],
        "dpo": ["54.6 $\pm$ 1.6", "54.2 $\pm$ 1.7", "64.6 $\pm$ 1.6", "38.5 $\pm$ 1.6", "46.0 $\pm$ 1.7", "65.3 $\pm$ 1.6", "59.6 $\pm$ 1.6", "50.6 $\pm$ 1.6", "43.0 $\pm$ 1.7", "65.8 $\pm$ 1.6", "-", "-"]
    },
    {
        "name": "GPQA Zero Shot",
        "original": "-",
        "prompt": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
        "sft": ["33.5 $\pm$ 2.2", "32.4 $\pm$ 2.2", "34.2 $\pm$ 2.2", "34.2 $\pm$ 2.2", "33.3 $\pm$ 2.2", "34.4 $\pm$ 2.2", "33.3 $\pm$ 2.2", "33.3 $\pm$ 2.2", "34.4 $\pm$ 2.2", "33.5 $\pm$ 2.2", "-", "-"],
        "dpo": ["36.8 $\pm$ 2.3", "31.9 $\pm$ 2.2", "35.7 $\pm$ 2.3", "30.6 $\pm$ 2.2", "35.9 $\pm$ 2.3", "35.9 $\pm$ 2.3", "35.5 $\pm$ 2.3", "35.7 $\pm$ 2.3", "32.6 $\pm$ 2.2", "34.6 $\pm$ 2.2", "-", "-"]
    },
    {
        "name": "GPQA N Shot",
        "original": "-",
        "prompt": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
        "sft": ["32.4 $\pm$ 2.2", "32.8 $\pm$ 2.2", "34.4 $\pm$ 2.2", "33.7 $\pm$ 2.2", "33.0 $\pm$ 2.2", "33.9 $\pm$ 2.2", "33.7 $\pm$ 2.2", "32.8 $\pm$ 2.2", "33.7 $\pm$ 2.2", "34.8 $\pm$ 2.3", "-", "-"],
        "dpo": ["37.5 $\pm$ 2.3", "31.2 $\pm$ 2.2", "35.9 $\pm$ 2.3", "31.2 $\pm$ 2.2", "37.1 $\pm$ 2.3", "35.5 $\pm$ 2.3", "33.5 $\pm$ 2.2", "32.1 $\pm$ 2.2", "36.6 $\pm$ 2.3", "35.7 $\pm$ 2.3", "-", "-"]
    },
    {
        "name": "SocialIQA",
        "original": "-",
        "prompt": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
        "sft": ["50.3 $\pm$ 1.1", "50.4 $\pm$ 1.1", "50.9 $\pm$ 1.1", "46.8 $\pm$ 1.1", "50.0 $\pm$ 1.1", "50.3 $\pm$ 1.1", "50.5 $\pm$ 1.1", "46.6 $\pm$ 1.1", "48.2 $\pm$ 1.1", "50.6 $\pm$ 1.1", "-", "-"],
        "dpo": ["41.5 $\pm$ 1.1", "44.5 $\pm$ 1.1", "44.7 $\pm$ 1.1", "37.6 $\pm$ 1.1", "43.0 $\pm$ 1.1", "43.6 $\pm$ 1.1", "44.8 $\pm$ 1.1", "39.0 $\pm$ 1.1", "40.0 $\pm$ 1.1", "45.3 $\pm$ 1.1", "-", "-"]
    },
    {
        "name": "CommonsenseQA",
        "original": "-",
        "prompt": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
        "sft": ["77.7 $\pm$ 1.2", "78.8 $\pm$ 1.2", "77.6 $\pm$ 1.2", "66.0 $\pm$ 1.4", "75.7 $\pm$ 1.2", "78.9 $\pm$ 1.2", "77.0 $\pm$ 1.2", "73.8 $\pm$ 1.3", "79.1 $\pm$ 1.2", "78.5 $\pm$ 1.2", "-", "-"],
        "dpo": ["57.7 $\pm$ 1.4", "65.9 $\pm$ 1.4", "23.8 $\pm$ 1.2", "25.8 $\pm$ 1.3", "23.2 $\pm$ 1.2", "70.8 $\pm$ 1.3", "21.3 $\pm$ 1.2", "39.2 $\pm$ 1.4", "20.1 $\pm$ 1.1", "44.6 $\pm$ 1.4", "-", "-"]
    },
    {
        "name": "GSM8K",
        "original": "-",
        "prompt": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
        "sft": ["85.8 $\pm$ 1.0", "76.2 $\pm$ 1.2", "86.4 $\pm$ 0.9", "81.7 $\pm$ 1.1", "85.1 $\pm$ 1.0", "86.7 $\pm$ 0.9", "87.0 $\pm$ 0.9", "74.5 $\pm$ 1.2", "76.0 $\pm$ 1.2", "87.3 $\pm$ 0.9", "-", "-"],
        "dpo": ["87.9 $\pm$ 0.9", "88.5 $\pm$ 0.9", "90.2 $\pm$ 0.8", "80.6 $\pm$ 1.1", "88.9 $\pm$ 0.9", "90.4 $\pm$ 0.8", "87.3 $\pm$ 0.9", "90.0 $\pm$ 0.8", "15.2 $\pm$ 1.0", "91.0 $\pm$ 0.8", "-", "-"]
    },
    {
        "name": "MathQA",
        "original": "-",
        "prompt": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
        "sft": ["43.3 $\pm$ 0.9", "42.6 $\pm$ 0.9", "43.0 $\pm$ 0.9", "43.3 $\pm$ 0.9", "43.2 $\pm$ 0.9", "42.7 $\pm$ 0.9", "42.9 $\pm$ 0.9", "42.9 $\pm$ 0.9", "42.8 $\pm$ 0.9", "43.3 $\pm$ 0.9", "-", "-"],
        "dpo": ["33.9 $\pm$ 0.9", "34.7 $\pm$ 0.9", "32.9 $\pm$ 0.9", "28.1 $\pm$ 0.8", "30.5 $\pm$ 0.8", "35.0 $\pm$ 0.9", "31.3 $\pm$ 0.8", "32.8 $\pm$ 0.9", "28.9 $\pm$ 0.8", "34.0 $\pm$ 0.9", "-", "-"]
    },
    {
        "name": "MMLU",
        "original": "-",
        "prompt": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
        "sft": ["72.5 $\pm$ 0.4", "72.0 $\pm$ 0.4", "73.1 $\pm$ 0.4", "68.6 $\pm$ 0.4", "72.1 $\pm$ 0.4", "73.5 $\pm$ 0.4", "72.8 $\pm$ 0.4", "70.7 $\pm$ 0.4", "72.5 $\pm$ 0.4", "73.8 $\pm$ 0.4", "-", "-"],
        "dpo": ["57.9 $\pm$ 0.4", "64.4 $\pm$ 0.4", "50.3 $\pm$ 0.4", "33.8 $\pm$ 0.4", "42.3 $\pm$ 0.4", "72.3 $\pm$ 0.4", "34.3 $\pm$ 0.4", "62.5 $\pm$ 0.4", "33.2 $\pm$ 0.4", "69.1 $\pm$ 0.4", "-", "-"]
    }
]

# Generate the LaTeX table
latex_table = create_latex_table(data)
print(latex_table)

\begin{table}[htbp]
\centering
\resizebox{\textwidth}{!}{%
\begin{tabular}{llcccccccccccc}
\toprule
\multirow[c]{2}{*}{\textbf{Benchmark}} & \multirow[c]{2}{*}{\textbf{Original}} & \multirow[c]{2}{*}{\textbf{Method}} & \multicolumn{2}{c}{\textbf{Openness}} & \multicolumn{2}{c}{\textbf{Conscientiousness}} & \multicolumn{2}{c}{\textbf{Extraversion}} & \multicolumn{2}{c}{\textbf{Agreeableness}} & \multicolumn{2}{c}{\textbf{Neuroticism}} & \multicolumn{2}{c}{\textbf{Average}} \\
 & & & High & Low & High & Low & High & Low & High & Low & High & Low & \textbf{High} & \textbf{Low} \\
\midrule
\multirow[c]{3}{*}{\textbf{TruthfulQA}}
 & \multirow[c]{3}{*}{-} & Prompt &  &  &  &  &  &  &  &  &  &  &  &  &  &  \\
 & & SFT & 55.2 $\pm$ 1.6 & 52.8 $\pm$ 1.6 & 55.6 $\pm$ 1.6 & 50.8 $\pm$ 1.5 & 54.5 $\pm$ 1.6 & 56.7 $\pm$ 1.6 & 54.4 $\pm$ 1.6 & 51.6 $\pm$ 1.6 & 52.4 $\pm$ 1.5 & 56.7 $\pm$ 1.6 & - & - \\
 & & DPO & 54.6 $\pm$ 1.6 & 54.2 $\pm$ 1.7 & 64.6 $\pm$ 1.6 & 38.5 $\pm$ 1.6 & 46.0 $\pm$ 1.7 & 65