In [68]:
model_name_or_path = "openai/whisper-large-v2"
model_dir = "models/whisper-large-v2-asr-int8"

language = "Chinese (China)"
language_abbr = "zh-CN"
# language_abbr = "chinese"
task = "transcribe"
dataset_name = "mozilla-foundation/common_voice_11_0"

batch_size=64

In [69]:
from datasets import load_dataset, DatasetDict

common_voice = DatasetDict()

common_voice["train"] = load_dataset(dataset_name, language_abbr, split="train", trust_remote_code=True)
common_voice["validation"] = load_dataset(dataset_name, language_abbr, split="validation", trust_remote_code=True)

common_voice["train"][0]

{'client_id': '95368aab163e0387e4fd4991b4f2d8ccfbd4364bf656c860230501fd27dcedf087773e4695a6cf5de9c4f1d406d582283190d065cdfa36b0e2b060cffaca977e',
 'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',
 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',
  'array': array([-6.82121026e-13, -2.27373675e-12, -2.27373675e-12, ...,
          1.21667399e-05,  3.23003678e-06, -2.43066324e-07], shape=(255744,)),
  'sampling_rate': 48000},
 'sentence': '性喜温暖润湿气候且耐寒。',
 'up_votes': 2,
 'down_votes': 0,
 'age': '',
 'gender': '',
 'accent': '',
 'locale': 'zh-CN',
 'segment': ''}

In [70]:
common_voice

DatasetDict({
    train: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 29056
    })
    validation: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 10581
    })
})

In [71]:
from transformers import AutoFeatureExtractor, AutoTokenizer, AutoProcessor

# 从预训练模型加载特征提取器
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)

# 从预训练模型加载分词器，可以指定语言和任务以获得最适合特定需求的分词器配置
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, language=language, task=task)

# 从预训练模型加载处理器，处理器通常结合了特征提取器和分词器，为特定任务提供一站式的数据预处理
processor = AutoProcessor.from_pretrained(model_name_or_path, language=language, task=task)

In [72]:
common_voice = common_voice.remove_columns(
    ["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]
)

In [73]:
common_voice["train"][0]

{'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',
  'array': array([-6.82121026e-13, -2.27373675e-12, -2.27373675e-12, ...,
          1.21667399e-05,  3.23003678e-06, -2.43066324e-07], shape=(255744,)),
  'sampling_rate': 48000},
 'sentence': '性喜温暖润湿气候且耐寒。'}

In [74]:
from datasets import Audio

common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

In [75]:
# sampling_rate 从 48KHZ 降为 16KHZ
common_voice["train"][0]

{'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',
  'array': array([ 5.09317033e-11, -7.27595761e-12, -6.54836185e-11, ...,
         -5.96661994e-06,  2.71382887e-05,  1.29687978e-05], shape=(85248,)),
  'sampling_rate': 16000},
 'sentence': '性喜温暖润湿气候且耐寒。'}

In [76]:
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

In [77]:
small_common_voice = DatasetDict()

small_common_voice["train"] = common_voice["train"].shuffle(seed=16).select(range(640))
small_common_voice["validation"] = common_voice["validation"].shuffle(seed=16).select(range(320))

small_common_voice

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 640
    })
    validation: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 320
    })
})

In [78]:
# 抽样数据处理
tokenized_common_voice = small_common_voice.map(prepare_dataset)

# 完整数据训练，尝试开启 `num_proc=8` 参数多进程并行处理（如阻塞无法运行，则不使用此参数）
# tokenized_common_voice = common_voice.map(prepare_dataset)

In [79]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

# 定义一个针对语音到文本任务的数据整理器类
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any  # 处理器结合了特征提取器和分词器

    # 整理器函数，将特征列表处理成一个批次
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # 从特征列表中提取输入特征，并填充以使它们具有相同的形状
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # 从特征列表中提取标签特征（文本令牌），并进行填充
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # 使用-100替换标签中的填充区域，-100通常用于在损失计算中忽略填充令牌
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # 如果批次中的所有序列都以句子开始令牌开头，则移除它
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        # 将处理过的标签添加到批次中
        batch["labels"] = labels

        return batch  # 返回最终的批次，准备好进行训练或评估

