In [None]:
# train_softmax.py
import random
import numpy as np
import torch
from transformers import BertConfig, Trainer, TrainingArguments, EarlyStoppingCallback, EvalPrediction
from seqeval.metrics import classification_report
from transformers.trainer_utils import EvalLoopOutput
from transformers import AutoTokenizer, get_scheduler
from bert_softmax_data_processing import prepare_datasets, generate_label_map
from bert_softmax_model import BERT_Softmax # 修改模型导入
import matplotlib.pyplot as plt
import os
from sklearn.metrics import confusion_matrix
import pandas as pd

# 生成标签映射
label_map = generate_label_map()
label_map_inv = {v: k for k, v in label_map.items()}

class Config:
    # 设置随机数种子
    seed_val = 42
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    max_length = 512
    batch_size = 32
    num_labels = 49  # 假设标签数是12*4+1=49，请根据实际标签数修改
    model_name = "../models/bert_pretrained"
    data_path = "../datasets/train/ablation_trainset_total.txt"
    output_dir = "../../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear" # 修改输出目录

    num_epochs = 30
    weight_decay = 0.01
    bert_lr = 5e-5
    learning_rate = 5e-5 #  BERT-Softmax 通常使用一个学习率就足够了
    warmup_ratio = 0.1

    lr_scheduler_type = "linear" # linear 似乎比cosine更好
    early_stopping_patience = 5
    max_grad_norm = 1.0

class NERTrainer(Trainer):
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        with torch.no_grad():
            outputs = model(**inputs)
            loss = outputs.get("loss")
            # Softmax 直接从 logits 获取预测结果
            logits = outputs["logits"]
            predictions = torch.argmax(logits, dim=-1).cpu().numpy()
        return (loss, predictions, inputs["labels"].cpu().numpy()) # 修改返回值为 predictions

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs = model(**inputs)
        loss = outputs["loss"]
        return (loss, outputs) if return_outputs else loss

    def evaluation_loop(self, dataloader, description, prediction_loss_only=False, ignore_keys=None, metric_key_prefix="eval"):
        model = self.model.eval()
        total_loss = 0.0
        all_preds = []
        all_labels = []
        all_attention_masks = [] # 存储 attention_mask

        for batch in dataloader:
            inputs = self._prepare_inputs(batch)
            with torch.no_grad():
                loss, predictions, labels = self.prediction_step(model, inputs, prediction_loss_only) # prediction_step 直接返回 predictions
                total_loss += loss.item()
                all_preds.extend(predictions) # predictions 是 numpy 数组，直接 extend
                all_labels.extend(labels)
                all_attention_masks.extend(inputs["attention_mask"].cpu().numpy()) # 存储 attention_mask

        # 过滤填充标签并计算指标
        true_labels = []
        pred_labels = []
        for pred_seq, true_seq, mask in zip(all_preds, all_labels, all_attention_masks): # 使用存储的 attention_mask
            temp_true_labels = []
            temp_pred_labels = []
            for l, p, m in zip(true_seq, pred_seq, mask):
                if m == 1 and l != -100: # 过滤 padding mask and ignore_index
                    temp_true_labels.append(label_map_inv[l])
                    temp_pred_labels.append(label_map_inv[p])
            true_labels.append(temp_true_labels)
            pred_labels.append(temp_pred_labels)


        report = classification_report(true_labels, pred_labels, output_dict=True)
        # 输出指标查看
        print(classification_report(true_labels, pred_labels))

        metrics = {
            f"{metric_key_prefix}_loss": total_loss / len(dataloader),
            f"{metric_key_prefix}_precision": report["micro avg"]["precision"],
            f"{metric_key_prefix}_recall": report["micro avg"]["recall"],
            f"{metric_key_prefix}_f1": report["micro avg"]["f1-score"],
        }

        # 生成混淆矩阵 (绘制图像并保存), 代码与 BERT+CRF 的 train.py 相同，无需修改
        true_entities = []
        pred_entities = []
        for tl_seq, pl_seq in zip(true_labels, pred_labels):
            for tl, pl in zip(tl_seq, pl_seq):
                if tl == 'O':
                    true_ent = 'O'
                else:
                    parts = tl.split('-', 1)
                    true_ent = parts[1] if len(parts) > 1 else 'O'
                if pl == 'O':
                    pred_ent = 'O'
                else:
                    parts = pl.split('-', 1)
                    pred_ent = parts[1] if len(parts) > 1 else 'O'
                true_entities.append(true_ent)
                pred_entities.append(pred_ent)

        entity_types = sorted(list(set(true_entities + pred_entities)))
        cm = confusion_matrix(true_entities, pred_entities, labels=entity_types)
        cm_df = pd.DataFrame(cm, index=entity_types, columns=entity_types)

        # ---  绘制混淆矩阵热图 ---  代码与 BERT+CRF 的 train.py 相同，无需修改
        fig = plt.figure(figsize=(12, 10)) # 可调整图像大小
        ax = fig.add_subplot(111)
        cax = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) # 使用蓝色的 colormap
        fig.colorbar(cax) # 添加颜色栏
        tick_marks = np.arange(len(entity_types))
        ax.set_xticks(tick_marks)
        ax.set_yticks(tick_marks)
        ax.set_xticklabels(entity_types, rotation=45, ha="right") # 旋转 x 轴标签，使其不重叠
        ax.set_yticklabels(entity_types)

        # 在混淆矩阵上显示数值
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                ax.text(j, i, cm[i, j], ha="center", va="center", color='black' if cm[i,j] < cm.max()/2 else 'white') # 根据数值大小调整颜色

        plt.tight_layout() # 自动调整子图参数，提供图像边缘周围的填充
        plt.ylabel('True Entity')
        plt.xlabel('Predicted Entity')
        plt.title(f'Confusion Matrix - Epoch {int(self.state.epoch)}') # 添加标题

        # 获取当前epoch数
        epoch_num = int(self.state.epoch)
        save_dir = os.path.join(self.args.output_dir, 'confusion_matrices')
        os.makedirs(save_dir, exist_ok=True)
        filename = f"confusion_matrix_epoch_{epoch_num}.png" # 保存为 PNG 图片
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath) # 保存图像
        plt.close(fig) # 关闭图像，释放内存
        print(f"\nConfusion matrix image for epoch {epoch_num} saved to {filepath}")


        return EvalLoopOutput(
            predictions=all_preds,
            label_ids=all_labels,
            metrics=metrics,
            num_samples=len(all_labels),
        )

