In [14]:
import torch
import torch.nn as nn
import os
import wandb
import random
import argparse
import numpy as np
from tqdm import tqdm
from transformers import BertModel, AutoModel
from transformers import AdamW

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据预处理部分

In [4]:
from torch.utils.data import Dataset
import json

class ErrorDetectionDataset(Dataset):
    def __init__(
            self,
            data_path,
            coarse_labels={
                "字符级错误": 0,
                "成分残缺型错误": 1, 
                "成分赘余型错误": 2,
                "成分搭配不当型错误": 3
            },
            fine_labels={
                "缺字漏字": 0,
                "错别字错误": 1,
                "缺少标点": 2,
                "错用标点": 3,
                "主语不明": 4,
                "谓语残缺": 5,
                "宾语残缺": 6,
                "其他成分残缺": 7,
                "主语多余": 8,
                "虚词多余": 9,
                "其他成分多余": 10,
                "语序不当": 11,
                "动宾搭配不当": 12,
                "其他搭配不当": 13
            }
    ):
        self.data_path = data_path
        self.coarse_labels = coarse_labels
        self.fine_labels = fine_labels
        self._get_data()
    
    def _get_data(self):
        with open(self.data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        self.data = []
        for item in data:
            sent = item['sent']
            coarse_types = item['CourseGrainedErrorType']
            fine_types = item['FineGrainedErrorType']
            
            # 构建粗粒度标签（多标签）
            coarse_label = [0] * len(self.coarse_labels)
            for c_type in coarse_types:
                if c_type in self.coarse_labels:
                    coarse_label[self.coarse_labels[c_type]] = 1
            
            # 构建细粒度标签（多标签）
            fine_label = [0] * len(self.fine_labels)
            for f_type in fine_types:
                if f_type in self.fine_labels:
                    fine_label[self.fine_labels[f_type]] = 1
            
            self.data.append((sent, coarse_label, fine_label, item.get('sent_id', -1)))
    
    def __len__(self):
        return len(self.data)
    
    def get_coarse_labels(self):
        return self.coarse_labels
    
    def get_fine_labels(self):
        return self.fine_labels
    
    def __getitem__(self, idx):
        return self.data[idx]

# 数据加载部分

In [5]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

class ErrorDetectionDataLoader:
    def __init__(
        self,
        dataset,
        batch_size=16,
        max_length=128,
        shuffle=True,
        drop_last=True,
        device=None,
        tokenizer_name='./models/bge-large-zh-v1.5'
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.dataset = dataset
        self.batch_size = batch_size
        self.max_length = max_length
        self.shuffle = shuffle
        self.drop_last = drop_last
        
        if device is None:
            self.device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu'
            )
        else:
            self.device = device
        
        self.loader = DataLoader(
            dataset=self.dataset,
            batch_size=self.batch_size,
            collate_fn=self.collate_fn,
            shuffle=self.shuffle,
            drop_last=self.drop_last
        )
    
    def collate_fn(self, data):
        sents = [item[0] for item in data]
        coarse_labels = [item[1] for item in data]
        fine_labels = [item[2] for item in data]
        sent_ids = [item[3] for item in data]
        
        # 编码文本
        encoded = self.tokenizer.batch_encode_plus(
            batch_text_or_text_pairs=sents,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt',
            return_length=True
        )
        
        input_ids = encoded['input_ids'].to(self.device)
        attention_mask = encoded['attention_mask'].to(self.device)
        token_type_ids = encoded.get('token_type_ids', None)
        
        if token_type_ids is not None:
            token_type_ids = token_type_ids.to(self.device)
        
        # 处理标签
        if coarse_labels[0] == -1:
            coarse_labels = None
            fine_labels = None
        else:
            coarse_labels = torch.tensor(coarse_labels, dtype=torch.float).to(self.device)
            fine_labels = torch.tensor(fine_labels, dtype=torch.float).to(self.device)
        
        return input_ids, attention_mask, token_type_ids, coarse_labels, fine_labels, sent_ids
    
    def __iter__(self):
        for data in self.loader:
            yield data
    
    def __len__(self):
        return len(self.loader)

# 数据处理和划分的 校验部分  
训练集 104个sentences   验证集  27个sentences  

In [33]:
# 校验上述数据加载器出来的数据是否正确，输出数据出来看一看
train_data_path = '../datas/train.json'
val_data_path = '../datas/val.json'
train_dataset = ErrorDetectionDataset(train_data_path)
train_dataloader = ErrorDetectionDataLoader(train_dataset, batch_size=1)
val_dataset = ErrorDetectionDataset(val_data_path)
val_dataloader = ErrorDetectionDataLoader(val_dataset, batch_size=1)

cnt_train=0
for batch in train_dataloader:
    cnt_train += 1
    input_ids, attention_mask, token_type_ids, coarse_labels, fine_labels, sent_ids = batch
    print(input_ids.shape, attention_mask.shape, token_type_ids.shape, coarse_labels.shape, fine_labels.shape)
    # break
print("train data size:", cnt_train)

cnt_val=0
for batch in val_dataloader:
    cnt_val += 1
    input_ids, attention_mask, token_type_ids, coarse_labels, fine_labels, sent_ids = batch
    print(input_ids.shape, attention_mask.shape, token_type_ids.shape, coarse_labels.shape, fine_labels.shape)
    # break
print("val data size:", cnt_val)

torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 4]) torch.Size([1, 14])
torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 4]) torch.Size([1, 14])
torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 4]) torch.Size([1, 14])
torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 4]) torch.Size([1, 14])
torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 4]) torch.Size([1, 14])
torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 4]) torch.Size([1, 14])
torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 4]) torch.Size([1, 14])
torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 4]) torch.Size([1, 14])
torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 4]) torch.Size([1, 14])
torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 4]) 

In [35]:
# 检查一下dataset中处理出来的数据是否正确
for data in train_dataset:
    sent, coarse_label, fine_label, sent_id = data
    print(sent, coarse_label, fine_label, sent_id)