In [80]:
# 用给定的处理器实例化数据整理器
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [81]:
# 1. 强制使用正确的模型类（WhisperForConditionalGeneration）
from transformers import WhisperForConditionalGeneration, BitsAndBytesConfig

# 2. 8位量化配置不变
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")



# 验证模型类型（可选，用于确认是否正确加载）
print(type(model))  # 应输出：<class 'transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration'>

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


<class 'transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration'>


In [82]:
# # 设置模型配置中的forced_decoder_ids属性为None
# model.config.forced_decoder_ids = None  # 这通常用于指定在解码（生成文本）过程中必须使用的特定token的ID，设置为None表示没有这样的强制要求

# # 设置模型配置中的suppress_tokens列表为空
# model.config.suppress_tokens = []  # 这用于指定在生成过程中应被抑制（不生成）的token的列表，设置为空列表表示没有要抑制的token


# 配置解码器提示ID（适配语音转文本任务）
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language="chinese",  # 用全称，而非缩写
    task=task
)
model.config.suppress_tokens = []  # 不抑制任何token

In [85]:
from transformers import AutoModelForSpeechSeq2Seq, BitsAndBytesConfig
import torch

# 配置8位量化参数
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,  # 保持8位量化
    bnb_8bit_compute_dtype=torch.float16,  # 计算时使用float16
    bnb_8bit_use_double_quant=True,  # 双量化优化内存
    bnb_8bit_quant_type="nf8"  # 8位量化的优化类型
)

# 加载模型时应用量化配置，device_map="auto"会自动分配设备
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_name_or_path,
    quantization_config=bnb_config,  # 应用8位量化配置
    device_map="auto"  # 自动分配设备，无需后续手动移动
)

# 获取模型所在设备（由device_map自动分配）
device = model.device

# 确保卷积层参数类型正确（针对Whisper的conv1/conv2）
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv1d):
        # 强制卷积层权重和偏置为float16（与输入匹配）
        # 不移动设备，只转换类型
        module.weight = torch.nn.Parameter(module.weight.half())
        if module.bias is not None:
            module.bias = torch.nn.Parameter(module.bias.half())

# 处理输入数据类型 - 在使用时调用
def prepare_inputs(input_features):
    # 将输入转换为float16并移动到模型所在设备
    return input_features.half().to(device)
    

In [86]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

# 创建一个LoraConfig对象，用于设置LoRA（Low-Rank Adaptation）的配置参数
config = LoraConfig(
    r=4,  # LoRA的秩，影响LoRA矩阵的大小
    lora_alpha=64,  # LoRA适应的比例因子
    # 指定将LoRA应用到的模型模块，通常是attention和全连接层的投影。
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,  # 在LoRA模块中使用的dropout率
    bias="none",  # 设置bias的使用方式，这里没有使用bias
)


In [87]:
peft_model = get_peft_model(model, config)

In [88]:
# 打印 LoRA 微调训练的模型参数
peft_model.print_trainable_parameters()

trainable params: 1,966,080 || all params: 1,545,271,040 || trainable%: 0.1272


In [89]:
from transformers import Seq2SeqTrainingArguments

# 设置序列到序列模型训练的参数
training_args = Seq2SeqTrainingArguments(
    output_dir=model_dir,  # 指定模型输出和保存的目录
    per_device_train_batch_size=batch_size,  # 每个设备上的训练批量大小
    learning_rate=1e-3,  # 学习率
    num_train_epochs=1,  # 训练的总轮数
    eval_strategy="epoch",  # 新版本中使用eval_strategy替代evaluation_strategy
    per_device_eval_batch_size=batch_size,  # 每个设备上的评估批量大小
    generation_max_length=128,  # 生成任务的最大长度
    logging_steps=10,  # 指定日志记录的步骤，用于跟踪训练进度
    remove_unused_columns=False,  # 是否删除不使用的列，以减少数据处理开销
    label_names=["labels"],  # 指定标签列的名称，用于训练过程中
)

In [90]:
from transformers import Seq2SeqTrainer