# compute_metrics 函数基本相同，只需要根据 prediction_step 的输出进行调整
def compute_metrics(eval_preds: EvalPrediction):
    pred_tags = eval_preds.predictions # 已经是预测标签，不再是 logits
    true_labels = eval_preds.label_ids
    # attention_mask = eval_preds.inputs.get("attention_mask") #  移除获取 attention_mask 的代码

    true_labels_list = []
    pred_labels_list = []
    for true_seq, pred_seq in zip(true_labels, pred_tags): #  移除 mask 参数
        # true_seq_filtered = [label_map_inv[l] for l, m in zip(true_seq, mask) if l != -100 and m == 1] # 移除 mask 相关条件
        # pred_seq_filtered = [label_map_inv[p] for p, m in zip(pred_seq, mask) if m == 1] # 移除 mask 相关条件
        true_seq_filtered = [label_map_inv[l] for l in true_seq if l != -100] # 只过滤 -100 标签
        pred_seq_filtered = [label_map_inv[p] for p in pred_seq] # pred_seq 不再需要 mask 过滤，evaluation_loop 已处理
        true_labels_list.append(true_seq_filtered)
        pred_labels_list.append(pred_seq_filtered)

    report_dict = classification_report(true_labels_list, pred_labels_list, output_dict=True)
    report_string = classification_report(true_labels_list, pred_labels_list)

    metrics = {
        "eval_precision": report_dict["micro avg"]["precision"],
        "eval_recall": report_dict["micro avg"]["recall"],
        "eval_f1": report_dict["micro avg"]["f1-score"],
        "eval_report_string": report_string,
    }
    return metrics

