In [1]:
"""
data_seed用于控制数据加载的随机性，保证实验可复现。
早停止记得开(EarlyStoppingCallback)。
"""
# train.py
import random
import numpy as np
import torch
from transformers import BertConfig, Trainer, TrainingArguments, EarlyStoppingCallback, EvalPrediction
from seqeval.metrics import classification_report
from transformers.trainer_utils import EvalLoopOutput
from transformers import AutoTokenizer, get_scheduler
from bert_crf_data_processing import prepare_datasets, generate_label_map
from bert_crf_model import BERT_CRF
import matplotlib.pyplot as plt
import os
from sklearn.metrics import confusion_matrix
import pandas as pd

# 生成标签映射
label_map = generate_label_map()
label_map_inv = {v: k for k, v in label_map.items()}

class Config:
    # 设置随机数种子
    seed_val = 42
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    max_length = 512
    batch_size = 32
    num_labels = 49  # 假设标签数是12*4+1=49，请根据实际标签数修改
    model_name = "../models/bert_pretrained"
    data_path = "../datasets/train/ablation_trainset_total.txt"
    output_dir = "../../../hy-tmp/models/Total_bertlr5e-5_crflr5e-3_cosine" # 修改输出目录

    num_epochs = 35
    weight_decay = 0.01
    bert_lr = 5e-5 # BERT learning rate
    crf_lr = 5e-3 # CRF learning rate
    learning_rate = 5e-5 # 额外学习率参数，如果需要可以调整
    warmup_ratio = 0.1

    lr_scheduler_type = "cosine"
    early_stopping_patience = 7
    max_grad_norm = 1.0

class NERTrainer(Trainer):
    def __init__(self, label_map_inv, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.label_map_inv = label_map_inv
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        with torch.no_grad():
            outputs = model(**inputs)
            loss = outputs.get("loss")
            decoded_tags = model.crf.decode(outputs["logits"], mask=inputs["attention_mask"].bool())
        return (loss, decoded_tags, inputs["labels"].cpu().numpy())

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs = model(**inputs)
        loss = outputs["loss"]
        return (loss, outputs) if return_outputs else loss

    def evaluation_loop(self, dataloader, description, prediction_loss_only=False, ignore_keys=None, metric_key_prefix="eval"):
        model = self.model.eval()
        total_loss = 0.0
        all_preds = []
        all_labels = []
        all_masks = []  # 新增列表保存所有attention_mask

        for batch in dataloader:
            inputs = self._prepare_inputs(batch)
            with torch.no_grad():
                loss, decoded_tags, labels = self.prediction_step(model, inputs, prediction_loss_only)
                total_loss += loss.item()
                all_preds.extend(decoded_tags)
                all_labels.extend(labels)
                all_masks.extend(inputs["attention_mask"].cpu().numpy())  # 保存所有mask

        # 处理时使用所有保存的mask
        true_labels = []
        pred_labels = []
        for pred_seq, true_seq, mask in zip(all_preds, all_labels, all_masks):
            seq_len = np.sum(mask)
            true_labels.append([self.label_map_inv[l] for l in true_seq[:seq_len]])
            pred_labels.append([self.label_map_inv[p] for p in pred_seq[:seq_len]])

        report = classification_report(true_labels, pred_labels, output_dict=True)
        # 输出指标查看
        print(classification_report(true_labels, pred_labels))

        metrics = {
            f"{metric_key_prefix}_loss": total_loss / len(dataloader),
            f"{metric_key_prefix}_precision": report["micro avg"]["precision"],
            f"{metric_key_prefix}_recall": report["micro avg"]["recall"],
            f"{metric_key_prefix}_f1": report["micro avg"]["f1-score"],
        }

        # 生成混淆矩阵 (绘制图像并保存)
        true_entities = []
        pred_entities = []
        for tl_seq, pl_seq in zip(true_labels, pred_labels):
            for tl, pl in zip(tl_seq, pl_seq):
                if tl == 'O':
                    true_ent = 'O'
                else:
                    parts = tl.split('-', 1)
                    true_ent = parts[1] if len(parts) > 1 else 'O'
                if pl == 'O':
                    pred_ent = 'O'
                else:
                    parts = pl.split('-', 1)
                    pred_ent = parts[1] if len(parts) > 1 else 'O'
                true_entities.append(true_ent)
                pred_entities.append(pred_ent)

        entity_types = sorted(list(set(true_entities + pred_entities)))
        cm = confusion_matrix(true_entities, pred_entities, labels=entity_types)
        cm_df = pd.DataFrame(cm, index=entity_types, columns=entity_types)

        # ---  绘制混淆矩阵热图 ---
        fig = plt.figure(figsize=(12, 10)) # 可调整图像大小
        ax = fig.add_subplot(111)
        cax = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) # 使用蓝色的 colormap
        fig.colorbar(cax) # 添加颜色栏
        tick_marks = np.arange(len(entity_types))
        ax.set_xticks(tick_marks)
        ax.set_yticks(tick_marks)
        ax.set_xticklabels(entity_types, rotation=45, ha="right") # 旋转 x 轴标签，使其不重叠
        ax.set_yticklabels(entity_types)

        # 在混淆矩阵上显示数值
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                ax.text(j, i, cm[i, j], ha="center", va="center", color='black' if cm[i,j] < cm.max()/2 else 'white') # 根据数值大小调整颜色

        plt.tight_layout() # 自动调整子图参数，提供图像边缘周围的填充
        plt.ylabel('True Entity')
        plt.xlabel('Predicted Entity')
        plt.title(f'Confusion Matrix - Epoch {int(self.state.epoch)}') # 添加标题

        # 获取当前epoch数
        epoch_num = int(self.state.epoch)
        save_dir = os.path.join(self.args.output_dir, 'confusion_matrices')
        os.makedirs(save_dir, exist_ok=True)
        filename = f"confusion_matrix_epoch_{epoch_num}.png" # 保存为 PNG 图片
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath) # 保存图像
        plt.close(fig) # 关闭图像，释放内存
        print(f"\nConfusion matrix image for epoch {epoch_num} saved to {filepath}")

        # --- (保留保存 CSV 的代码，如果需要的话，可以注释掉) ---
        # filename_csv = f"confusion_matrix_epoch_{epoch_num}.csv"
        # filepath_csv = os.path.join(save_dir, filename_csv)
        # cm_df.to_csv(filepath_csv)
        # print(f"\nConfusion matrix for epoch {epoch_num} saved to {filepath_csv}")
        return EvalLoopOutput(
            predictions=all_preds,
            label_ids=all_labels,
            metrics=metrics,
            num_samples=len(all_labels),
        )



