In [7]:
import os
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    GenerationConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
from evaluate import load

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.cuda.set_device(0)
device = torch.device("cuda:0")

torch.cuda.empty_cache()

# ========== 1. 模型与分词器加载 ==========
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [8]:
# ========== 2. 4-bit量化配置 ==========
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

In [9]:
# ========== 3. 加载量化模型 ==========
try:
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=quant_config,
        device_map={"": torch.cuda.current_device()},
        torch_dtype=torch.float16
    )
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=True
    )
except RuntimeError as e:
    print(f"模型加载失败: {str(e)}")
    raise

In [10]:
# ========== 4. LoRA配置 ==========
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["c_attn"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head"]
)

try:
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
except ValueError as e:
    print(f"LoRA配置错误: {str(e)}")
    raise

# ========== 5. 数据集预处理 ==========
def safe_preprocess(examples):
    try:
        texts = [f"Instruction: {ins}\nOutput: {out}"
                for ins, out in zip(examples["instruction"], examples["output"])]
        tokenized = tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=128,
            return_tensors=None
        )
        tokenized["labels"] = tokenized["input_ids"].copy()
        return tokenized
    except KeyError as e:
        print(f"数据字段错误: {str(e)}")
        raise

try:
    dataset = load_dataset("tatsu-lab/alpaca", split="train[:10000]", keep_in_memory=False)
    tokenized_dataset = dataset.map(
        safe_preprocess,
        batched=True,
        batch_size=4,
        remove_columns=dataset.column_names
    )
except Exception as e:
    print(f"数据加载失败: {str(e)}")
    raise


trainable params: 38,892,288 || all params: 120,804,864 || trainable%: 32.1943


Map: 100%|██████████| 10000/10000 [00:05<00:00, 1973.34 examples/s]


In [11]:
# ========== 6. 训练参数配置 ==========
training_args = TrainingArguments(
    output_dir="./qlora_output_v2",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=5,
    learning_rate=1e-4,
    fp16=True,
    logging_steps=5,
    save_steps=20,
    remove_unused_columns=False,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    report_to="none",
    dataloader_pin_memory=False,
    dataloader_num_workers=0
)

In [12]:
# ========== 7. 数据整理 ==========
class SafeDataCollator:
    def __init__(self, device):
        self.device = device

    def __call__(self, features):
        try:
            return {
                "input_ids": torch.stack([torch.tensor(f["input_ids"])
                            for f in features]).to(self.device),
                "attention_mask": torch.stack([torch.tensor(f["attention_mask"])
                            for f in features]).to(self.device),
                "labels": torch.stack([torch.tensor(f["labels"])
                            for f in features]).to(self.device)
            }
        except Exception as e:
            print(f"数据整理错误: {str(e)}")
            raise

In [13]:
# ========== 8. 训练 ==========
try:
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=SafeDataCollator(device)
    )
    print("开始训练")
    trainer.train()
    print("训练成功完成")

except RuntimeError as e:
    print(f"训练失败: {str(e)}")
    raise

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


开始训练


Step,Training Loss
5,6.1351
10,6.5073
15,6.3293
20,6.0073
25,6.1247
30,5.4093
35,5.2589
40,4.8045
45,4.2606
50,3.5028


训练成功完成


In [None]:
# ========== 9. 推理生成 ==========
def safe_generate(prompt, max_length=500):
    try:
        model.eval()
        inputs = tokenizer(
            f"Instruction: {prompt}\nOutput:",
            return_tensors="pt",
            max_length=128,
            truncation=True
        ).to(device)

        generation_config = GenerationConfig(
            max_new_tokens=max_length,
            do_sample=True,
            top_k=30,
            temperature=0.7,
            repetition_penalty=1.2
        )

        with torch.no_grad(), torch.cuda.amp.autocast():
            outputs = model.generate(
                **inputs,
                generation_config=generation_config
            )
        return tokenizer.decode(outputs[0], skip_special_tokens=True).split("Output:")[-1].strip()

    except RuntimeError as e:
        print(f"生成失败: {str(e)}")
        return ""