作为一名英语课代表，对徐老师的认识可能比同学更清楚。 [0, 1, 0, 0] [0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0] 5059
因而，保存良好的家风，摒弃有害而无益的家风，是有助于人成长的一大益事。 [0, 0, 0, 1] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] 254
早上阳光明媚，如我的心灵，丝毫不被外界的阴云所引响，只觉得蔚蓝的天空上漂浮着的云，星星点点素雅的小花，让人倍感舒爽到了中午，远处传来一丝丝的锣鼓声，一众人从小到大到我眼前，看清上面有一个方形的木盒，很大，盖住了天上的太阳。 [1, 1, 0, 0] [0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] 5612
周末时，老师邀请我与其他的同学去为母校的同学们界绍一位著名的冬奥运动员。 [1, 0, 0, 0] [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 4614
每次当我回到家时，从来也没有再见到那个向我飞扑过来的狗，也从没忘记那个向我撒娇。认人心软的狗，我家的那地地毯也被妈妈洗了，也看不见那让我永远忘不掉的梅花。 [1, 0, 1, 0] [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0] 5735
老师刚说，话音未落就见我们消失在果树之中，果树上都是橙子，我们8人分成4组，两人两把剪刀一个栏子，我们两人一起行动，另一人是小宇，我的好兄弟， [1, 0, 0, 0] [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 8885
我不仅从一个小学生变成了初中生，我也懂得了如何去诊惜生活中的每分每秒。 [1, 0, 0, 0] [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 5342
记忆的盒子瞬打开，把我拉回了相片中的时间。 [1, 0, 0, 0] [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 3774
先从后面偷偷摸摸进去，先是杀了两个厨房里的仆人，后是把”老的小的“通通“砍掉”，刀都能砍卷了。 [0, 1, 0, 0] [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0

# 处理对最后14个错误的描述的embedding

In [27]:
def get_fine_description_embedding(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path)

    # 设置为评估模式
    model.eval()

    # 存储所有句子的向量表示
    fine_description = []

    # 获取每一个细粒度错误的详细描述
    sentences = [
        "句子中缺少字（需要添加）",
        "句子中出现错别字（需要修改或删除）",
        "应断句的地方没有使用标点把句子断开",
        "标点使用错误，如本来应使用句号却分号或括号使用错误",
        "句子缺少主语，或主语不清晰，修改是要增加主语或使主语显现",
        "句子缺少谓语，修改是要增加谓语",
        "句子缺少宾语，修改是要增加宾语",
        "修改是要增加除主语、谓语、宾语之外的其他情况",
        "一般是句子较长，前一个主语说出后，紧接着有一个较长、较复杂的形容词修饰过了一会儿，我忘记了测试，老师说等消息，（去掉主语：修改是要删除主语）",
        "副词\"的\"、\"所\"的多余，需改掉或删除副词",
        "除主语、谓词之外的成分多余，修改特别容易混淆区，都是一些平时见不到的洲狮，孟加拉虎等猛兽，据说这里有300多只老虎生活在这里呢是不是很惊叹呢！（多个\"这里\"，表达重复）",
        "句子中词语或子句的顺序不合理，修改是调换某几个词汇或子句的顺序",
        "谓语与宾语搭配不当，修改是要用其他词替换句子的谓语或宾语",
        "除动宾，语序不当不当之外的其他搭配不当情况，修改是要用其他词替换句子中的某个成分"
    ]
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**encoded_input)
        
    fine_description = model_output.last_hidden_state[:, 0, :]
    return fine_description
        

 # 模型部分

In [22]:
import torch
import torch.nn as nn
from transformers import AutoModel

class HierarchicalErrorClassifier(nn.Module):
    def __init__(
        self, 
        pretrained_model_name, 
        num_coarse_labels=4, 
        num_fine_labels=14, 
        freeze_pooler=False, 
        dropout=0.2,
        fine_description=None
    ):
        super().__init__()
        
        self.freeze_pooler = freeze_pooler
        self.bert = AutoModel.from_pretrained(pretrained_model_name)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.fine_description = fine_description.to(device)
        
        if freeze_pooler:
            for param in self.bert.pooler.parameters():
                param.requires_grad = False
        
        # 粗粒度分类器
        self.coarse_classifier = nn.Linear(self.bert.config.hidden_size, num_coarse_labels)
        
        # 细粒度分类器
        self.fine_classifier = nn.Linear(self.bert.config.hidden_size, num_fine_labels)
        
        # 定义粗粒度类别和对应的细粒度索引的映射
        self.coarse_to_fine_indices = {
            0: [0, 1, 2, 3],    # 字符级错误
            1: [4, 5, 6, 7],    # 成分残缺型错误
            2: [8, 9, 10],      # 成分冗余型错误
            3: [11, 12, 13]     # 成分搭配不当型错误
        }
    
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        # BERT编码
        if token_type_ids is not None:
            outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        else:
            outputs = self.bert(input_ids, attention_mask=attention_mask)
        
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        
        # 粗粒度分类
        coarse_logits = self.coarse_classifier(pooled_output)
        coarse_probs = torch.sigmoid(coarse_logits)
        
        # # 细粒度分类
        # fine_logits = self.fine_classifier(pooled_output)
        # fine_probs = torch.sigmoid(fine_logits)
        
        # 细粒度分类，包含14个补充信息
        fine_logits = torch.matmul(pooled_output, self.fine_description.T)
        fine_probs = torch.sigmoid(fine_logits)
        
        return coarse_probs, fine_probs
    
    def apply_hierarchical_constraint(self, coarse_preds, fine_preds):
        """
        应用层次约束：如果粗粒度类别预测为负，则该粗粒度下的所有细粒度类别均设为负
        
        Args:
            coarse_preds: 粗粒度预测结果，shape [batch_size, num_coarse_labels]
            fine_preds: 细粒度预测结果，shape [batch_size, num_fine_labels]
            
        Returns:
            应用约束后的细粒度预测结果
        """
        constrained_fine_preds = fine_preds.clone()
        
        # 遍历每个样本
        for i in range(coarse_preds.size(0)):  # 遍历每个样本（第一个维度是批次大小）
            # 对每个粗粒度类别
            for coarse_idx, fine_indices in self.coarse_to_fine_indices.items():
                # 如果粗粒度为负，则对应的细粒度全部设为负
                if coarse_preds[i, coarse_idx] == 0:
                    constrained_fine_preds[i, fine_indices] = 0
        
        return constrained_fine_preds

# 训练、验证和测试部分

包的导入、参数定义、计算最终指标的函数定义

In [None]:
import os
import wandb
import random
import argparse
from tqdm import tqdm

import torch
import torch.nn as nn
import numpy as np
from transformers import AdamW
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score


# 在代码开始处禁用 wandb 记录
os.environ["WANDB_MODE"] = "disabled"

# 导入自定义模块
# from error_detection_dataset import ErrorDetectionDataset
# from error_detection_dataloader import ErrorDetectionDataLoader
# from hierarchical_classifier_model import HierarchicalErrorClassifier

# 如果分项目的时候就可以使用这个参数解析函数
def argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='./models/bge-large-zh-v1.5')
    parser.add_argument('--num_coarse_labels', type=int, default=4)
    parser.add_argument('--num_fine_labels', type=int, default=14)
    parser.add_argument('--dropout', type=float, default=0.3)
    parser.add_argument('--freeze_pooler', action='store_true', help='Flag to freeze the pooler layer')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--max_length', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--epochs', type=int, default=40)
    parser.add_argument('--device', type=str, required=False)
    parser.add_argument('--project', type=str, default='hierarchical_error_detection')
    parser.add_argument('--entity', type=str, default='akccc')
    parser.add_argument('--name', type=str, required=False)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--data_path', type=str, default='../datas/train.json')
    parser.add_argument('--val_data_path', type=str, default='../datas/val.json')
    parser.add_argument('--test_data_path', type=str, default='../dats/val.json')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints')
    parser.add_argument('--threshold', type=float, default=0.75)
    parser.add_argument('--patience', type=int, default=5)
    
    return parser.parse_args()



# 如果在Jupyter Notebook中运行，可以使用这个自定义参数函数替代argparser
def get_default_configs():
    """在Jupyter环境中使用的默认配置，避免argparse解析错误"""
    class Args:
        def __init__(self):
            self.model_name = './models/bge-large-zh-v1.5'
            self.num_coarse_labels = 4
            self.num_fine_labels = 14
            self.dropout = 0.3
            self.freeze_pooler = False
            self.batch_size = 8
            self.max_length = 128
            self.lr = 1e-5
            self.epochs = 40
            self.device = device
            self.project = 'hierarchical_error_detection'
            self.entity = 'akccc'
            self.name = None
            self.seed = 3407
            self.data_path = '../datas/train.json'
            self.val_data_path = '../datas/val.json'
            self.test_data_path = '../datas/val.json'
            self.checkpoint_dir = 'checkpoints'
            self.threshold = 0.5
            self.patience = 5
            self.exp_name = 'default_run'
    return Args()


def calculate_metrics(labels, predictions, average='micro'):
    """
    计算各种评估指标
    
    Args:
        labels: 真实标签
        predictions: 预测标签
        average: 平均方法，'micro'或'macro'
        
    Returns:
        包含各种指标的字典
    """
    # 将数组转换为numpy格式以确保兼容性
    labels = np.array(labels)
    predictions = np.array(predictions)
    
    # 计算微平均和宏平均的F1分数
    micro_f1 = f1_score(labels, predictions, average='micro')
    macro_f1 = f1_score(labels, predictions, average='macro')
    
    # 计算样本级别的准确率（每个样本的所有标签都要正确）
    sample_acc = accuracy_score(labels, predictions)
    
    return {
        'micro_f1': micro_f1 * 100,  # 转换为百分比
        'macro_f1': macro_f1 * 100,
        'accuracy': sample_acc * 100
    }

In [37]:
def train(configs):
    # 初始化wandb
    wandb.init(
        project=configs.project,
        entity=configs.entity,
        name=configs.exp_name,
    )

    # 配置wandb
    wandb_config = wandb.config
    wandb_config.model_name = configs.model_name
    wandb_config.num_coarse_labels = configs.num_coarse_labels
    wandb_config.num_fine_labels = configs.num_fine_labels
    wandb_config.dropout = configs.dropout
    wandb_config.freeze_pooler = configs.freeze_pooler
    wandb_config.batch_size = configs.batch_size
    wandb_config.max_length = configs.max_length
    wandb_config.lr = configs.lr
    wandb_config.epochs = configs.epochs
    wandb_config.device = configs.device
    wandb_config.seed = configs.seed
    wandb_config.threshold = configs.threshold

    # 设置随机种子
    random.seed(configs.seed)
    np.random.seed(configs.seed)
    torch.manual_seed(configs.seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # 创建检查点目录
    checkpoint_dir = os.path.join(configs.checkpoint_dir, configs.exp_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # 加载数据集
    train_dataset = ErrorDetectionDataset(configs.data_path)
    val_dataset = ErrorDetectionDataset(configs.val_data_path)
    test_dataset = ErrorDetectionDataset(configs.test_data_path)
    

    # 创建数据加载器
    train_dataloader = ErrorDetectionDataLoader(
        dataset=train_dataset,
        batch_size=configs.batch_size,
        max_length=configs.max_length,
        shuffle=True,
        drop_last=True,
        device=configs.device,
        tokenizer_name=configs.model_name
    )

    val_dataloader = ErrorDetectionDataLoader(
        dataset=val_dataset,
        batch_size=configs.batch_size,
        max_length=configs.max_length,
        shuffle=False,
        drop_last=False,
        device=configs.device,
        tokenizer_name=configs.model_name
    )

    test_dataloader = ErrorDetectionDataLoader(
        dataset=test_dataset,
        batch_size=configs.batch_size,
        max_length=configs.max_length,
        shuffle=False,
        drop_last=False,
        device=configs.device,
        tokenizer_name=configs.model_name
    )
    
    fine_description = get_fine_description_embedding(model_path=configs.model_name)

    # 创建模型
    model = HierarchicalErrorClassifier(
        pretrained_model_name=configs.model_name,
        num_coarse_labels=configs.num_coarse_labels,
        num_fine_labels=configs.num_fine_labels,
        dropout=configs.dropout,
        freeze_pooler=configs.freeze_pooler,
        fine_description=fine_description
    ).to(configs.device)

    # 定义损失函数
    criterion = nn.BCELoss()

    # 定义优化器
    optimizer = AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=configs.lr
    )

    # 初始化最佳验证损失和早停计数器
    best_val_f1 = 0
    patience_counter = 0
    
    # 监控模型
    wandb.watch(model, log='all')
    
    # 获取标签映射，用于后续预测结果记录
    coarse_label_map = {v: k for k, v in val_dataset.get_coarse_labels().items()}
    fine_label_map = {v: k for k, v in val_dataset.get_fine_labels().items()}
    
    # print("coarse_label_map:", coarse_label_map)
    # print("fine_label_map:", fine_label_map)    
    # # 终止运行，用来debug
    # return
    
    
    # 训练循环
    for epoch in range(configs.epochs):
        # 训练阶段
        model.train()
        train_loss = 0.0
        all_coarse_preds = []
        all_coarse_labels = []
        all_fine_preds = []
        all_fine_labels = []
        all_constrained_fine_preds = []
        
        with tqdm(
            train_dataloader,
            total=len(train_dataloader),
            desc=f'Epoch {epoch + 1}/{configs.epochs}',
            unit='batch',
            ncols=100
        ) as pbar:
            for input_ids, attention_mask, token_type_ids, coarse_labels, fine_labels, sent_ids in pbar:
                optimizer.zero_grad()
                
                # 前向传播
                coarse_probs, fine_probs = model(input_ids, attention_mask, token_type_ids)
                
                # print("coarse_probs:", coarse_probs)
                
                # 计算损失
                coarse_loss = criterion(coarse_probs, coarse_labels)
                fine_loss = criterion(fine_probs, fine_labels)
                loss = coarse_loss + fine_loss
                
                # 反向传播
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                
                # 收集预测结果
                coarse_preds = (coarse_probs > configs.threshold).float().cpu().numpy()
                fine_preds = (fine_probs > configs.threshold).float().cpu().numpy()
                constrained_fine_preds = model.apply_hierarchical_constraint(
                    (coarse_probs > configs.threshold).float(), 
                    (fine_probs > configs.threshold).float()
                ).cpu().numpy()
                
                all_coarse_preds.extend(coarse_preds)
                all_coarse_labels.extend(coarse_labels.cpu().numpy())
                all_fine_preds.extend(fine_preds)
                all_fine_labels.extend(fine_labels.cpu().numpy())
                all_constrained_fine_preds.extend(constrained_fine_preds)
                
                # 更新进度条
                pbar.set_postfix(
                    loss=f'{loss.item():.3f}',
                    coarse_loss=f'{coarse_loss.item():.3f}',
                    fine_loss=f'{fine_loss.item():.3f}'
                )
                
                # 记录到wandb
                wandb.log({
                    'batch_loss': loss.item(),
                    'batch_coarse_loss': coarse_loss.item(),
                    'batch_fine_loss': fine_loss.item()
                })
        
        # 计算训练指标
        train_loss = train_loss / len(train_dataloader)
        
        # 计算各种评估指标
        train_coarse_metrics_micro = calculate_metrics(all_coarse_labels, all_coarse_preds, average='micro')
        train_coarse_metrics_macro = calculate_metrics(all_coarse_labels, all_coarse_preds, average='macro')
        train_fine_metrics_micro = calculate_metrics(all_fine_labels, all_fine_preds, average='micro')
        train_fine_metrics_macro = calculate_metrics(all_fine_labels, all_fine_preds, average='macro')
        train_constrained_fine_metrics_micro = calculate_metrics(all_fine_labels, all_constrained_fine_preds, average='micro')
        train_constrained_fine_metrics_macro = calculate_metrics(all_fine_labels, all_constrained_fine_preds, average='macro')
        
        # 保存模型检查点
        checkpoint_path = os.path.join(checkpoint_dir, f'epoch_{epoch + 1}.pt')
        torch.save(model.state_dict(), checkpoint_path)
        wandb.save(checkpoint_path)
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        all_coarse_preds = []
        all_coarse_labels = []
        all_fine_preds = []
        all_fine_labels = []
        all_constrained_fine_preds = []
        
        # 记录验证集和测试集每个句子的真实标签和预测结果
        val_sentence_predictions = []

        with torch.no_grad():
            for input_ids, attention_mask, token_type_ids, coarse_labels, fine_labels, sent_ids in val_dataloader:
                # 前向传播
                coarse_probs, fine_probs = model(input_ids, attention_mask, token_type_ids)
                
                # 应用层次约束
                coarse_preds = (coarse_probs > configs.threshold).float()
                fine_preds = (fine_probs > configs.threshold).float()
                constrained_fine_preds = model.apply_hierarchical_constraint(coarse_preds, fine_preds)
                
                # 计算损失
                coarse_loss = criterion(coarse_probs, coarse_labels)
                fine_loss = criterion(fine_probs, fine_labels)
                loss = coarse_loss + fine_loss
                
                val_loss += loss.item()
                
                # 收集预测结果
                all_coarse_preds.extend(coarse_preds.cpu().numpy())
                all_coarse_labels.extend(coarse_labels.cpu().numpy())
                all_fine_preds.extend(fine_preds.cpu().numpy())
                all_constrained_fine_preds.extend(constrained_fine_preds.cpu().numpy())
                all_fine_labels.extend(fine_labels.cpu().numpy())
                
                # 记录当前 batch 中每个样本的预测结果和真实标签
                coarse_preds_np = coarse_preds.cpu().numpy()
                fine_preds_np = constrained_fine_preds.cpu().numpy()
                
                # 获取当前批次的原始句子
                # 由于sent_ids可能是数字或者直接是句子标识符，视情况处理
                # 这里我们使用索引从数据集中获取句子
                batch_sentences = []
                for sid in sent_ids:
                    # 如果sent_ids是数字索引
                    if isinstance(sid, int) or (isinstance(sid, torch.Tensor) and sid.numel() == 1):
                        idx = sid if isinstance(sid, int) else sid.item()
                        if idx >= 0 and idx < len(val_dataset):
                            batch_sentences.append(val_dataset.data[idx][0])
                        else:
                            # 如果索引无效，使用一个默认值
                            batch_sentences.append(f"[Unknown sentence with ID {sid}]")
                    else:
                        # 如果sent_ids本身就是句子标识符（字符串）
                        batch_sentences.append(str(sid))
                
                for i in range(len(batch_sentences)):
                    coarse_indices = np.where(coarse_preds_np[i] == 1)[0]
                    fine_indices = np.where(fine_preds_np[i] == 1)[0]
                    predicted_coarse = [coarse_label_map[idx] for idx in coarse_indices]
                    predicted_fine = [fine_label_map[idx] for idx in fine_indices]
                    
                    # 真实标签
                    true_coarse_indices = np.where(coarse_labels[i].cpu().numpy() == 1)[0]
                    true_fine_indices = np.where(fine_labels[i].cpu().numpy() == 1)[0]
                    true_coarse = [coarse_label_map[idx] for idx in true_coarse_indices]
                    true_fine = [fine_label_map[idx] for idx in true_fine_indices]

                    val_sentence_predictions.append({
                        'sentence': batch_sentences[i],
                        'predicted_coarse': predicted_coarse,
                        'predicted_fine': predicted_fine,
                        'true_coarse': true_coarse,
                        'true_fine': true_fine
                    })
                    
        # 保存验证集预测结果和真实标签到文件
        val_pred_file = os.path.join(checkpoint_dir, f"val_predictions_epoch_{epoch+1}.json")
        with open(val_pred_file, "w", encoding="utf-8") as f:
            json.dump(val_sentence_predictions, f, ensure_ascii=False, indent=4)
        
        # 计算验证指标
        val_loss = val_loss / len(val_dataloader)
        
        # 计算各种评估指标
        val_coarse_metrics_micro = calculate_metrics(all_coarse_labels, all_coarse_preds, average='micro')
        val_coarse_metrics_macro = calculate_metrics(all_coarse_labels, all_coarse_preds, average='macro')
        val_fine_metrics_micro = calculate_metrics(all_fine_labels, all_fine_preds, average='micro')
        val_fine_metrics_macro = calculate_metrics(all_fine_labels, all_fine_preds, average='macro')
        val_constrained_fine_metrics_micro = calculate_metrics(all_fine_labels, all_constrained_fine_preds, average='micro')
        val_constrained_fine_metrics_macro = calculate_metrics(all_fine_labels, all_constrained_fine_preds, average='macro')
        
        # 输出训练和验证指标
        print(f'\nEpoch {epoch+1}/{configs.epochs}:')
        print(f'  Train Loss: {train_loss:.4f}')
        print(f'  Train Coarse-grained Metrics:')
        print(f'    Micro F1: {train_coarse_metrics_micro["micro_f1"]:.2f}')
        print(f'    Macro F1: {train_coarse_metrics_macro["macro_f1"]:.2f}')
        print(f'    Accuracy: {train_coarse_metrics_micro["accuracy"]:.2f}')
        print(f'  Train Fine-grained Metrics (Unconstrained):')
        print(f'    Micro F1: {train_fine_metrics_micro["micro_f1"]:.2f}')
        print(f'    Macro F1: {train_fine_metrics_macro["macro_f1"]:.2f}')
        print(f'    Accuracy: {train_fine_metrics_micro["accuracy"]:.2f}')
        print(f'  Train Fine-grained Metrics (Constrained):')
        print(f'    Micro F1: {train_constrained_fine_metrics_micro["micro_f1"]:.2f}')
        print(f'    Macro F1: {train_constrained_fine_metrics_macro["macro_f1"]:.2f}')
        print(f'    Accuracy: {train_constrained_fine_metrics_micro["accuracy"]:.2f}')
        print("+"*50)
        print(f'  Val Loss: {val_loss:.4f}')
        print(f'  Val Coarse-grained Metrics:')
        print(f'    Micro F1: {val_coarse_metrics_micro["micro_f1"]:.2f}')
        print(f'    Macro F1: {val_coarse_metrics_macro["macro_f1"]:.2f}')
        print(f'    Accuracy: {val_coarse_metrics_micro["accuracy"]:.2f}')
        print(f'  Val Fine-grained Metrics (Unconstrained):')
        print(f'    Micro F1: {val_fine_metrics_micro["micro_f1"]:.2f}')
        print(f'    Macro F1: {val_fine_metrics_macro["macro_f1"]:.2f}')
        print(f'    Accuracy: {val_fine_metrics_micro["accuracy"]:.2f}')
        print(f'  Val Fine-grained Metrics (Constrained):')
        print(f'    Micro F1: {val_constrained_fine_metrics_micro["micro_f1"]:.2f}')
        print(f'    Macro F1: {val_constrained_fine_metrics_macro["macro_f1"]:.2f}')
        print(f'    Accuracy: {val_constrained_fine_metrics_micro["accuracy"]:.2f}')
        
        # 记录到wandb
        wandb.log({
            'train_loss': train_loss,
            'train_coarse_micro_f1': train_coarse_metrics_micro["micro_f1"],
            'train_coarse_macro_f1': train_coarse_metrics_macro["macro_f1"],
            'train_coarse_accuracy': train_coarse_metrics_micro["accuracy"],
            'train_fine_micro_f1': train_fine_metrics_micro["micro_f1"],
            'train_fine_macro_f1': train_fine_metrics_macro["macro_f1"],
            'train_fine_accuracy': train_fine_metrics_micro["accuracy"],
            'train_constrained_fine_micro_f1': train_constrained_fine_metrics_micro["micro_f1"],
            'train_constrained_fine_macro_f1': train_constrained_fine_metrics_macro["macro_f1"],
            'train_constrained_fine_accuracy': train_constrained_fine_metrics_micro["accuracy"],
            'val_loss': val_loss,
            'val_coarse_micro_f1': val_coarse_metrics_micro["micro_f1"],
            'val_coarse_macro_f1': val_coarse_metrics_macro["macro_f1"],
            'val_coarse_accuracy': val_coarse_metrics_micro["accuracy"],
            'val_fine_micro_f1': val_fine_metrics_micro["micro_f1"],
            'val_fine_macro_f1': val_fine_metrics_macro["macro_f1"],
            'val_fine_accuracy': val_fine_metrics_micro["accuracy"],
            'val_constrained_fine_micro_f1': val_constrained_fine_metrics_micro["micro_f1"],
            'val_constrained_fine_macro_f1': val_constrained_fine_metrics_macro["macro_f1"],
            'val_constrained_fine_accuracy': val_constrained_fine_metrics_micro["accuracy"],
            'epoch': epoch + 1
        })
        
        # 检查是否保存最佳模型并应用早停
        if val_constrained_fine_metrics_micro["micro_f1"] > best_val_f1:
            best_val_f1 = val_constrained_fine_metrics_micro["micro_f1"]
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best_model.pt'))
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= configs.patience:
                print('Early stopping triggered.')
                break
    
    # 加载最佳模型进行测试
    print("\n===== Testing best model =====")
    model.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'best_model.pt')))
    model.eval()
    
    all_coarse_preds = []
    all_coarse_labels = []
    all_fine_preds = []
    all_fine_labels = []
    all_constrained_fine_preds = []
    
    # 记录测试集预测
    test_sentence_predictions = []
    
    with torch.no_grad():
        for input_ids, attention_mask, token_type_ids, coarse_labels, fine_labels, sent_ids in test_dataloader:
            # 前向传播
            coarse_probs, fine_probs = model(input_ids, attention_mask, token_type_ids)
            
            # 应用层次约束
            coarse_preds = (coarse_probs > configs.threshold).float()
            fine_preds = (fine_probs > configs.threshold).float()
            constrained_fine_preds = model.apply_hierarchical_constraint(coarse_preds, fine_preds)
            
            # 收集预测结果
            all_coarse_preds.extend(coarse_preds.cpu().numpy())
            all_coarse_labels.extend(coarse_labels.cpu().numpy())
            all_fine_preds.extend(fine_preds.cpu().numpy())
            all_constrained_fine_preds.extend(constrained_fine_preds.cpu().numpy())
            all_fine_labels.extend(fine_labels.cpu().numpy())
            
            # 获取当前批次的原始句子
            batch_sentences = []
            for sid in sent_ids:
                # 如果sent_ids是数字索引
                if isinstance(sid, int) or (isinstance(sid, torch.Tensor) and sid.numel() == 1):
                    idx = sid if isinstance(sid, int) else sid.item()
                    if idx >= 0 and idx < len(test_dataset):
                        batch_sentences.append(test_dataset.data[idx][0])
                    else:
                        # 如果索引无效，使用一个默认值
                        batch_sentences.append(f"[Unknown sentence with ID {sid}]")
                else:
                    # 如果sent_ids本身就是句子标识符（字符串）
                    batch_sentences.append(str(sid))
            
            # 记录当前 batch 中每个样本的预测结果和真实标签
            coarse_preds_np = coarse_preds.cpu().numpy()
            fine_preds_np = constrained_fine_preds.cpu().numpy()
            for i in range(len(batch_sentences)):
                coarse_indices = np.where(coarse_preds_np[i] == 1)[0]
                fine_indices = np.where(fine_preds_np[i] == 1)[0]
                predicted_coarse = [coarse_label_map[idx] for idx in coarse_indices]
                predicted_fine = [fine_label_map[idx] for idx in fine_indices]

                # 真实标签
                true_coarse_indices = np.where(coarse_labels[i].cpu().numpy() == 1)[0]
                true_fine_indices = np.where(fine_labels[i].cpu().numpy() == 1)[0]
                true_coarse = [coarse_label_map[idx] for idx in true_coarse_indices]
                true_fine = [fine_label_map[idx] for idx in true_fine_indices]

                test_sentence_predictions.append({
                    'sentence': batch_sentences[i],
                    'predicted_coarse': predicted_coarse,
                    'predicted_fine': predicted_fine,
                    'true_coarse': true_coarse,
                    'true_fine': true_fine
                })
                
    # 保存测试集预测结果和真实标签到文件
    test_pred_file = os.path.join(checkpoint_dir, "test_predictions.json")
    with open(test_pred_file, "w", encoding="utf-8") as f:
        json.dump(test_sentence_predictions, f, ensure_ascii=False, indent=4)
    
    # 计算测试指标
    test_coarse_metrics_micro = calculate_metrics(all_coarse_labels, all_coarse_preds, average='micro')
    test_coarse_metrics_macro = calculate_metrics(all_coarse_labels, all_coarse_preds, average='macro')
    test_fine_metrics_micro = calculate_metrics(all_fine_labels, all_fine_preds, average='micro')
    test_fine_metrics_macro = calculate_metrics(all_fine_labels, all_fine_preds, average='macro')
    test_constrained_fine_metrics_micro = calculate_metrics(all_fine_labels, all_constrained_fine_preds, average='micro')
    test_constrained_fine_metrics_macro = calculate_metrics(all_fine_labels, all_constrained_fine_preds, average='macro')
    
    # 输出测试结果
    print("\n===== Final Test Results =====")
    print(f"Final micro f1: {test_constrained_fine_metrics_micro['micro_f1']:.2f}")
    print(f"Final macro f1: {test_constrained_fine_metrics_macro['macro_f1']:.2f}")
    
    print("\nCoarse-grained micro f1: {:.2f}".format(test_coarse_metrics_micro['micro_f1']))
    print("Fine-grained micro f1: {:.2f}".format(test_constrained_fine_metrics_micro['micro_f1']))
    
    print("\nCoarse-grained macro f1: {:.2f}".format(test_coarse_metrics_macro['macro_f1']))
    print("Fine-grained macro f1: {:.2f}".format(test_constrained_fine_metrics_macro['macro_f1']))
    
    print("\nAccuracy: {:.2f}".format(test_constrained_fine_metrics_micro['accuracy']))
    
    # 以表格形式输出所有指标（与给定的评估表格格式一致）
    print("\n")
    print("+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+")
    print("| {:<13} | {:<13} | {:<13} | {:<13} | {:<13} | {:<13} |".format(
        "Final", "Final", "Course-", "Fine-grained", "Course-", "Fine-grained"))
    print("| {:<13} | {:<13} | {:<13} | {:<13} | {:<13} | {:<13} |".format(
        "micro f1", "macro f1", "grained micro f1", "micro f1", "grained macro f1", "macro f1"))
    print("+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+")
    print("| {:<13.2f} | {:<13.2f} | {:<13.2f} | {:<13.2f} | {:<13.2f} | {:<13.2f} |".format(
        test_constrained_fine_metrics_micro['micro_f1'],
        test_constrained_fine_metrics_macro['macro_f1'],
        test_coarse_metrics_micro['micro_f1'],
        test_constrained_fine_metrics_micro['micro_f1'],
        test_coarse_metrics_macro['macro_f1'],
        test_constrained_fine_metrics_macro['macro_f1']
    ))
    print("+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+" + "-"*15 + "+")
    
    # 记录最终结果到wandb
    wandb.log({
        'test_coarse_micro_f1': test_coarse_metrics_micro["micro_f1"],
        'test_coarse_macro_f1': test_coarse_metrics_macro["macro_f1"],
        'test_coarse_accuracy': test_coarse_metrics_micro["accuracy"],
        'test_fine_micro_f1': test_fine_metrics_micro["micro_f1"],
        'test_fine_macro_f1': test_fine_metrics_macro["macro_f1"],
        'test_fine_accuracy': test_fine_metrics_micro["accuracy"],
        'test_constrained_fine_micro_f1': test_constrained_fine_metrics_micro["micro_f1"],
        'test_constrained_fine_macro_f1': test_constrained_fine_metrics_macro["macro_f1"],
        'test_constrained_fine_accuracy': test_constrained_fine_metrics_micro["accuracy"],
        'final_micro_f1': test_constrained_fine_metrics_micro["micro_f1"],
        'final_macro_f1': test_constrained_fine_metrics_macro["macro_f1"]
    })
    
    # 完成wandb记录
    wandb.finish()