def collate_fn(batch):
    return {
        "input_ids": torch.stack([x["input_ids"] for x in batch]),
        "attention_mask": torch.stack([x["attention_mask"] for x in batch]),
        "labels": torch.stack([x["labels"] for x in batch]),
    }

def main():
    cfg = Config()
    random.seed(cfg.seed_val)
    np.random.seed(cfg.seed_val)
    torch.manual_seed(cfg.seed_val)
    torch.cuda.manual_seed_all(cfg.seed_val)

    # 初始化模型
    pre_config = BertConfig.from_pretrained(cfg.model_name, num_labels=cfg.num_labels)
    model = BERT_CRF.from_pretrained(cfg.model_name, config=pre_config).to(cfg.device)

    # 加载Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
    train_ds, val_ds = prepare_datasets(cfg.data_path, tokenizer)

    optimizer_grouped_parameters = [
        {
            "name": "bert_base",
            "params": [p for n, p in model.bert.named_parameters() if p.requires_grad],
            "lr": cfg.bert_lr,
        },
        {
            "name": "crf_base",
            "params": [p for n, p in model.crf.named_parameters() if p.requires_grad],
            "lr": cfg.crf_lr,
        },
        {
            "params": [p for n, p in model.classifier.named_parameters() if p.requires_grad], # 假设分类器层名为 classifier
            "lr": cfg.learning_rate, # 使用 learning_rate
        }
    ]
    # 创建优化器，使用 AdamW，并将分组参数传递进去
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, weight_decay=cfg.weight_decay)

    num_training_steps = len(train_ds) // cfg.batch_size * cfg.num_epochs # 粗略估计，根据实际情况调整
    num_warmup_steps = int(cfg.warmup_ratio * num_training_steps)

    lr_scheduler = get_scheduler(
        name=cfg.lr_scheduler_type, # 可以尝试 "cosine"
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

    # 训练参数
    training_args = TrainingArguments(
        output_dir=cfg.output_dir,
        per_device_train_batch_size=cfg.batch_size,
        per_device_eval_batch_size=cfg.batch_size * 2,
        num_train_epochs=cfg.num_epochs,
        weight_decay=cfg.weight_decay,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_f1",
        remove_unused_columns=False,
        fp16=False,
        max_grad_norm=cfg.max_grad_norm,
        logging_steps=20,
        seed=cfg.seed_val,
        data_seed=cfg.seed_val,
    )

    trainer = NERTrainer(
        label_map_inv=label_map_inv,
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        compute_metrics=None,
        data_collator=collate_fn,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=cfg.early_stopping_patience)], # 确保早停回调被正确添加
        optimizers=(optimizer, lr_scheduler),
    )

    # 存储训练过程中的 loss 和 metrics
    train_history = trainer.train()

    # 从 trainer 的 state 中获取日志信息
    log_history = trainer.state.log_history

    # 提取训练 loss 和 验证 metrics
    train_losses = []
    eval_losses = []
    eval_f1s = []
    eval_precisions = []
    eval_recalls = []
    epochs = []
    logged_epochs = set()

    for log_data in log_history:
        if 'loss' in log_data:
            train_losses.append(log_data['loss'])
            epoch_num = log_data['epoch']
            if epoch_num not in logged_epochs:
                epochs.append(epoch_num)
                logged_epochs.add(epoch_num)
        if 'eval_loss' in log_data:
            eval_losses.append(log_data['eval_loss'])
            epoch_num = log_data['epoch']
            if epoch_num not in logged_epochs:
                epochs.append(epoch_num)
                logged_epochs.add(epoch_num)
        if 'eval_f1' in log_data:
            eval_f1s.append(log_data['eval_f1'])
        if 'eval_precision' in log_data:
            eval_precisions.append(log_data['eval_precision'])
        if 'eval_recall' in log_data:
            eval_recalls.append(log_data['eval_recall'])

    epochs = sorted(list(set(epochs)))

    model.save_pretrained(cfg.output_dir)
    tokenizer.save_pretrained(cfg.output_dir)

    final_eval_metrics = trainer.state.log_history[-1]
    final_report_string = final_eval_metrics.get("eval_report_string")
    if final_report_string:
        print("\nFinal Evaluation Report:\n", final_report_string)

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm
Some weights of BERT_CRF were not initialized from the model checkpoint at ../models/bert_pretrained and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'crf.end_transitions', 'crf.start_transitions', 'crf.transitions']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-03-06 21:08:33 - INFO - 数据集划分结果：
- 总序列数：756
- 训练集序列数：604
- 验证集序列数：152