def collate_fn(batch):
    return {
        "input_ids": torch.stack([x["input_ids"] for x in batch]),
        "attention_mask": torch.stack([x["attention_mask"] for x in batch]),
        "labels": torch.stack([x["labels"] for x in batch]),
    }

def main():
    cfg = Config()
    random.seed(cfg.seed_val)
    np.random.seed(cfg.seed_val)
    torch.manual_seed(cfg.seed_val)
    torch.cuda.manual_seed_all(cfg.seed_val)

    # 初始化模型
    pre_config = BertConfig.from_pretrained(cfg.model_name, num_labels=cfg.num_labels)
    model = BERT_Softmax.from_pretrained(cfg.model_name, config=pre_config).to(cfg.device) # 修改模型类

    # 加载Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
    train_ds, val_ds = prepare_datasets(cfg.data_path, tokenizer)

    optimizer_grouped_parameters = [
        {
            "name": "bert_base",
            "params": [p for n, p in model.bert.named_parameters() if p.requires_grad],
            "lr": cfg.bert_lr,
        },
        {
            "params": [p for n, p in model.classifier.named_parameters() if p.requires_grad], # 假设分类器层名为 classifier
            "lr": cfg.learning_rate, # 使用 learning_rate
        }
    ]
    # 创建优化器，使用 AdamW，并将分组参数传递进去
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, weight_decay=cfg.weight_decay)

    num_training_steps = len(train_ds) // cfg.batch_size * cfg.num_epochs # 粗略估计，根据实际情况调整
    num_warmup_steps = int(cfg.warmup_ratio * num_training_steps)

    lr_scheduler = get_scheduler(
        name=cfg.lr_scheduler_type, # 可以尝试 "cosine"
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

    # 训练参数
    training_args = TrainingArguments(
        output_dir=cfg.output_dir,
        per_device_train_batch_size=cfg.batch_size,
        per_device_eval_batch_size=cfg.batch_size * 2,
        num_train_epochs=cfg.num_epochs,
        weight_decay=cfg.weight_decay,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_f1",
        remove_unused_columns=False,
        fp16=False,
        max_grad_norm=cfg.max_grad_norm,
        logging_steps=20,
        seed=cfg.seed_val,
        data_seed=cfg.seed_val,
    )

    trainer = NERTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        compute_metrics=compute_metrics,
        data_collator=collate_fn,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=cfg.early_stopping_patience)], # 确保早停回调被正确添加
        optimizers=(optimizer, lr_scheduler),
    )

    # 存储训练过程中的 loss 和 metrics，代码与 BERT+CRF 的 train.py 相同，无需修改
    train_history = trainer.train()

    # 从 trainer 的 state 中获取日志信息，代码与 BERT+CRF 的 train.py 相同，无需修改
    log_history = trainer.state.log_history

    # 提取训练 loss 和 验证 metrics，代码与 BERT+CRF 的 train.py 相同，无需修改
    train_losses = []
    eval_losses = []
    eval_f1s = []
    eval_precisions = []
    eval_recalls = []
    epochs = []
    logged_epochs = set()

    for log_data in log_history:
        if 'loss' in log_data:
            train_losses.append(log_data['loss'])
            epoch_num = log_data['epoch']
            if epoch_num not in logged_epochs:
                epochs.append(epoch_num)
                logged_epochs.add(epoch_num)
        if 'eval_loss' in log_data:
            eval_losses.append(log_data['eval_loss'])
            epoch_num = log_data['epoch']
            if epoch_num not in logged_epochs:
                epochs.append(epoch_num)
                logged_epochs.add(epoch_num)
        if 'eval_f1' in log_data:
            eval_f1s.append(log_data['eval_f1'])
        if 'eval_precision' in log_data:
            eval_precisions.append(log_data['eval_precision'])
        if 'eval_recall' in log_data:
            eval_recalls.append(log_data['eval_recall'])

    epochs = sorted(list(set(epochs)))

    model.save_pretrained(cfg.output_dir)
    tokenizer.save_pretrained(cfg.output_dir)

    final_eval_metrics = trainer.state.log_history[-1]
    final_report_string = final_eval_metrics.get("eval_report_string")
    if final_report_string:
        print("\nFinal Evaluation Report:\n", final_report_string)

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm
Some weights of BERT_Softmax were not initialized from the model checkpoint at ../models/bert_pretrained and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,No log,1.30317,0.0,0.0,0.0
2,2.899300,0.770489,0.361539,0.340449,0.350678
3,0.959600,0.369795,0.668624,0.664143,0.666376
4,0.489700,0.26119,0.744069,0.790245,0.766462
5,0.256600,0.211517,0.803589,0.817053,0.810265
6,0.175400,0.207994,0.778029,0.854412,0.814434
7,0.139300,0.178161,0.810547,0.856522,0.832901
8,0.117200,0.187531,0.813757,0.859005,0.835769
9,0.093300,0.192608,0.801514,0.867569,0.833234
10,0.083900,0.184402,0.819176,0.864217,0.841094


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.00      0.00      0.00       679
          NO       0.00      0.00      0.00       235
          NR       0.00      0.00      0.00      2680
          NS       0.00      0.00      0.00      1269
           T       0.00      0.00      0.00       787
          ZA       0.00      0.00      0.00       185
          ZD       0.00      0.00      0.00        59
          ZF       0.00      0.00      0.00       249
          ZP       0.00      0.00      0.00      1485
          ZS       0.00      0.00      0.00       303
          ZZ       0.00      0.00      0.00       115

   micro avg       0.00      0.00      0.00      8057
   macro avg       0.00      0.00      0.00      8057