def predict(model, text, tokenizer, device, threshold=0.5):
    """
    对单个文本进行预测
    """
    model.eval()
    
    # 编码文本
    encoded = tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=128,
        return_tensors='pt'
    )
    
    input_ids = encoded['input_ids'].to(device)
    attention_mask = encoded['attention_mask'].to(device)
    token_type_ids = encoded.get('token_type_ids', None)
    
    if token_type_ids is not None:
        token_type_ids = token_type_ids.to(device)
    
    # 获取预测结果
    with torch.no_grad():
        coarse_probs, fine_probs = model(input_ids, attention_mask, token_type_ids)
        
        # 应用阈值
        coarse_preds = (coarse_probs > threshold).float()
        fine_preds = (fine_probs > threshold).float()
        
        # 应用层次约束
        constrained_fine_preds = model.apply_hierarchical_constraint(coarse_preds, fine_preds)
    
    # 映射结果到标签
    coarse_indices = torch.nonzero(coarse_preds[0]).cpu().numpy().flatten()
    fine_indices = torch.nonzero(constrained_fine_preds[0]).cpu().numpy().flatten()
    
    # 将索引转换为标签名称（需要模型中有这些映射）
    coarse_label_map = {v: k for k, v in model.coarse_labels.items()}
    fine_label_map = {v: k for k, v in model.fine_labels.items()}
    
    predicted_coarse = [coarse_label_map[idx] for idx in coarse_indices]
    predicted_fine = [fine_label_map[idx] for idx in fine_indices]
    
    return predicted_coarse, predicted_fine