In [2]:
# ========== 10. 评估模块 ==========
import evaluate
import numpy as np
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.tokenize import word_tokenize
from datasets import load_dataset
import torch
from typing import List, Dict, Any, Optional
import time
import json
from datetime import datetime

class EnhancedModelEvaluator:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.perplexity = evaluate.load("perplexity", module_type="metric")
        self.rouge = evaluate.load('rouge')
        self.bertscore = evaluate.load("bertscore")
        self.smoothing = SmoothingFunction().method1

    def calculate_perplexity(self, text: str) -> float:
        """计算困惑度分数"""
        try:
            results = self.perplexity.compute(
                predictions=[text],
                model_id=self.model.config._name_or_path,
                add_start_token=True
            )
            return results["mean_perplexity"]
        except Exception as e:
            print(f"困惑度计算错误: {str(e)}")
            return float("inf")

    def calculate_bleu(self, reference: str, candidate: str) -> float:
        """计算BLEU分数"""
        try:
            reference_tokens = word_tokenize(reference.lower())
            candidate_tokens = word_tokenize(candidate.lower())
            return sentence_bleu(
                [reference_tokens],
                candidate_tokens,
                smoothing_function=self.smoothing
            )
        except Exception as e:
            print(f"BLEU分数计算错误: {str(e)}")
            return 0.0

    def calculate_rouge(self, reference: str, candidate: str) -> Dict[str, float]:
        """计算ROUGE分数"""
        try:
            results = self.rouge.compute(
                predictions=[candidate],
                references=[reference]
            )
            return {
                'rouge1': results['rouge1'],
                'rouge2': results['rouge2'],
                'rougeL': results['rougeL']
            }
        except Exception as e:
            print(f"ROUGE分数计算错误: {str(e)}")
            return {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0}

    def calculate_bert_score(self, reference: str, candidate: str) -> float:
        """计算BERTScore"""
        try:
            results = self.bertscore.compute(
                predictions=[candidate],
                references=[reference],
                lang="en"
            )
            return np.mean(results['f1'])
        except Exception as e:
            print(f"BERTScore计算错误: {str(e)}")
            return 0.0

    def calculate_diversity_metrics(self, text: str) -> Dict[str, float]:
        """计算文本多样性指标"""
        words = word_tokenize(text.lower())
        unique_words = set(words)

        if len(words) > 0:
            type_token_ratio = len(unique_words) / len(words)
        else:
            type_token_ratio = 0

        return {
            "vocabulary_richness": type_token_ratio,
            "unique_words": len(unique_words),
            "total_words": len(words)
        }

    def calculate_response_time(self, start_time: float) -> float:
        """计算响应时间"""
        return time.time() - start_time

    def evaluate_generation(self,
                          generated_text: str,
                          reference_text: Optional[str] = None,
                          generation_time: Optional[float] = None) -> Dict[str, Any]:
        """评估生成文本的质量"""
        metrics = {}

        # 基础指标
        metrics["perplexity"] = self.calculate_perplexity(generated_text)
        metrics["length"] = len(word_tokenize(generated_text))

        # 多样性指标
        diversity_metrics = self.calculate_diversity_metrics(generated_text)
        metrics.update(diversity_metrics)

        # 如果有参考文本，计算相似度指标
        if reference_text:
            metrics["bleu"] = self.calculate_bleu(reference_text, generated_text)
            metrics.update(self.calculate_rouge(reference_text, generated_text))
            metrics["bert_score"] = self.calculate_bert_score(reference_text, generated_text)

        # 性能指标
        if generation_time is not None:
            metrics["generation_time"] = generation_time
            metrics["tokens_per_second"] = metrics["length"] / generation_time if generation_time > 0 else 0

        return metrics

    def batch_evaluate(self, test_cases: List[Dict[str, str]]) -> List[Dict[str, Any]]:
        """批量评估多个测试用例"""
        results = []
        total_start_time = time.time()

        for test_case in test_cases:
            prompt = test_case["prompt"]
            reference = test_case.get("reference")

            # 生成文本并记录时间
            start_time = time.time()
            generated = safe_generate(prompt)
            generation_time = self.calculate_response_time(start_time)

            # 评估指标
            metrics = self.evaluate_generation(
                generated,
                reference,
                generation_time
            )

            results.append({
                "prompt": prompt,
                "generated": generated,
                "reference": reference,
                "metrics": metrics
            })

        # 记录总体评估时间
        total_time = self.calculate_response_time(total_start_time)

        return results, total_time

    def print_evaluation_report(self, results: List[Dict[str, Any]], total_time: float):
        """打印详细评估报告"""
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"\n====== 模型评估报告 ({timestamp}) ======")

        # 初始化平均指标字典
        avg_metrics = {}

        # 详细结果
        for i, result in enumerate(results, 1):
            print(f"\n测试用例 {i}/{len(results)}:")
            print(f"提示: {result['prompt']}")
            print(f"生成: {result['generated']}")
            if result['reference']:
                print(f"参考: {result['reference']}")

            print("\n各项指标:")
            metrics = result['metrics']

            # 打印所有指标
            for metric_name, value in metrics.items():
                if isinstance(value, float):
                    print(f"- {metric_name}: {value:.4f}")
                else:
                    print(f"- {metric_name}: {value}")

                # 累积平均值
                if metric_name not in avg_metrics:
                    avg_metrics[metric_name] = []
                avg_metrics[metric_name].append(value)

        # 打印总体统计
        print("\n====== 总体统计 ======")
        print(f"总评估时间: {total_time:.2f}秒")
        print(f"平均每个样本评估时间: {total_time/len(results):.2f}秒")

        print("\n平均指标:")
        for metric_name, values in avg_metrics.items():
            if all(isinstance(x, (int, float)) for x in values):
                avg = np.mean(values)
                std = np.std(values)
                print(f"- {metric_name}:")
                print(f"  平均值: {avg:.4f}")
                print(f"  标准差: {std:.4f}")

    def save_report(self, results: List[Dict[str, Any]], total_time: float,
                   filename: str = "evaluation_report.json"):
        """保存评估报告到文件"""
        report = {
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "total_time": total_time,
            "avg_time_per_sample": total_time/len(results),
            "results": results,
            "model_config": {
                "model_name": self.model.config._name_or_path,
                "device": str(self.device)
            }
        }

        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(report, f, ensure_ascii=False, indent=2)

