In [None]:
import os
import argparse
import pandas as pd

# Define choices for multiple-choice questions
CHOICES = ["A", "B", "C", "D"]

# Mappings from subcategories to categories
SUBCATEGORIES = {
    "abstract_algebra": ["math"],
    "anatomy": ["health"],
    "astronomy": ["physics"],
    "business_ethics": ["business"],
    "clinical_knowledge": ["health"],
    "college_biology": ["biology"],
    "college_chemistry": ["chemistry"],
    "college_computer_science": ["computer science"],
    "college_mathematics": ["math"],
    "college_medicine": ["health"],
    "college_physics": ["physics"],
    "computer_security": ["computer science"],
    "conceptual_physics": ["physics"],
    "econometrics": ["economics"],
    "electrical_engineering": ["engineering"],
    "elementary_mathematics": ["math"],
    "formal_logic": ["philosophy"],
    "global_facts": ["other"],
    "high_school_biology": ["biology"],
    "high_school_chemistry": ["chemistry"],
    "high_school_computer_science": ["computer science"],
    "high_school_european_history": ["history"],
    "high_school_geography": ["geography"],
    "high_school_government_and_politics": ["politics"],
    "high_school_macroeconomics": ["economics"],
    "high_school_mathematics": ["math"],
    "high_school_microeconomics": ["economics"],
    "high_school_physics": ["physics"],
    "high_school_psychology": ["psychology"],
    "high_school_statistics": ["math"],
    "high_school_us_history": ["history"],
    "high_school_world_history": ["history"],
    "human_aging": ["health"],
    "human_sexuality": ["culture"],
    "international_law": ["law"],
    "jurisprudence": ["law"],
    "logical_fallacies": ["philosophy"],
    "machine_learning": ["computer science"],
    "management": ["business"],
    "marketing": ["business"],
    "medical_genetics": ["health"],
    "miscellaneous": ["other"],
    "moral_disputes": ["philosophy"],
    "moral_scenarios": ["philosophy"],
    "nutrition": ["health"],
    "philosophy": ["philosophy"],
    "prehistory": ["history"],
    "professional_accounting": ["other"],
    "professional_law": ["law"],
    "professional_medicine": ["health"],
    "professional_psychology": ["psychology"],
    "public_relations": ["politics"],
    "security_studies": ["politics"],
    "sociology": ["culture"],
    "us_foreign_policy": ["politics"],
    "virology": ["health"],
    "world_religions": ["philosophy"],
}

# Mappings from categories to higher-level categories
CATEGORIES = {
    "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
    "humanities": ["history", "philosophy", "law"],
    "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
    "other (business, health, misc.)": ["other", "business", "health"],
}

def format_subject(subject):
    """
    Formats the subject name by replacing underscores with spaces.

    Args:
        subject (str): The subject name with underscores.

    Returns:
        str: Formatted subject name.
    """
    return " ".join(subject.split("_")).strip()

def format_example(df, idx, include_answer=True):
    """
    Formats a single example from the dataframe into a prompt string.

    Args:
        df (pd.DataFrame): DataFrame containing the data.
        idx (int): Index of the example in the DataFrame.
        include_answer (bool): Whether to include the answer in the prompt.

    Returns:
        str: Formatted prompt string.
    """
    prompt = df.iloc[idx, 0]
    num_choices = df.shape[1] - 2  # Assuming last column is the answer
    for j in range(num_choices):
        prompt += f"\n{CHOICES[j]}. {df.iloc[idx, j + 1]}"
    prompt += "\nAnswer:"
    if include_answer:
        prompt += f" {df.iloc[idx, num_choices + 1]}"
    prompt += "\n\n"
    return prompt

def get_subcategories_by_selected_categories(selected_categories):
    """
    Retrieves all subcategories that fall under the selected higher-level categories.

    Args:
        selected_categories (list): List of higher-level category names.

    Returns:
        set: Set of subcategory names that belong to the selected categories.
    """
    selected_subcategories = set()
    for high_cat in selected_categories:
        if high_cat not in CATEGORIES:
            print(f"Warning: High-level category '{high_cat}' not recognized. Skipping.")
            continue
        categories_in_high_cat = CATEGORIES[high_cat]
        for subcat, cats in SUBCATEGORIES.items():
            for cat in cats:
                if cat in categories_in_high_cat:
                    selected_subcategories.add(subcat)
    return selected_subcategories