trainer = Seq2SeqTrainer(
    args=training_args,
    model=peft_model,
    train_dataset=tokenized_common_voice["train"],
    eval_dataset=tokenized_common_voice["validation"],
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
)
peft_model.config.use_cache = False

  trainer = Seq2SeqTrainer(


In [91]:
print(type(model))  # 应输出：<class 'transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration'>

<class 'transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration'>


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset, DatasetDict, Audio
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
import os

# 中文显示配置
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
plt.rcParams["axes.unicode_minus"] = False

# -------------- 关键加速配置（必须依赖高端GPU）--------------
model_name_or_path = "openai/whisper-large-v2"  # 保留大模型
batch_size = 64  # 超大批次（需24GB+显存）
gradient_accumulation_steps = 1  # 不累积梯度，直接大批次
num_train_epochs = 1  # 仅训练1轮（全量数据）
learning_rate = 3e-3  # 高学习率加速收敛
logging_steps = 20  # 减少日志频率
eval_steps = 200  # 减少验证频率
# -----------------------------------------------------------

output_dir = "./whisper_large_fast"
language = "Chinese"
language_abbr = "zh-CN"
task = "transcribe"
dataset_name = "mozilla-foundation/common_voice_11_0"

os.makedirs(output_dir, exist_ok=True)

# 1. 加载全量数据集（保留全量数据）
print("加载全量数据集...")
common_voice = DatasetDict()
common_voice["train"] = load_dataset(dataset_name, language_abbr, split="train", trust_remote_code=True)
common_voice["validation"] = load_dataset(dataset_name, language_abbr, split="validation", trust_remote_code=True)
common_voice["test"] = load_dataset(dataset_name, language_abbr, split="test", trust_remote_code=True)

# 移除冗余字段加速处理
common_voice = common_voice.remove_columns(
    ["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]
)
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

# 2. 加载处理器（复用大模型处理器）
processor = AutoProcessor.from_pretrained(model_name_or_path, language=language, task=task)
feature_extractor = processor.feature_extractor
tokenizer = processor.tokenizer

# 3. 预处理（多进程+精简逻辑）
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=16000).input_features[0]
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

print("预处理全量数据（多进程加速）...")
tokenized_common_voice = common_voice.map(
    prepare_dataset,
    remove_columns=common_voice["train"].column_names,
    num_proc=16  # 最大程度利用CPU核心
)

# 4. 数据整理器（复用逻辑）
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    def __call__(self, features):
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        
        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# 5. 评估指标（仅保留WER）
wer_metric = evaluate.load("wer")
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    return {"wer": 100 * wer_metric.compute(predictions=pred_str, references=label_str)}

# 6. 模型配置（4bit量化+LoRA极致加速）
print("加载whisper-large-v2（4bit量化）...")
# 4bit量化（比8bit快2倍，显存占用减少50%）
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16  # 计算精度
)

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_name_or_path,
    quantization_config=quantization_config,
    device_map="auto",  # 自动分配到GPU
    torch_dtype=torch.float16
)

# 关键：冻结大部分参数，仅训练LoRA（加速5-10倍）
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)  # 梯度检查点省显存

# LoRA配置（仅适配关键模块）
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],  # 仅注意力Q/V投影层
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
)

model = get_peft_model(model, lora_config)
print("可训练参数:")
model.print_trainable_parameters()  # 仅0.05%参数可训练

# 7. 训练参数（极限加速配置）
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=learning_rate,
    num_train_epochs=num_train_epochs,
    logging_steps=logging_steps,
    eval_steps=eval_steps,
    evaluation_strategy="steps",
    save_strategy="no",  # 不保存中间模型（省时间）
    load_best_model_at_end=False,
    fp16=True,  # 强制FP16
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim="paged_adamw_8bit",  # 8bit优化器（省显存）
    generation_max_length=128,
    remove_unused_columns=False,
    label_names=["labels"],
    report_to="none",  # 关闭日志报告
    dataloader_num_workers=8,
    dataloader_pin_memory=True,  # 内存pin住加速数据传输
)

# 8. 自定义训练器（确保参数正确）
class CustomTrainer(Seq2SeqTrainer):
    def training_step(self, model, inputs, num_items_in_batch):
        outputs = model(input_features=inputs["input_features"], labels=inputs["labels"])
        loss = outputs.loss
        self.accelerator.backward(loss)
        return loss.detach()