# ========== 使用示例 ==========
def run_evaluation(test_cases):
    """运行评估流程"""
    # 初始化评估器
    evaluator = EnhancedModelEvaluator(model, tokenizer, device)

    # 执行评估
    print("开始评估...")
    results, total_time = evaluator.batch_evaluate(test_cases)

    # 打印报告
    evaluator.print_evaluation_report(results, total_time)

    # 保存报告
    evaluator.save_report(results, total_time)

    return results, total_time

In [3]:
test_prompts = [
    # 基础知识解释
    "How to Make a Sandwich",
    "Explanation of Quantum Mechanics",
    "Write a Poem About Spring",

    # 技术指导
    "Explain the concept of recursion in programming",
    "How to optimize Python code for better performance",
    "Steps to set up a basic machine learning pipeline",

    # 创意写作
    "Write a short story about time travel",
    "Create a character description for a fantasy novel",
    "Write a scene describing a sunset at the beach",

    # 分析类
    "Analyze the impact of social media on modern society",
    "Compare and contrast renewable and non-renewable energy sources",
    "Discuss the future of artificial intelligence",

    # 实用指南
    "Tips for effective time management",
    "How to start a small business",
    "Steps to improve public speaking skills",

    # 科学解释
    "Explain how black holes work",
    "Describe the process of photosynthesis",
    "How does the human immune system function",

    # 问题解决
    "How to debug a slow running computer",
    "Ways to reduce stress and anxiety",
    "Solutions for climate change"
]

print("开始生成测试...")
print("="*50)

for i, prompt in enumerate(test_prompts, 1):
    print(f"\n测试 {i}/{len(test_prompts)}")
    print(f"输入: {prompt}")
    print(f"输出: {safe_generate(prompt)}")
    print("="*50)