def collect_attack_prompts(data_dir, selected_subcategories, ntrain=5):
    """
    Collects prompts from selected subcategories and aggregates them into a list.

    Args:
        data_dir (str): Path to the data directory containing 'dev' and 'test' folders.
        selected_subcategories (set): Set of subcategory names to include.
        ntrain (int): Number of training examples to include from the dev set.

    Returns:
        list: Aggregated list of formatted prompts.
    """
    attack_prompts = []
    
    for subcat in selected_subcategories:
        dev_path = os.path.join(data_dir, "dev", f"{subcat}_dev.csv")
        test_path = os.path.join(data_dir, "test", f"{subcat}_test.csv")
        
        # Check if both dev and test files exist
        if not os.path.isfile(dev_path):
            print(f"Development file for '{subcat}' not found at {dev_path}. Skipping this subcategory.")
            continue
        if not os.path.isfile(test_path):
            print(f"Test file for '{subcat}' not found at {test_path}. Skipping this subcategory.")
            continue
        
        # Load the CSV files
        dev_df = pd.read_csv(dev_path, header=None)
        test_df = pd.read_csv(test_path, header=None)
        
        # Limit the number of training examples if specified
        if ntrain > 0:
            dev_df = dev_df.iloc[:ntrain]
        
        # Collect prompts from the dev set
        for i in range(dev_df.shape[0]):
            prompt = format_example(dev_df, i, include_answer=True)
            attack_prompts.append(prompt)
        
        # Collect prompts from the test set
        for i in range(test_df.shape[0]):
            prompt = format_example(test_df, i, include_answer=True)
            attack_prompts.append(prompt)
        
        print(f"Collected {dev_df.shape[0] + test_df.shape[0]} prompts from subcategory '{subcat}'.")
    
    print(f"\nTotal prompts collected: {len(attack_prompts)}")
    return attack_prompts

def save_attack_prompts(attack_prompts, save_path):
    """
    Saves the aggregated attack prompts to a text file.

    Args:
        attack_prompts (list): List of formatted prompts.
        save_path (str): Path to save the attack prompts file.
    """
    with open(save_path, "w", encoding="utf-8") as f:
        for prompt in attack_prompts:
            f.write(prompt + "\n")
    print(f"Attack prompts saved to {save_path}")

def main():
    # Argument parser for command-line options
    parser = argparse.ArgumentParser(description="Process MMLU data into attack_prompts list based on selected categories.")
    parser.add_argument(
        "--data-dir",
        type=str,
        required=True,
        help="Directory containing the 'dev' and 'test' subdirectories with CSV files."
    )
    parser.add_argument(
        "--save-file",
        type=str,
        default="attack_prompts.txt",
        help="File path to save the aggregated attack prompts."
    )
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument(
        "--selected-subcategories",
        type=str,
        nargs='+',
        help="List of subcategory names to include (e.g., physics mathematics)."
    )
    group.add_argument(
        "--selected-categories",
        type=str,
        nargs='+',
        help="List of higher-level category names to include (e.g., STEM humanities)."
    )
    parser.add_argument(
        "--ntrain",
        type=int,
        default=5,
        help="Number of training examples to include from each subcategory's dev set."
    )
    
    args = parser.parse_args()
    
    # Determine which subcategories to include
    if args.selected_categories:
        selected_categories = args.selected_categories
        selected_subcategories = get_subcategories_by_selected_categories(selected_categories)
        if not selected_subcategories:
            print("No valid subcategories found for the selected categories. Exiting.")
            return
    else:
        selected_subcategories = set(args.selected_subcategories)
    
    print(f"Selected subcategories: {selected_subcategories}")
    
    # Collect attack prompts
    attack_prompts = collect_attack_prompts(
        data_dir=args.data_dir,
        selected_subcategories=selected_subcategories,
        ntrain=args.ntrain
    )
    
    # Save the attack prompts to a file
    save_attack_prompts(attack_prompts, args.save_file)