# 9. 开始训练
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_common_voice["train"],
    eval_dataset=tokenized_common_voice["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor
)

print("开始训练（依赖高端GPU，预计1-2小时）...")
train_result = trainer.train()

# 10. 绘制损失曲线
logs = trainer.state.log_history
train_losses = [log["loss"] for log in logs if "loss" in log and "eval_loss" not in log]
train_steps = [log["step"] for log in logs if "loss" in log and "eval_loss" not in log]
val_losses = [log["eval_loss"] for log in logs if "eval_loss" in log]
val_steps = [log["step"] for log in logs if "eval_loss" in log]

plt.figure(figsize=(8, 5))
plt.plot(train_steps, train_losses, label="训练损失")
plt.plot(val_steps, val_losses, label="验证损失")
plt.xlabel("训练步数")
plt.ylabel("损失值")
plt.title("训练与验证损失变化")
plt.legend()
plt.savefig(f"{output_dir}/loss_curve.png")

# 11. 测试集评估
test_results = trainer.evaluate(eval_dataset=tokenized_common_voice["test"])
print(f"测试集WER: {test_results['wer']:.2f}%")

print(f"训练完成，结果保存至 {output_dir}")

加载全量数据集...
预处理全量数据（多进程加速）...


Map (num_proc=16):   0%|          | 0/29056 [00:00<?, ? examples/s]

In [None]:
#以上代码导致：服务器磁盘读写/内存出现了跑高的情况，主要是磁盘读写/内存的读占用过高导致系统夯住了，导致系统无法远程连接。

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset, DatasetDict, Audio
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
import os
import gc  # 垃圾回收

# 中文显示配置
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
plt.rcParams["axes.unicode_minus"] = False

# -------------- 低资源配置（减少磁盘/内存压力）--------------
model_name_or_path = "openai/whisper-large-v2"
batch_size = 16  # 降低批次，减少内存占用
gradient_accumulation_steps = 4  # 梯度累积替代大批次，平衡显存
num_train_epochs = 1
learning_rate = 2e-3
logging_steps = 50
eval_steps = 500
max_preprocess_workers = 4  # 降低预处理并行度，减少磁盘IO
# -----------------------------------------------------------

output_dir = "./whisper_large_low_resource"
language = "Chinese"
language_abbr = "zh-CN"
task = "transcribe"
dataset_name = "mozilla-foundation/common_voice_11_0"

os.makedirs(output_dir, exist_ok=True)

# 1. 分阶段加载数据集（减少一次性内存占用）
print("分阶段加载数据集...")
common_voice = DatasetDict()

# 训练集分块加载（避免一次性读入内存）
train_dataset = load_dataset(
    dataset_name, language_abbr, split="train", trust_remote_code=True, streaming=False
)
common_voice["train"] = train_dataset

# 验证集和测试集正常加载
common_voice["validation"] = load_dataset(
    dataset_name, language_abbr, split="validation", trust_remote_code=True
)
common_voice["test"] = load_dataset(
    dataset_name, language_abbr, split="test", trust_remote_code=True
)

# 移除冗余字段，减少数据体积
common_voice = common_voice.remove_columns(
    ["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]
)
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

# 手动触发垃圾回收，释放未使用内存
gc.collect()

# 2. 加载处理器（复用）
processor = AutoProcessor.from_pretrained(model_name_or_path, language=language, task=task)
feature_extractor = processor.feature_extractor
tokenizer = processor.tokenizer

# 3. 低IO预处理（减少并行度+分块处理）
def prepare_dataset(batch):
    audio = batch["audio"]
    # 音频特征提取（避免临时文件堆积）
    batch["input_features"] = feature_extractor(
        audio["array"], sampling_rate=16000, return_tensors="pt"  # 直接返回Tensor，减少中间转换
    ).input_features[0]
    # 标签处理
    batch["labels"] = tokenizer(batch["sentence"], truncation=True, max_length=128).input_ids
    return batch

print("低并行度预处理数据...")
tokenized_common_voice = common_voice.map(
    prepare_dataset,
    remove_columns=common_voice["train"].column_names,
    num_proc=max_preprocess_workers,  # 降低并行进程数，减少磁盘读写
    batch_size=100,  # 分小批处理，避免内存峰值
    load_from_cache_file=True,  # 缓存预处理结果，避免重复IO
    cache_file_names={  # 自定义缓存路径，集中管理
        "train": f"{output_dir}/train_cache.arrow",
        "validation": f"{output_dir}/val_cache.arrow",
        "test": f"{output_dir}/test_cache.arrow"
    }
)

# 再次回收内存
gc.collect()

# 4. 轻量数据整理器（避免冗余操作）
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    def __call__(self, features):
        # 仅处理必要字段，减少中间变量
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        
        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# 5. 评估指标（延迟加载，减少初始化内存）
def compute_metrics(pred):
    wer_metric = evaluate.load("wer")  # 评估时才加载，节省内存
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    return {"wer": 100 * wer_metric.compute(predictions=pred_str, references=label_str)}

# 6. 模型配置（平衡速度和资源）
print("加载模型（8bit量化，低内存模式）...")
# 8bit量化（比4bit更稳定，内存占用适中）
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_use_double_quant=True,
    bnb_8bit_compute_dtype=torch.float16
)

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_name_or_path,
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True  # 启用低CPU内存模式
)