# 使用评估器评估全部结果
test_cases = [{"prompt": prompt} for prompt in test_prompts]

  with torch.no_grad(), torch.cuda.amp.autocast():
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


开始生成测试...

测试 1/21
输入: How to Make a Sandwich


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: To make a sandwich, first you need some ingredients. 
- Heat a tablespoon of olive oil over medium heat and add to the sandwich with a pinch salt.
- In a large bowl, mix in the melted butter until fully combined (or softened) layer evenly.
- Add to an electric mixer for about 10 minutes until smoothies are incorporated into your sandwich.
- Transfer the cheese to a strainer and whisking well through it.

测试 2/21
输入: Explanation of Quantum Mechanics


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Quantum mechanics is a branch of mathematics that focuses on the properties and operations involved in quantum mechanics. It involves finding a physical relationship between two elements, which are known as qubits or 'qubits'. This means that each particle interacts with its own qubits by adding a bit more energy to it. In this way they can be measured, and then multiplied further until the qubits have been able find the same amount (and vice versa).

测试 3/21
输入: Write a Poem About Spring


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Spring is the perfect time to spend with a family and explore new places in your life. It's a great time if you're feeling overwhelmed or scared of something that has gone wrong.

测试 4/21
输入: Explain the concept of recursion in programming


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Recursion in programming is a process of combining different concepts together to form a single program. In programs, an array or multiple variable is used to store a list and store the elements it needs in order for each element's type to be represented. This can be a great way around complex problems, as there are many different types that can be used to store elements without any additional complexity. However if you're using a much simpler approach to pattern classification, recursion may seem like a more powerful technique than recursion.

测试 5/21
输入: How to optimize Python code for better performance


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Python code for better performance can be optimized in a variety of ways. For example, using a simple and efficient Python program can increase the overall efficiency by optimizing for different tasks in the task, while optimizing for specific applications is more efficient. Additionally, using Python code for more complex tasks can help minimize the need or time spent manually and efficiently. Furthermore, Python code can be executed on an iterative basis, making it easier than ever before when it comes to performance goals.

测试 6/21
输入: Steps to set up a basic machine learning pipeline


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Step 1: Create an external dataset and import it from Python. Create a query function and select the input field as necessary for generating output. This will be used for creating a loop that is able to print text in a form like this:

测试 7/21
输入: Write a short story about time travel


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Once upon an old man from the South, he decided to embark on a journey of adventure. After months of trying out for his family and friends in search of new ways to make ends meet by traveling alone, he stumbled upon something magical - something that kept him going all over the world. He was determined not only to explore the world but also find peace and peace within himself. Along with his newfound friends – such as their own child – they traveled together to face off against an evil evil forces known as the Dark One! Together, they made it difficult for anyone who ventured into one vast city, and ultimately ended up being able take part in a grand epic ride.

测试 8/21
输入: Create a character description for a fantasy novel


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: A young girl discovers a magical forest in the forest that is filled with magical energy. She finds a magical forest full of mythical creatures, including humans and animals who are enchanted by their power. The girls learn from them how to use magic to create magical objects, such as light bulbs or cones can fly through walls and doors like they did before!

测试 9/21
输入: Write a scene describing a sunset at the beach
输出: The sun is setting in the sand and it can be seen in the foreground.

测试 10/21
输入: Analyze the impact of social media on modern society


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Social Media has an impact in terms and impacts that could have a negative impact to society. It also highlights the importance people can make online, which is a great way for businesses owners who are not satisfied with their customers’s presence or their brand. Additionally it helps companies reduce costs associated with marketing campaigns and increase access by providing more content, allowing them better understanding the customer needs and how well they interact with their customers.

测试 11/21
输入: Compare and contrast renewable and non-renewable energy sources


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Renewable energy sources are more sustainable, cost efficient methods of generating electricity from wind, solar or other sources. These sources have higher efficiency as compared to traditional fossil fuels, such a process can be costly and inefficient, resulting in increased costs and resources. Additionally, renewable energy sources require less maintenance than traditional fossil fuels, while non-renowables typically produce lower emissions due to their use by humans.