训练集标签分布: Counter({0: 239305, 3: 8373, 1: 8370, 7: 3353, 5: 3351, 37: 3065, 39: 3065, 2: 2551, 20: 2539, 21: 1955, 23: 1955, 4: 1923, 34: 1457, 42: 1159, 43: 1084, 41: 1083, 8: 930, 45: 923, 47: 923, 13: 877, 15: 874, 35: 827, 33: 824, 24: 712, 38: 598, 6: 521, 25: 514, 27: 514, 40: 435, 22: 403, 14: 393, 29: 376, 30: 376, 31: 376, 16: 209, 26: 165, 17: 163, 19: 162, 46: 129, 28: 68, 32: 45, 44: 39, 9: 29, 11: 29, 12: 21, 10: 7, 48: 3, 18: 1})
验证集标签分布: Counter({0: 60126, 1: 2180, 3: 2177, 37: 1285, 39: 1285, 5: 1003, 7: 1001, 20: 652, 2: 617, 21: 576, 23: 576, 4: 495, 34: 419, 42: 339, 38: 319, 41: 294, 43: 293, 8: 266, 33: 249, 35: 247, 24: 209, 40: 198, 15: 191, 13: 188, 45: 183, 47: 183, 6: 153, 22: 132, 29: 96, 31: 96, 30: 95, 14: 76, 27: 56, 25: 55, 16: 44, 46: 39, 19: 27, 17: 26, 26: 24, 32: 19, 9: 11, 11: 11, 44: 8, 10: 3, 28: 3, 48: 2})


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,No log,651.35732,0.0,0.0,0.0
2,1489.211500,411.737345,0.36462,0.192379,0.251869
3,495.519100,187.08522,0.683447,0.672335,0.677845
4,247.709000,117.309672,0.782515,0.813206,0.797565
5,117.007000,96.968114,0.821626,0.828969,0.825281
6,72.522200,76.789101,0.839323,0.855157,0.847166
7,50.163300,75.780355,0.830058,0.865086,0.84721
8,40.954500,80.875638,0.845136,0.85615,0.850607
9,31.071100,74.667219,0.843071,0.872161,0.857369
10,26.660400,75.319303,0.847912,0.879484,0.863409


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.00      0.00      0.00       679
          NO       0.00      0.00      0.00       235
          NR       0.00      0.00      0.00      2680
          NS       0.00      0.00      0.00      1269
           T       0.00      0.00      0.00       787
          ZA       0.00      0.00      0.00       185
          ZD       0.00      0.00      0.00        59
          ZF       0.00      0.00      0.00       249
          ZP       0.00      0.00      0.00      1485
          ZS       0.00      0.00      0.00       303
          ZZ       0.00      0.00      0.00       115

   micro avg       0.00      0.00      0.00      8057
   macro avg       0.00      0.00      0.00      8057