# 在以下主函数中添加判断Jupyter环境的逻辑
if __name__ == '__main__':
    # 判断是否在Jupyter环境中运行
    try:
        # 检查是否在Jupyter中运行
        get_ipython = globals().get('get_ipython', None)
        if get_ipython and 'IPKernelApp' in get_ipython().config:
            # 在Jupyter环境中运行，使用默认配置
            print("Running in Jupyter environment, using default configs")
            configs = get_default_configs()
        else:
            # 在命令行环境中运行，使用argparse
            configs = argparser()
    except:
        # 任何异常都使用argparse处理
        configs = argparser()
    
    # 设置实验名称
    if configs.name is None:
        configs.exp_name = \
            f'{os.path.basename(configs.model_name)}' + \
            f'{"_fp" if configs.freeze_pooler else ""}' + \
            f'_b{configs.batch_size}_e{configs.epochs}' + \
            f'_len{configs.max_length}_lr{configs.lr}'
    else:
        configs.exp_name = configs.name
    
    # 设置设备
    if configs.device is None:
        configs.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu'
        )
    
    # 调用训练函数
    train(configs)

Running in Jupyter environment, using default configs


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Epoch 1/40: 100%|█| 13/13 [00:00<00:00, 13.67batch/s, coarse_loss=0.641, fine_loss=1.064, loss=1.705
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))