测试 12/21
输入: Discuss the future of artificial intelligence


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Artificial Intelligence is a rapidly growing field of research, with an increasing number of applications. AI-enabled systems can be used to automate tasks and automate processes of decision making. These systems will enable governments for tasks such as finance, education management or financial services; and they can also make decisions based on human interaction and collaboration between humans and machines. The emergence in recent years has seen a rapid growth from robotics, computer vision devices that can provide personalized recommendations, predictive analytics and predictive analytics; this could lead us towards more efficient solutions.

测试 13/21
输入: Tips for effective time management
输出: Make sure you are using a combination of digital and physical tracking. If not, make sure to use a digital assistant to provide accurate and accurate information.

测试 14/21
输入: How to start a small business


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Start a small business. Begin by setting up a business that sells food, software and services, create an online store where you can buy products and services from anywhere. Then build a website with the information you need while you are creating your own website. Once there is enough information in it all along – including pricing, details of the product or service, payment processing fees for shipping, and shipping costs. In order make sure everything is simple and easy-to understand!

测试 15/21
输入: Steps to improve public speaking skills


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: 1. Provide an effective communicator for the audience. 
2, Promote empathy and open dialogue by providing a platform that encourages conversations about important topics such as education, transportation issues or environmental issues.

测试 16/21
输入: Explain how black holes work


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Black holes are a type of small, gas-filled, and invisible object, usually referred to as the "Paloons". They are objects that are typically used for observation or analysis. When they are formed from water, they move in space when they are cooled or cooled by electricity, often resulting in stars and objects.

测试 17/21
输入: Describe the process of photosynthesis


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Photosynthesis involves converting sunlight to a chemical reaction, releasing carbon dioxide and methane from nearby gases. The process involves using different materials such as oil sands or silos that trap sunlight in the atmosphere and generate energy through photosynthesis.

测试 18/21
输入: How does the human immune system function
输出: The human immune response is a response to environmental factors such as climate change and other environmental threats.

测试 19/21
输入: How to debug a slow running computer


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: The main difference between the two approaches is that both are expensive and have limited power.

测试 20/21
输入: Ways to reduce stress and anxiety


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


输出: Strategies such as mindfulness help with the task. Mindfulness can be used in order to increase concentration and focus on the task. Furthermore, mindfulness can also provide a source of relaxation when focusing or facing situations. Finally, mindful meditation helps to foster balance between positive moods while avoiding worrying about your own health or any potential health issues.

测试 21/21
输入: Solutions for climate change
输出: One possible solution to the global warming problem is reducing emissions of carbon dioxide and releasing its carbon from oceans. This could include replacing existing fossil fuels with cleaner, more efficient sources such as tap water, burning coal, and transitioning to renewable energy sources.


In [5]:
# ========== 10. 增强评估模块 ==========
import re