weighted avg       0.00      0.00      0.00      8057


Confusion matrix image for epoch 1 saved to ../../../hy-tmp/models/Total_bertlr5e-5_crflr5e-3_cosine/confusion_matrices/confusion_ma

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.00      0.00      0.00       679
          NO       0.00      0.00      0.00       235
          NR       0.36      0.56      0.44      2680
          NS       1.00      0.00      0.00      1269
           T       0.00      0.00      0.00       787
          ZA       0.00      0.00      0.00       185
          ZD       0.00      0.00      0.00        59
          ZF       0.00      0.00      0.00       249
          ZP       0.36      0.04      0.06      1485
          ZS       0.00      0.00      0.00       303
          ZZ       0.00      0.00      0.00       115

   micro avg       0.36      0.19      0.25      8057
   macro avg       0.14      0.05      0.04      8057
weighted avg       0.34      0.19      0.16      8057


Confusion matrix image for epoch 2 saved to ../../../hy-tmp/models/Total_bertlr5e-5_crflr5e-3_cosine/confusion_matrices/confusion_ma

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.76      0.88      0.81       679
          NO       0.00      0.00      0.00       235
          NR       0.79      0.88      0.83      2680
          NS       0.68      0.68      0.68      1269
           T       0.68      0.50      0.58       787
          ZA       0.67      0.45      0.54       185
          ZD       0.00      0.00      0.00        59
          ZF       0.43      0.41      0.42       249
          ZP       0.53      0.68      0.60      1485
          ZS       0.35      0.02      0.04       303
          ZZ       0.00      0.00      0.00       115

   micro avg       0.68      0.67      0.68      8057
   macro avg       0.41      0.38      0.38      8057
weighted avg       0.64      0.67      0.65      8057


Confusion matrix image for epoch 3 saved to ../../../hy-tmp/models/Total_bertlr5e-5_crflr5e-3_cosine/confusion_matrices/confusion_ma

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.82      0.91      0.86       679
          NO       0.69      0.59      0.64       235
          NR       0.83      0.93      0.88      2680
          NS       0.73      0.77      0.75      1269
           T       0.76      0.71      0.73       787
          ZA       0.52      0.68      0.59       185
          ZD       0.39      0.20      0.27        59
          ZF       0.78      0.86      0.82       249
          ZP       0.85      0.84      0.85      1485
          ZS       0.49      0.51      0.50       303
          ZZ       1.00      0.01      0.02       115

   micro avg       0.78      0.81      0.80      8057
   macro avg       0.66      0.58      0.57      8057
weighted avg       0.78      0.81      0.79      8057