Epoch 1/40:
  Train Loss: 2.1052
  Train Coarse-grained Metrics:
    Micro F1: 39.44
    Macro F1: 30.83
    Accuracy: 11.54
  Train Fine-grained Metrics (Unconstrained):
    Micro F1: 16.76
    Macro F1: 12.88
    Accuracy: 0.00
  Train Fine-grained Metrics (Constrained):
    Micro F1: 11.81
    Macro F1: 7.32
    Accuracy: 0.96
++++++++++++++++++++++++++++++++++++++++++++++++++
  Val Loss: 1.3504
  Val Coarse-grained Metrics:
    Micro F1: 56.72
    Macro F1: 20.65
    Accuracy: 29.63
  Val Fine-grained Metrics (Unconstrained):
    Micro F1: 9.88
    Macro F1: 5.40
    Accuracy: 0.00
  Val Fine-grained Metrics (Constrained):
    Micro F1: 9.68
    Macro F1: 3.02
    Accuracy: 0.00


Epoch 2/40: 100%|█| 13/13 [00:00<00:00, 13.79batch/s, coarse_loss=0.650, fine_loss=0.939, loss=1.589
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))



Epoch 2/40:
  Train Loss: 1.5290
  Train Coarse-grained Metrics:
    Micro F1: 45.21
    Macro F1: 20.77
    Accuracy: 15.38
  Train Fine-grained Metrics (Unconstrained):
    Micro F1: 23.14
    Macro F1: 12.53
    Accuracy: 0.00
  Train Fine-grained Metrics (Constrained):
    Micro F1: 22.47
    Macro F1: 7.44
    Accuracy: 0.96