class EnhancedModelEvaluator:
    """增强型模型评估器"""
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.perplexity = evaluate.load("perplexity", module_type="metric")

    def simple_tokenize(self, text: str) -> List[str]:
        """简单的分词函数"""
        # 将标点符号分开
        text = re.sub(r'([.,!?()])', r' \1 ', text)
        # 分词并移除空字符串
        return [token for token in text.split() if token.strip()]

    def calculate_perplexity(self, text: str) -> float:
        """计算困惑度分数"""
        try:
            results = self.perplexity.compute(
                predictions=[text],
                model_id=self.model.config._name_or_path,
                add_start_token=True
            )
            return results["mean_perplexity"]
        except Exception as e:
            print(f"困惑度计算错误: {str(e)}")
            return float("inf")

    def calculate_simple_similarity(self, reference: str, candidate: str) -> float:
        """计算简单的相似度分数"""
        try:
            ref_tokens = set(self.simple_tokenize(reference.lower()))
            cand_tokens = set(self.simple_tokenize(candidate.lower()))

            if not ref_tokens or not cand_tokens:
                return 0.0

            # 计算 Jaccard 相似度
            intersection = len(ref_tokens.intersection(cand_tokens))
            union = len(ref_tokens.union(cand_tokens))

            return intersection / union if union > 0 else 0.0

        except Exception as e:
            print(f"相似度计算错误: {str(e)}")
            return 0.0

    def calculate_diversity_metrics(self, text: str) -> Dict[str, float]:
        """计算文本多样性指标"""
        tokens = self.simple_tokenize(text.lower())
        unique_tokens = set(tokens)

        # 词汇丰富度
        type_token_ratio = len(unique_tokens) / len(tokens) if tokens else 0

        # 计算重复 bigram 的比例
        bigrams = list(zip(tokens[:-1], tokens[1:])) if len(tokens) > 1 else []
        unique_bigrams = set(bigrams)
        bigram_diversity = len(unique_bigrams) / len(bigrams) if bigrams else 0

        return {
            "vocabulary_richness": type_token_ratio,
            "unique_words": len(unique_tokens),
            "total_words": len(tokens),
            "bigram_diversity": bigram_diversity
        }

    def calculate_response_time(self, start_time: float) -> float:
        """计算响应时间"""
        return time.time() - start_time

    def evaluate_generation(self,
                          generated_text: str,
                          reference_text: Optional[str] = None,
                          generation_time: Optional[float] = None) -> Dict[str, Any]:
        """评估生成文本的质量"""
        metrics = {}

        # 基础指标
        metrics["perplexity"] = self.calculate_perplexity(generated_text)
        metrics["length"] = len(self.simple_tokenize(generated_text))

        # 多样性指标
        diversity_metrics = self.calculate_diversity_metrics(generated_text)
        metrics.update(diversity_metrics)

        # 如果有参考文本，计算相似度指标
        if reference_text:
            metrics["similarity_score"] = self.calculate_simple_similarity(reference_text, generated_text)

        # 性能指标
        if generation_time is not None:
            metrics["generation_time"] = generation_time
            metrics["tokens_per_second"] = metrics["length"] / generation_time if generation_time > 0 else 0

        return metrics

    def batch_evaluate(self, test_cases: List[Dict[str, str]]) -> List[Dict[str, Any]]:
        """批量评估多个测试用例"""
        results = []
        total_start_time = time.time()

        for i, test_case in enumerate(test_cases, 1):
            print(f"\n评估进度: {i}/{len(test_cases)}")

            prompt = test_case["prompt"]
            reference = test_case.get("reference")

            # 生成文本并记录时间
            start_time = time.time()
            generated = safe_generate(prompt)
            generation_time = self.calculate_response_time(start_time)

            # 评估指标
            metrics = self.evaluate_generation(
                generated,
                reference,
                generation_time
            )

            results.append({
                "prompt": prompt,
                "generated": generated,
                "reference": reference,
                "metrics": metrics
            })

        total_time = self.calculate_response_time(total_start_time)
        return results, total_time

    def print_evaluation_report(self, results: List[Dict[str, Any]], total_time: float):
        """打印详细评估报告"""
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"\n====== 模型评估报告 ({timestamp}) ======")

        # 初始化平均指标字典
        avg_metrics = {}

        # 详细结果
        for i, result in enumerate(results, 1):
            print(f"\n测试用例 {i}/{len(results)}:")
            print(f"提示: {result['prompt']}")
            print(f"生成: {result['generated']}")
            if result['reference']:
                print(f"参考: {result['reference']}")

            print("\n指标:")
            metrics = result['metrics']
            for metric_name, value in metrics.items():
                if isinstance(value, float):
                    print(f"- {metric_name}: {value:.4f}")
                else:
                    print(f"- {metric_name}: {value}")

                # 累积平均值
                if metric_name not in avg_metrics:
                    avg_metrics[metric_name] = []
                avg_metrics[metric_name].append(value)

        # 打印总体统计
        print("\n====== 总体统计 ======")
        print(f"总评估时间: {total_time:.2f}秒")
        print(f"平均每个样本评估时间: {total_time/len(results):.2f}秒")

        print("\n平均指标:")
        for metric_name, values in avg_metrics.items():
            if all(isinstance(x, (int, float)) for x in values):
                avg = np.mean(values)
                std = np.std(values)
                print(f"- {metric_name}:")
                print(f"  平均值: {avg:.4f}")
                print(f"  标准差: {std:.4f}")

    def save_report(self, results: List[Dict[str, Any]], total_time: float,
                   filename: str = "evaluation_report.json"):
        """保存评估报告到文件"""
        report = {
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "total_time": total_time,
            "avg_time_per_sample": total_time/len(results),
            "results": results,
            "model_config": {
                "model_name": self.model.config._name_or_path,
                "device": str(self.device)
            }
        }

        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(report, f, ensure_ascii=False, indent=2)