Confusion matrix image for epoch 4 saved to ../../../hy-tmp/models/Total_bertlr5e-5_crflr5e-3_cosine/confusion_matrices/confusion_ma

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.88      0.91      0.90       679
          NO       0.75      0.62      0.68       235
          NR       0.86      0.93      0.90      2680
          NS       0.84      0.80      0.82      1269
           T       0.82      0.79      0.80       787
          ZA       0.71      0.73      0.72       185
          ZD       0.36      0.46      0.41        59
          ZF       0.82      0.88      0.85       249
          ZP       0.89      0.81      0.85      1485
          ZS       0.37      0.46      0.41       303
          ZZ       0.60      0.46      0.52       115

   micro avg       0.82      0.83      0.83      8057
   macro avg       0.66      0.65      0.65      8057
weighted avg       0.82      0.83      0.83      8057


Confusion matrix image for epoch 5 saved to ../../../hy-tmp/models/Total_bertlr5e-5_crflr5e-3_cosine/confusion_matrices/confusion_ma

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.87      0.94      0.90       679
          NO       0.66      0.75      0.70       235
          NR       0.90      0.93      0.91      2680
          NS       0.83      0.82      0.82      1269
           T       0.84      0.81      0.82       787
          ZA       0.68      0.81      0.74       185
          ZD       0.52      0.63      0.57        59
          ZF       0.88      0.89      0.88       249
          ZP       0.88      0.85      0.87      1485
          ZS       0.50      0.59      0.54       303
          ZZ       0.68      0.50      0.57       115

   micro avg       0.84      0.86      0.85      8057
   macro avg       0.69      0.71      0.70      8057
weighted avg       0.84      0.86      0.85      8057


Confusion matrix image for epoch 6 saved to ../../../hy-tmp/models/Total_bertlr5e-5_crflr5e-3_cosine/confusion_matrices/confusion_ma

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.89      0.94      0.91       679
          NO       0.71      0.67      0.69       235
          NR       0.90      0.92      0.91      2680
          NS       0.83      0.83      0.83      1269
           T       0.79      0.85      0.82       787
          ZA       0.64      0.85      0.73       185
          ZD       0.48      0.69      0.57        59
          ZF       0.83      0.91      0.87       249
          ZP       0.88      0.87      0.87      1485
          ZS       0.51      0.64      0.57       303
          ZZ       0.47      0.61      0.53       115

   micro avg       0.83      0.87      0.85      8057
   macro avg       0.66      0.73      0.69      8057
weighted avg       0.84      0.87      0.85      8057


Confusion matrix image for epoch 7 saved to ../../../hy-tmp/models/Total_bertlr5e-5_crflr5e-3_cosine/confusion_matrices/confusion_ma

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.92      0.93      0.92       679
          NO       0.75      0.67      0.71       235
          NR       0.91      0.92      0.91      2680
          NS       0.84      0.84      0.84      1269
           T       0.84      0.83      0.83       787
          ZA       0.78      0.75      0.76       185
          ZD       0.46      0.81      0.59        59
          ZF       0.88      0.92      0.90       249
          ZP       0.89      0.84      0.86      1485
          ZS       0.46      0.69      0.55       303
          ZZ       0.56      0.62      0.59       115

   micro avg       0.85      0.86      0.85      8057
   macro avg       0.69      0.73      0.71      8057
weighted avg       0.85      0.86      0.85      8057


Confusion matrix image for epoch 8 saved to ../../../hy-tmp/models/Total_bertlr5e-5_crflr5e-3_cosine/confusion_matrices/confusion_ma

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.90      0.95      0.93       679
          NO       0.72      0.70      0.71       235
          NR       0.90      0.92      0.91      2680
          NS       0.85      0.84      0.84      1269
           T       0.83      0.84      0.83       787
          ZA       0.69      0.84      0.76       185
          ZD       0.55      0.78      0.65        59
          ZF       0.87      0.92      0.89       249
          ZP       0.88      0.89      0.88      1485
          ZS       0.50      0.64      0.56       303
          ZZ       0.55      0.62      0.58       115

   micro avg       0.84      0.87      0.86      8057
   macro avg       0.69      0.74      0.71      8057
weighted avg       0.85      0.87      0.86      8057


Confusion matrix image for epoch 9 saved to ../../../hy-tmp/models/Total_bertlr5e-5_crflr5e-3_cosine/confusion_matrices/confusion_ma