# 移除错误的生成结果

In [2]:
%load_ext autoreload
%autoreload 2

import gc
import os
import logging
import random
from tqdm import tqdm
import argparse
import numpy as np
import torch
import pickle

# Load environment variables from .env file
from dotenv import load_dotenv
load_dotenv()

from core.data.data_utils import load_ds, load_ds_from_json
from core.utils import utils
from core.models.huggingface_models import HuggingfaceModel


def load_pickle_file(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
model_names = ["Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-7B-Instruct", "meta-llama/Llama-3.1-8B-Instruct"]
context_types = ["irrelevant", "golden", "without"]
sample_types = ["greedy", "sample"]
datasets = ["squad", "triviaqa", "bioasq"]
splits=["train", "test", "validation"]

from collections import Counter
total_error_counts = Counter()
for model_name in model_names:
    for context_type in context_types:
        for sample_type in sample_types:
            for dataset in datasets:
                for split in splits:
                    dir_path = f"output/{split}/generation/{model_name}/{dataset}/{sample_type}_{context_type}"
                    if not os.path.exists(dir_path):
                        continue
                    print(dir_path)
                    error_counts = Counter()
                    # print(error_counts)
                    # print(f"Model: {model_name}, Context: {context_type}, Sample: {sample_type}, Dataset: {dataset}, Split: {split}")
                    # 遍历
                    for file_name in os.listdir(dir_path):
                        if file_name.endswith(".pkl"):
                            file_path = os.path.join(dir_path, file_name)
                            data = load_pickle_file(file_path)
                            new_responses = []
                            error_count = 0
                            for response in data["responses"]:
                                if 'error' in response:
                                    error_count += 1
                                else:
                                    new_responses.append(response)
                            if 'args' in data:
                                del data['args']
                            data["responses"] = new_responses
                            error_counts.update({error_count})
                            with open(file_path, 'wb') as f:
                                pickle.dump(data, f)
                    print("Error counts:", error_counts)
                    total_error_counts.update(error_counts)

print("Total error counts:", total_error_counts)

output/train/generation/Qwen/Qwen2.5-3B-Instruct/squad/greedy_irrelevant
Error counts: Counter({0: 2000})
output/train/generation/Qwen/Qwen2.5-3B-Instruct/squad/sample_irrelevant
Error counts: Counter({0: 2000})
output/train/generation/Qwen/Qwen2.5-3B-Instruct/squad/greedy_golden
Error counts: Counter({0: 2000})
output/train/generation/Qwen/Qwen2.5-3B-Instruct/squad/sample_golden
Error counts: Counter({0: 2000})
output/train/generation/Qwen/Qwen2.5-3B-Instruct/squad/greedy_without
Error counts: Counter({0: 2000})
output/train/generation/Qwen/Qwen2.5-3B-Instruct/squad/sample_without
Error counts: Counter({0: 2000})
output/train/generation/Qwen/Qwen2.5-7B-Instruct/squad/greedy_irrelevant
Error counts: Counter({0: 10000})
output/test/generation/Qwen/Qwen2.5-7B-Instruct/squad/greedy_irrelevant
Error counts: Counter({0: 1000})
output/validation/generation/Qwen/Qwen2.5-7B-Instruct/squad/greedy_irrelevant
Error counts: Counter({0: 1000})
output/train/generation/Qwen/Qwen2.5-7B-Instruct/trivia