In [None]:
from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM, 
                          DataCollatorForSeq2Seq, 
                          Seq2SeqTrainingArguments, 
                          Seq2SeqTrainer, 
                          BartForConditionalGeneration)
import data_utils

In [None]:
batch_size = 32
epochs = 5
max_input_length = 512 # 最大输入长度 
max_target_length = 256 # 最大输出长度
lr = 1e-04

In [None]:
data_sets = data_utils.load_data()
tokenizer = AutoTokenizer.from_pretrained("fnlp/bart-base-chinese")

model = AutoModelForSeq2SeqLM.from_pretrained("fnlp/bart-base-chinese")

tokenized_datasets = data_utils.tokenized(
                data_sets, 
                tokenizer, 
                max_input_length=max_input_length,
                max_target_length=max_target_length)

In [None]:
# 设置训练参数
args = Seq2SeqTrainingArguments(
    output_dir="results", # 模型保存路径
    num_train_epochs=epochs,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=lr,
    warmup_steps=500,
    weight_decay=0.001,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=500,
    evaluation_strategy="steps", # 按步数评估
    save_total_limit=3, 
    generation_max_length=max_target_length, # 生成的最大长度
    generation_num_beams=1,

    # 使用rouge-1评估最优模型 
    load_best_model_at_end=True,
    metric_for_best_model="rouge-1"
)

In [None]:
import numpy as np
import lawrouge

def compute_metrics(eval_pred):
    """
    @arg1:  predictions, labels = eval_pred
    """

    def _decode(eval_pred_):
        predictions, labels = eval_pred_
        decoded_preds = tokenizer.batch_decode(predictions, 
                            skip_special_tokens=True)
        labels = np.where(labels != -100, 
                        labels, 
                        tokenizer.pad_token_id)
        # labels 忽略无意义的填充部分
        decoded_labels = tokenizer.batch_decode(labels, 
                                            skip_special_tokens=True)
        return decoded_preds, decoded_labels
    

    def _join(d_pred):
        """
        @function:将解码后的预测值与人工摘要join成完整的句子
        """
        decoded_preds, decoded_labels = d_pred
        decoded_preds = ["".join(pred.replace(" ", "")) 
                        for pred in decoded_preds]
        decoded_labels = ["".join(label.replace(" ", "")) 
                        for label in decoded_labels]
        return decoded_preds, decoded_labels
    
    dp, dl = _join(_decode(eval_pred))
    rouge = lawrouge.Rouge()
    result = rouge.get_scores(dp, 
                              dl,
                              avg=True)
    # 计算批次中摘要的平均得分
    result = {'rouge-1': result['rouge-1']['f'], 
              'rouge-2': result['rouge-2']['f'], 
              'rouge-l': result['rouge-l']['f']}
    result = {key: value * 100 
              for key, value in result.items()}
    return result

In [None]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_utils.collate_fn,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()
print(trainer.evaluate(tokenized_datasets["validation"]))
# 打印测试集上的结果
print(trainer.evaluate(tokenized_datasets["test"]))
# 保存最优模型
trainer.save_model("results/best")