# ========== 使用示例 ==========
def run_evaluation(test_cases):
    """运行评估流程"""
    # 初始化评估器
    evaluator = EnhancedModelEvaluator(model, tokenizer, device)

    # 执行评估
    print("开始评估...")
    results, total_time = evaluator.batch_evaluate(test_cases)

    # 打印报告
    evaluator.print_evaluation_report(results, total_time)

    # 保存报告
    evaluator.save_report(results, total_time)

    return results, total_time

In [6]:
# 准备测试用例
test_cases = [{"prompt": prompt} for prompt in test_prompts]

# 初始化评估器并运行评估
results, total_time = run_evaluation(test_cases)

  with torch.no_grad(), torch.cuda.amp.autocast():
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


开始评估...

评估进度: 1/21


100%|██████████| 1/1 [00:00<00:00, 35.65it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 2/21


100%|██████████| 1/1 [00:00<00:00, 42.14it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 3/21


100%|██████████| 1/1 [00:00<00:00, 169.69it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 4/21


100%|██████████| 1/1 [00:00<00:00, 41.94it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 5/21


100%|██████████| 1/1 [00:00<00:00, 42.19it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 6/21


100%|██████████| 1/1 [00:00<00:00, 42.21it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 7/21


100%|██████████| 1/1 [00:00<00:00, 41.86it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 8/21


100%|██████████| 1/1 [00:00<00:00, 42.20it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 9/21


100%|██████████| 1/1 [00:00<00:00, 166.16it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 10/21


100%|██████████| 1/1 [00:00<00:00, 42.15it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 11/21


100%|██████████| 1/1 [00:00<00:00, 174.07it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 12/21


100%|██████████| 1/1 [00:00<00:00, 42.07it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 13/21


100%|██████████| 1/1 [00:00<00:00, 164.51it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 14/21


100%|██████████| 1/1 [00:00<00:00, 206.12it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 15/21


100%|██████████| 1/1 [00:00<00:00, 165.00it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 16/21


100%|██████████| 1/1 [00:00<00:00, 42.17it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 17/21


100%|██████████| 1/1 [00:00<00:00, 152.67it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 18/21


100%|██████████| 1/1 [00:00<00:00, 42.05it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 19/21


100%|██████████| 1/1 [00:00<00:00, 171.34it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 20/21


100%|██████████| 1/1 [00:00<00:00, 177.97it/s]
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



评估进度: 21/21


100%|██████████| 1/1 [00:00<00:00, 192.94it/s]



测试用例 1/21:
提示: How to Make a Sandwich
生成: To make a sandwich, first create a slice of bread. Place the slices in a bowl and top with a layer of butter or baking soda. When the bread is baked, add an inch and a half of the butter and bake for 30 minutes. In another step, place on parchment paper and mix the butter and sugar into the batter.

指标:
- perplexity: 9.6638
- length: 68
- vocabulary_richness: 0.6471
- unique_words: 44
- total_words: 68
- bigram_diversity: 0.9701
- generation_time: 0.4455
- tokens_per_second: 152.6489

测试用例 2/21:
提示: Explanation of Quantum Mechanics
生成: Quantum mechanics is a type of physics used to explain how an atom interacts with molecules and their energy. It is used in many applications, such as healthcare, transportation systems or even medical diagnostics. Quantum mechanics states that quantum particles interact with each other in a way that can be considered to help them better understand the properties required for their interactions over time. It ha