weighted avg       0.00      0.00      0.00      8057


Confusion matrix image for epoch 1 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matrix

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.00      0.00      0.00       679
          NO       0.00      0.00      0.00       235
          NR       0.32      0.68      0.43      2680
          NS       0.00      0.00      0.00      1269
           T       0.00      0.00      0.00       787
          ZA       0.00      0.00      0.00       185
          ZD       0.00      0.00      0.00        59
          ZF       0.00      0.00      0.00       249
          ZP       0.50      0.62      0.55      1485
          ZS       0.00      0.00      0.00       303
          ZZ       0.00      0.00      0.00       115

   micro avg       0.36      0.34      0.35      8057
   macro avg       0.07      0.11      0.08      8057
weighted avg       0.20      0.34      0.25      8057


Confusion matrix image for epoch 2 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matrix

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.74      0.89      0.81       679
          NO       0.00      0.00      0.00       235
          NR       0.80      0.88      0.84      2680
          NS       0.68      0.67      0.67      1269
           T       0.62      0.55      0.58       787
          ZA       0.38      0.22      0.28       185
          ZD       0.00      0.00      0.00        59
          ZF       0.22      0.22      0.22       249
          ZP       0.52      0.68      0.59      1485
          ZS       0.25      0.02      0.04       303
          ZZ       0.00      0.00      0.00       115

   micro avg       0.67      0.66      0.67      8057
   macro avg       0.35      0.34      0.34      8057
weighted avg       0.62      0.66      0.63      8057


Confusion matrix image for epoch 3 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matrix

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.82      0.91      0.86       679
          NO       0.51      0.47      0.49       235
          NR       0.81      0.94      0.87      2680
          NS       0.68      0.75      0.71      1269
           T       0.66      0.70      0.68       787
          ZA       0.47      0.62      0.53       185
          ZD       0.07      0.02      0.03        59
          ZF       0.73      0.85      0.78       249
          ZP       0.84      0.84      0.84      1485
          ZS       0.26      0.18      0.22       303
          ZZ       0.00      0.00      0.00       115

   micro avg       0.74      0.79      0.77      8057
   macro avg       0.49      0.52      0.50      8057
weighted avg       0.72      0.79      0.75      8057