++++++++++++++++++++++++++++++++++++++++++++++++++
  Val Loss: 1.4677
  Val Coarse-grained Metrics:
    Micro F1: 57.58
    Macro F1: 20.65
    Accuracy: 29.63
  Val Fine-grained Metrics (Unconstrained):
    Micro F1: 0.00
    Macro F1: 0.00
    Accuracy: 0.00
  Val Fine-grained Metrics (Constrained):
    Micro F1: 0.00
    Macro F1: 0.00
    Accuracy: 0.00


Epoch 3/40: 100%|█| 13/13 [00:00<00:00, 13.76batch/s, coarse_loss=0.639, fine_loss=0.626, loss=1.265
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))



Epoch 3/40:
  Train Loss: 1.4089
  Train Coarse-grained Metrics:
    Micro F1: 48.67
    Macro F1: 21.76
    Accuracy: 20.19
  Train Fine-grained Metrics (Unconstrained):
    Micro F1: 20.57
    Macro F1: 14.59
    Accuracy: 0.00
  Train Fine-grained Metrics (Constrained):
    Micro F1: 18.44
    Macro F1: 6.02
    Accuracy: 0.00
++++++++++++++++++++++++++++++++++++++++++++++++++
  Val Loss: 1.5886
  Val Coarse-grained Metrics:
    Micro F1: 57.58
    Macro F1: 20.65
    Accuracy: 29.63
  Val Fine-grained Metrics (Unconstrained):
    Micro F1: 0.00
    Macro F1: 0.00
    Accuracy: 0.00
  Val Fine-grained Metrics (Constrained):
    Micro F1: 0.00
    Macro F1: 0.00
    Accuracy: 0.00