In [None]:
from collections import defaultdict
import json
import os
def compute_accuracies(data):
    """
    计算每个子类别、每个类别以及整体的准确率。

    Args:
        data (list): 数据条目列表。

    Returns:
        tuple: 包含子类别准确率、类别准确率和整体准确率的三个字典。
    """
    # 初始化字典以存储计数
    subcat_counts = defaultdict(lambda: {'total': 0, 'correct': 0})
    cat_counts = defaultdict(lambda: {'total': 0, 'correct': 0})

    # 初始化整体计数
    overall_total = 0
    overall_correct = 0

    for entry in data:
        category = entry.get('category', ['Unknown'])[0]  # 假设category是一个列表
        subcategory = entry.get('subcategory', 'Unknown')
        correct_answer = entry.get('answer', '').strip().upper()
        output = entry.get('output', '')
        predicted_answer = output[3].strip().upper() if output else ''

        # 更新子类别计数
        subcat_counts[subcategory]['total'] += 1
        if predicted_answer == correct_answer:
            subcat_counts[subcategory]['correct'] += 1

        # 更新类别计数
        cat_counts[category]['total'] += 1
        if predicted_answer == correct_answer:
            cat_counts[category]['correct'] += 1

        # 更新整体计数
        overall_total += 1
        if predicted_answer == correct_answer:
            overall_correct += 1

    # 计算子类别准确率
    subcat_accuracy = {}
    for subcat, counts in subcat_counts.items():
        total = counts['total']
        correct = counts['correct']
        accuracy = (correct / total) * 100 if total > 0 else 0.0
        subcat_accuracy[subcat] = round(accuracy, 2)

    # 计算类别准确率
    cat_accuracy = {}
    for cat, counts in cat_counts.items():
        total = counts['total']
        correct = counts['correct']
        accuracy = (correct / total) * 100 if total > 0 else 0.0
        cat_accuracy[cat] = round(accuracy, 2)

    # 计算整体准确率
    overall_accuracy = (overall_correct / overall_total) * 100 if overall_total > 0 else 0.0
    overall_accuracy = round(overall_accuracy, 2)

    return subcat_accuracy, cat_accuracy, overall_accuracy

def save_accuracies(subcat_accuracy, cat_accuracy, overall_accuracy, folder_path, save_name, output_file='accuracies.json'):
    """
    将子类别准确率、类别准确率和整体准确率保存到同一个 JSON 文件中。

    Args:
        subcat_accuracy (dict): 每个子类别的准确率。
        cat_accuracy (dict): 每个类别的准确率。
        overall_accuracy (float): 整体准确率。
        folder_path (str): 保存文件的文件夹路径。
        save_name (str): 保存文件的前缀名称。
        output_file (str): 输出 JSON 文件名（默认 'accuracies.json'）。
    """
    # 确保保存文件的文件夹存在
    os.makedirs(folder_path, exist_ok=True)

    # 构建完整的文件路径
    output_path = os.path.join(folder_path, f"{save_name}_{output_file}")

    # 构建要保存的字典
    accuracies = {
        "subcategory_accuracy": subcat_accuracy,
        "category_accuracy": cat_accuracy,
        "overall_accuracy": overall_accuracy
    }

    try:
        # 保存所有准确率到一个 JSON 文件
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(accuracies, f, ensure_ascii=False, indent=4)
        print(f"所有准确率已保存到 '{output_path}'.")
    except IOError as e:
        print(f"保存准确率时出错: {e}")

In [5]:
 
with open('/home/kz34/Yang_Ouyang_Projects/ICLR2025/jailbreaking_related/SafeDecoding/exp_outputs_new_new_new/SafeDecoding_mistral_MMLU_7723_2024-11-25 08:50:38/SafeDecoding_mistral_MMLU_7723_2024-11-25 08:50:38.json', 'r') as f:
    output_json = json.load(f)
results = output_json['data']
subcat_acc, cat_acc, overall_acc = compute_accuracies(results)
save_accuracies(subcat_acc, cat_acc, overall_acc, './', "SafeDecoding_mistral_MMLU")


所有准确率已保存到 './SafeDecoding_mistral_MMLU_accuracies.json'.