Confusion matrix image for epoch 4 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matrix

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.85      0.92      0.89       679
          NO       0.65      0.63      0.64       235
          NR       0.87      0.93      0.90      2680
          NS       0.82      0.79      0.80      1269
           T       0.80      0.79      0.79       787
          ZA       0.49      0.71      0.58       185
          ZD       0.24      0.34      0.28        59
          ZF       0.76      0.88      0.81       249
          ZP       0.86      0.84      0.85      1485
          ZS       0.28      0.20      0.23       303
          ZZ       0.44      0.26      0.33       115

   micro avg       0.80      0.82      0.81      8057
   macro avg       0.59      0.61      0.59      8057
weighted avg       0.80      0.82      0.81      8057


Confusion matrix image for epoch 5 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matrix

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.84      0.94      0.89       679
          NO       0.61      0.66      0.63       235
          NR       0.84      0.93      0.88      2680
          NS       0.78      0.83      0.80      1269
           T       0.78      0.82      0.80       787
          ZA       0.61      0.81      0.70       185
          ZD       0.26      0.56      0.35        59
          ZF       0.80      0.90      0.85       249
          ZP       0.87      0.84      0.85      1485
          ZS       0.39      0.63      0.48       303
          ZZ       0.47      0.42      0.44       115

   micro avg       0.78      0.85      0.81      8057
   macro avg       0.60      0.69      0.64      8057
weighted avg       0.79      0.85      0.82      8057


Confusion matrix image for epoch 6 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matrix

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.87      0.93      0.90       679
          NO       0.66      0.70      0.68       235
          NR       0.90      0.91      0.90      2680
          NS       0.80      0.83      0.82      1269
           T       0.80      0.83      0.82       787
          ZA       0.66      0.81      0.72       185
          ZD       0.35      0.56      0.43        59
          ZF       0.84      0.91      0.88       249
          ZP       0.86      0.87      0.87      1485
          ZS       0.41      0.65      0.50       303
          ZZ       0.48      0.48      0.48       115

   micro avg       0.81      0.86      0.83      8057
   macro avg       0.64      0.71      0.67      8057
weighted avg       0.82      0.86      0.84      8057


Confusion matrix image for epoch 7 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matrix

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.88      0.93      0.91       679
          NO       0.70      0.68      0.69       235
          NR       0.89      0.92      0.91      2680
          NS       0.81      0.83      0.82      1269
           T       0.80      0.84      0.82       787
          ZA       0.68      0.77      0.72       185
          ZD       0.45      0.66      0.54        59
          ZF       0.84      0.91      0.87       249
          ZP       0.87      0.87      0.87      1485
          ZS       0.38      0.67      0.49       303
          ZZ       0.49      0.54      0.51       115

   micro avg       0.81      0.86      0.84      8057
   macro avg       0.65      0.72      0.68      8057
weighted avg       0.83      0.86      0.84      8057


Confusion matrix image for epoch 8 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matrix

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.88      0.93      0.90       679
          NO       0.64      0.67      0.66       235
          NR       0.89      0.93      0.91      2680
          NS       0.80      0.82      0.81      1269
           T       0.79      0.84      0.82       787
          ZA       0.60      0.90      0.72       185
          ZD       0.45      0.66      0.54        59
          ZF       0.79      0.92      0.85       249
          ZP       0.87      0.88      0.87      1485
          ZS       0.40      0.68      0.50       303
          ZZ       0.46      0.59      0.52       115

   micro avg       0.80      0.87      0.83      8057
   macro avg       0.63      0.74      0.68      8057
weighted avg       0.81      0.87      0.84      8057


Confusion matrix image for epoch 9 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matrix

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.87      0.93      0.90       679
          NO       0.65      0.68      0.67       235
          NR       0.89      0.93      0.91      2680
          NS       0.80      0.84      0.82      1269
           T       0.82      0.83      0.82       787
          ZA       0.67      0.81      0.73       185
          ZD       0.51      0.63      0.56        59
          ZF       0.87      0.90      0.88       249
          ZP       0.87      0.87      0.87      1485
          ZS       0.45      0.67      0.54       303
          ZZ       0.55      0.55      0.55       115

   micro avg       0.82      0.86      0.84      8057
   macro avg       0.66      0.72      0.69      8057