Epoch 4/40: 100%|█| 13/13 [00:00<00:00, 13.78batch/s, coarse_loss=0.656, fine_loss=0.857, loss=1.513
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))



Epoch 4/40:
  Train Loss: 1.4906
  Train Coarse-grained Metrics:
    Micro F1: 48.06
    Macro F1: 19.02
    Accuracy: 20.19
  Train Fine-grained Metrics (Unconstrained):
    Micro F1: 16.51
    Macro F1: 11.86
    Accuracy: 0.00
  Train Fine-grained Metrics (Constrained):
    Micro F1: 15.72
    Macro F1: 6.15
    Accuracy: 1.92
++++++++++++++++++++++++++++++++++++++++++++++++++
  Val Loss: 1.2465
  Val Coarse-grained Metrics:
    Micro F1: 57.58
    Macro F1: 20.65
    Accuracy: 29.63
  Val Fine-grained Metrics (Unconstrained):
    Micro F1: 20.34
    Macro F1: 4.76
    Accuracy: 3.70
  Val Fine-grained Metrics (Constrained):
    Micro F1: 20.34
    Macro F1: 4.76
    Accuracy: 3.70


Epoch 5/40: 100%|█| 13/13 [00:00<00:00, 13.85batch/s, coarse_loss=0.666, fine_loss=0.509, loss=1.175
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))