# 配置模型，减少不必要计算
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.pad_token_id = model.config.eos_token_id
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)  # 梯度检查点省显存

# LoRA配置（适度适配，避免过多计算）
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
)

model = get_peft_model(model, lora_config)
print("可训练参数:")
model.print_trainable_parameters()

# 回收模型加载后的冗余内存
gc.collect()

# 7. 训练参数（减少IO操作）
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=learning_rate,
    num_train_epochs=num_train_epochs,
    logging_steps=logging_steps,
    eval_steps=eval_steps,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=1000,  # 减少保存频率，降低磁盘IO
    save_total_limit=1,  # 只保留1个模型文件
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    fp16=True,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim="paged_adamw_8bit",  # 分页优化器，减少内存碎片
    generation_max_length=128,
    remove_unused_columns=False,
    label_names=["labels"],
    report_to="none",
    dataloader_num_workers=2,  # 降低数据加载线程，减少磁盘压力
    dataloader_pin_memory=False,  # 关闭pin内存，避免内存锁定
    disable_tqdm=False,  # 保留进度条，方便观察状态
)

# 8. 自定义训练器（定期回收内存）
class CustomTrainer(Seq2SeqTrainer):
    def training_step(self, model, inputs, num_items_in_batch):
        outputs = model(input_features=inputs["input_features"], labels=inputs["labels"])
        loss = outputs.loss
        self.accelerator.backward(loss)
        # 每10步回收一次内存，避免内存泄漏
        if self.state.global_step % 10 == 0:
            gc.collect()
            torch.cuda.empty_cache()  # 清空CUDA缓存
        return loss.detach()

# 9. 开始训练
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_common_voice["train"],
    eval_dataset=tokenized_common_voice["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor
)

print("开始训练（低资源模式，预计2-3小时）...")
train_result = trainer.train()

# 10. 绘制损失曲线（轻量化）
logs = trainer.state.log_history
train_losses = [log["loss"] for log in logs if "loss" in log and "eval_loss" not in log]
train_steps = [log["step"] for log in logs if "loss" in log and "eval_loss" not in log]
val_losses = [log["eval_loss"] for log in logs if "eval_loss" in log]
val_steps = [log["step"] for log in logs if "eval_loss" in log]

plt.figure(figsize=(8, 5))
plt.plot(train_steps, train_losses, label="训练损失")
plt.plot(val_steps, val_losses, label="验证损失")
plt.xlabel("训练步数")
plt.ylabel("损失值")
plt.title("训练与验证损失变化")
plt.legend()
plt.savefig(f"{output_dir}/loss_curve.png")
plt.close()  # 关闭图像，释放内存

# 11. 测试集评估（分批次评估，减少内存）
test_results = trainer.evaluate(
    eval_dataset=tokenized_common_voice["test"],
    max_length=128,
    num_beams=1  # 简化生成策略，减少计算
)
print(f"测试集WER: {test_results['wer']:.2f}%")

# 最终清理
gc.collect()
torch.cuda.empty_cache()
print(f"训练完成，结果保存至 {output_dir}")