weighted avg       0.83      0.86      0.84      8057


Confusion matrix image for epoch 10 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matri

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.89      0.93      0.91       679
          NO       0.64      0.72      0.68       235
          NR       0.91      0.91      0.91      2680
          NS       0.82      0.83      0.83      1269
           T       0.81      0.83      0.82       787
          ZA       0.73      0.70      0.71       185
          ZD       0.51      0.61      0.56        59
          ZF       0.84      0.90      0.87       249
          ZP       0.87      0.85      0.86      1485
          ZS       0.50      0.67      0.57       303
          ZZ       0.56      0.57      0.57       115

   micro avg       0.84      0.85      0.84      8057
   macro avg       0.67      0.71      0.69      8057
weighted avg       0.84      0.85      0.85      8057


Confusion matrix image for epoch 11 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matri

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.90      0.93      0.92       679
          NO       0.62      0.71      0.66       235
          NR       0.90      0.92      0.91      2680
          NS       0.83      0.84      0.83      1269
           T       0.81      0.84      0.82       787
          ZA       0.69      0.74      0.71       185
          ZD       0.52      0.78      0.63        59
          ZF       0.84      0.90      0.87       249
          ZP       0.87      0.87      0.87      1485
          ZS       0.46      0.66      0.55       303
          ZZ       0.53      0.63      0.57       115

   micro avg       0.83      0.86      0.85      8057
   macro avg       0.66      0.73      0.70      8057
weighted avg       0.83      0.86      0.85      8057


Confusion matrix image for epoch 12 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matri

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.91      0.93      0.92       679
          NO       0.62      0.72      0.67       235
          NR       0.90      0.93      0.91      2680
          NS       0.83      0.84      0.83      1269
           T       0.81      0.85      0.83       787
          ZA       0.66      0.81      0.73       185
          ZD       0.48      0.73      0.58        59
          ZF       0.84      0.90      0.87       249
          ZP       0.87      0.86      0.87      1485
          ZS       0.45      0.66      0.54       303
          ZZ       0.46      0.64      0.54       115

   micro avg       0.82      0.87      0.84      8057
   macro avg       0.65      0.74      0.69      8057
weighted avg       0.83      0.87      0.85      8057


Confusion matrix image for epoch 13 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matri

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.90      0.94      0.92       679
          NO       0.67      0.69      0.68       235
          NR       0.90      0.92      0.91      2680
          NS       0.83      0.83      0.83      1269
           T       0.80      0.84      0.82       787
          ZA       0.68      0.78      0.73       185
          ZD       0.53      0.75      0.62        59
          ZF       0.83      0.91      0.87       249
          ZP       0.87      0.87      0.87      1485
          ZS       0.51      0.64      0.57       303
          ZZ       0.59      0.59      0.59       115

   micro avg       0.84      0.86      0.85      8057
   macro avg       0.68      0.73      0.70      8057
weighted avg       0.84      0.86      0.85      8057


Confusion matrix image for epoch 14 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matri

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          NB       0.00      0.00      0.00        11
          NG       0.93      0.93      0.93       679
          NO       0.67      0.68      0.68       235
          NR       0.91      0.91      0.91      2680
          NS       0.83      0.84      0.83      1269
           T       0.79      0.84      0.82       787
          ZA       0.70      0.76      0.73       185
          ZD       0.51      0.75      0.60        59
          ZF       0.84      0.90      0.87       249
          ZP       0.88      0.86      0.87      1485
          ZS       0.46      0.63      0.53       303
          ZZ       0.53      0.62      0.57       115

   micro avg       0.83      0.86      0.84      8057
   macro avg       0.67      0.73      0.69      8057
weighted avg       0.84      0.86      0.85      8057


Confusion matrix image for epoch 15 saved to ../../hy-tmp/models/Total_bert_softmax_lr5e-5_linear/confusion_matrices/confusion_matri