Epoch 5/40:
  Train Loss: 1.3501
  Train Coarse-grained Metrics:
    Micro F1: 48.28
    Macro F1: 18.98
    Accuracy: 21.15
  Train Fine-grained Metrics (Unconstrained):
    Micro F1: 23.76
    Macro F1: 15.61
    Accuracy: 1.92
  Train Fine-grained Metrics (Constrained):
    Micro F1: 22.37
    Macro F1: 8.12
    Accuracy: 1.92
++++++++++++++++++++++++++++++++++++++++++++++++++
  Val Loss: 1.0523
  Val Coarse-grained Metrics:
    Micro F1: 57.58
    Macro F1: 20.65
    Accuracy: 29.63
  Val Fine-grained Metrics (Unconstrained):
    Micro F1: 7.69
    Macro F1: 2.86
    Accuracy: 0.00
  Val Fine-grained Metrics (Constrained):
    Micro F1: 7.69
    Macro F1: 2.86
    Accuracy: 0.00


Epoch 6/40: 100%|█| 13/13 [00:00<00:00, 13.65batch/s, coarse_loss=0.618, fine_loss=0.586, loss=1.204
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))



Epoch 6/40:
  Train Loss: 1.1867
  Train Coarse-grained Metrics:
    Micro F1: 48.67
    Macro F1: 19.05
    Accuracy: 21.15
  Train Fine-grained Metrics (Unconstrained):
    Micro F1: 29.48
    Macro F1: 23.64
    Accuracy: 1.92
  Train Fine-grained Metrics (Constrained):
    Micro F1: 25.30
    Macro F1: 10.41
    Accuracy: 0.00
++++++++++++++++++++++++++++++++++++++++++++++++++
  Val Loss: 1.2001
  Val Coarse-grained Metrics:
    Micro F1: 57.58
    Macro F1: 20.65
    Accuracy: 29.63
  Val Fine-grained Metrics (Unconstrained):
    Micro F1: 7.69
    Macro F1: 2.60
    Accuracy: 0.00
  Val Fine-grained Metrics (Constrained):
    Micro F1: 7.69
    Macro F1: 2.60
    Accuracy: 0.00


Epoch 7/40: 100%|█| 13/13 [00:00<00:00, 13.72batch/s, coarse_loss=0.629, fine_loss=0.617, loss=1.247
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))



Epoch 7/40:
  Train Loss: 1.2442
  Train Coarse-grained Metrics:
    Micro F1: 48.67
    Macro F1: 19.05
    Accuracy: 21.15
  Train Fine-grained Metrics (Unconstrained):
    Micro F1: 22.22
    Macro F1: 15.88
    Accuracy: 1.92
  Train Fine-grained Metrics (Constrained):
    Micro F1: 19.64
    Macro F1: 9.00
    Accuracy: 1.92
++++++++++++++++++++++++++++++++++++++++++++++++++
  Val Loss: 1.4964
  Val Coarse-grained Metrics:
    Micro F1: 57.58
    Macro F1: 20.65
    Accuracy: 29.63
  Val Fine-grained Metrics (Unconstrained):
    Micro F1: 0.00
    Macro F1: 0.00
    Accuracy: 0.00
  Val Fine-grained Metrics (Constrained):
    Micro F1: 0.00
    Macro F1: 0.00
    Accuracy: 0.00


Epoch 8/40: 100%|█| 13/13 [00:00<00:00, 13.71batch/s, coarse_loss=0.654, fine_loss=0.490, loss=1.144
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))



Epoch 8/40:
  Train Loss: 1.2849
  Train Coarse-grained Metrics:
    Micro F1: 48.85
    Macro F1: 19.16
    Accuracy: 21.15
  Train Fine-grained Metrics (Unconstrained):
    Micro F1: 23.56
    Macro F1: 15.46
    Accuracy: 3.85
  Train Fine-grained Metrics (Constrained):
    Micro F1: 22.86
    Macro F1: 8.17
    Accuracy: 2.88
++++++++++++++++++++++++++++++++++++++++++++++++++
  Val Loss: 1.2551
  Val Coarse-grained Metrics:
    Micro F1: 57.58
    Macro F1: 20.65
    Accuracy: 29.63
  Val Fine-grained Metrics (Unconstrained):
    Micro F1: 11.32
    Macro F1: 3.57
    Accuracy: 3.70
  Val Fine-grained Metrics (Constrained):
    Micro F1: 11.32
    Macro F1: 3.57
    Accuracy: 3.70


Epoch 9/40: 100%|█| 13/13 [00:00<00:00, 13.73batch/s, coarse_loss=0.637, fine_loss=0.334, loss=0.971
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  model.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'best_model.pt')))



Epoch 9/40:
  Train Loss: 1.1070
  Train Coarse-grained Metrics:
    Micro F1: 48.85
    Macro F1: 19.16
    Accuracy: 21.15
  Train Fine-grained Metrics (Unconstrained):
    Micro F1: 38.20
    Macro F1: 26.89
    Accuracy: 7.69
  Train Fine-grained Metrics (Constrained):
    Micro F1: 32.75
    Macro F1: 12.41
    Accuracy: 4.81
++++++++++++++++++++++++++++++++++++++++++++++++++
  Val Loss: 1.4199
  Val Coarse-grained Metrics:
    Micro F1: 57.58
    Macro F1: 20.65
    Accuracy: 29.63
  Val Fine-grained Metrics (Unconstrained):
    Micro F1: 0.00
    Macro F1: 0.00
    Accuracy: 0.00
  Val Fine-grained Metrics (Constrained):
    Micro F1: 0.00
    Macro F1: 0.00
    Accuracy: 0.00
Early stopping triggered.

===== Testing best model =====

===== Final Test Results =====
Final micro f1: 20.34
Final macro f1: 4.76

Coarse-grained micro f1: 57.58
Fine-grained micro f1: 20.34

Coarse-grained macro f1: 20.65
Fine-grained macro f1: 4.76

Accuracy: 3.70


+---------------+---------------+-

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
