In [1]:
import os
# 设置环境变量，只让程序看到 GPU 2
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


import torch
import torch.nn as nn
import wandb
import random
import argparse
import numpy as np
from tqdm import tqdm
from transformers import BertModel, AutoModel
from transformers import AdamW

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import Dataset
import json

class BAE2025Dataset(Dataset):
    def __init__(
            self,
            data_path,
            labels={
                "Yes": 0,
                "To some extent": 1, 
                "No": 2,
            }
    ):
        self.data_path = data_path
        self.labels = labels
        self._get_data()
    
    def _get_data(self):
        with open(self.data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        self.data = []
        for item in data:
            tutor_responses = item['tutor_responses']
            for response in tutor_responses.values():
                sent1 = item['conversation_history']
                sent2 = response['response']
                label = response['annotation']["Providing_Guidance"]
                if label in self.labels:
                    self.data.append(((sent1, sent2), self.labels[label]))
    
    def __len__(self):
        return len(self.data)
    
    def get_labels(self):
        return self.labels

    def __getitem__(self, idx):
        return self.data[idx]

In [3]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

class BAE2025DataLoader:
    def __init__(
        self,
        dataset,
        batch_size=16,
        max_length=512,
        shuffle=True,
        drop_last=True,
        device=None,
        # tokenizer_name='chinese-bert-wwm-ext'
        # tokenizer_name='chinese-roberta-wwm-ext'
        # tokenizer_name='chinese-roberta-wwm-ext-large'
        # tokenizer_name='/mnt/cfs/huangzhiwei/pykt-moekt/SBM/bge-large-en-v1.5'
        # tokenizer_name='/mnt/cfs/huangzhiwei/BAE2025/models/bge-base-en-v1.5'
        tokenizer_name='/mnt/cfs/huangzhiwei/BAE2025/models/bert-base-uncased'
        # tokenizer_name='/mnt/cfs/huangzhiwei/BAE2025/models/deberta-v3-base'
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.dataset = dataset
        self.batch_size = batch_size
        self.max_length = max_length
        self.shuffle = shuffle
        self.drop_last = drop_last

        if device is None:
            self.device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu'
            )
        else:
            self.device = device

        self.loader = DataLoader(
            dataset=self.dataset,
            batch_size=self.batch_size,
            collate_fn=self.collate_fn,
            shuffle=self.shuffle,
            drop_last=self.drop_last
        )

    def collate_fn(self, data):
        sents = [i[0] for i in data]
        labels = [i[1] for i in data]

        # 修改这里，处理两个句子的情况
        data = self.tokenizer.batch_encode_plus(
            batch_text_or_text_pairs=[(sent[0], sent[1]) for sent in sents],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt',
            return_length=True
        )
        input_ids = data['input_ids'].to(self.device)
        attention_mask = data['attention_mask'].to(self.device)
        token_type_ids = data['token_type_ids'].to(self.device)
        labels = torch.LongTensor(labels).to(self.device)

        return input_ids, attention_mask, token_type_ids, labels
        # return input_ids, attention_mask, labels


    def __iter__(self):
        for data in self.loader:
            yield data

    def __len__(self):
        return len(self.loader)



In [4]:
import torch
import torch.nn as nn
from transformers import BertModel

# 修改模型类以支持分层分类
class HierarchicalBertClassifier(nn.Module):
    def __init__(self, pretrained_model_name, freeze_pooler=0, dropout=0.3):
        super().__init__()
        
        # 第一阶段分类器：Yes vs 非Yes
        self.bert_stage1 = BertModel.from_pretrained(pretrained_model_name, output_hidden_states=True)
        
        # 第二阶段分类器：To some extent vs No
        self.bert_stage2 = BertModel.from_pretrained(pretrained_model_name, output_hidden_states=True)
        
        # 冻结BERT底层，保留顶层微调
        if freeze_pooler > 0:
            # 冻结第一阶段模型的底层
            modules1 = [self.bert_stage1.embeddings, *self.bert_stage1.encoder.layer[:freeze_pooler]]
            for module in modules1:
                for param in module.parameters():
                    param.requires_grad = False
            
            # 冻结第二阶段模型的底层
            modules2 = [self.bert_stage2.embeddings, *self.bert_stage2.encoder.layer[:freeze_pooler]]
            for module in modules2:
                for param in module.parameters():
                    param.requires_grad = False
        
        # 获取bert隐藏层大小
        bert_hidden_size = self.bert_stage1.config.hidden_size
        
        # 第一阶段的分类头（二分类：Yes vs 非Yes）
        self.stage1_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(bert_hidden_size, bert_hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(bert_hidden_size, 2)  # 二分类
        )
        
        # 第二阶段的分类头（二分类：To some extent vs No）
        self.stage2_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(bert_hidden_size, bert_hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(bert_hidden_size, 2)  # 二分类
        )
    
    def forward_stage1(self, input_ids, attention_mask, token_type_ids=None):
        """第一阶段：预测是Yes还是非Yes"""
        outputs = self.bert_stage1(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        # 使用[CLS]表示的序列表示
        cls_output = outputs.last_hidden_state[:, 0, :]
        logits = self.stage1_classifier(cls_output)
        return logits
    
    def forward_stage2(self, input_ids, attention_mask, token_type_ids=None):
        """第二阶段：预测是To some extent还是No"""
        outputs = self.bert_stage2(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        # 使用[CLS]表示的序列表示
        cls_output = outputs.last_hidden_state[:, 0, :]
        logits = self.stage2_classifier(cls_output)
        return logits
    
    def forward(self, input_ids, attention_mask, token_type_ids=None, stage=None):
        """根据stage参数进行相应阶段的前向传播"""
        if stage == 1:
            return self.forward_stage1(input_ids, attention_mask, token_type_ids)
        elif stage == 2:
            return self.forward_stage2(input_ids, attention_mask, token_type_ids)
        else:
            # 默认行为：完整的两阶段预测
            # 第一阶段：预测是Yes还是非Yes
            stage1_logits = self.forward_stage1(input_ids, attention_mask, token_type_ids)
            stage1_preds = torch.argmax(stage1_logits, dim=1)
            
            # 第二阶段：对预测为"非Yes"的样本进行To some extent vs No预测
            # 创建一个全为0的三分类输出张量（Yes=0, To some extent=1, No=2）
            batch_size = input_ids.size(0)
            final_logits = torch.zeros(batch_size, 3, device=input_ids.device)
            
            # 设置Yes的logits值（从stage1获取）
            final_logits[:, 0] = stage1_logits[:, 0]  # Yes的logits
            
            # 获取预测为非Yes(1)的样本索引
            non_yes_indices = (stage1_preds == 1).nonzero(as_tuple=True)[0]
            
            if len(non_yes_indices) > 0:
                # 只对预测为"非Yes"的样本进行第二阶段预测
                non_yes_input_ids = input_ids[non_yes_indices]
                non_yes_attention_mask = attention_mask[non_yes_indices]
                non_yes_token_type_ids = None if token_type_ids is None else token_type_ids[non_yes_indices]
                
                stage2_logits = self.forward_stage2(non_yes_input_ids, non_yes_attention_mask, non_yes_token_type_ids)
                
                # 将第二阶段的预测结果（To some extent vs No）放入最终结果中
                final_logits[non_yes_indices, 1] = stage2_logits[:, 0]  # To some extent的logits
                final_logits[non_yes_indices, 2] = stage2_logits[:, 1]  # No的logits
            
            return final_logits

In [5]:
import os
import wandb
import random
import argparse
from tqdm import tqdm

import torch
import torch.nn as nn
import numpy as np
from transformers import AdamW
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

# 如果在Jupyter Notebook中运行，可以使用这个自定义参数函数替代argparser
def get_default_configs():
    """在Jupyter环境中使用的默认配置，避免argparse解析错误"""
    class Args:
        def __init__(self):
            # self.model_name = '/mnt/cfs/huangzhiwei/pykt-moekt/SBM/bge-large-en-v1.5'
            # self.model_name = "/mnt/cfs/huangzhiwei/BAE2025/models/ModernBERT-large"
            # self.model_name = '/mnt/cfs/huangzhiwei/pykt-moekt/SBM/xlm-roberta-large'
            # self.model_name = '/mnt/cfs/huangzhiwei/BAE2025/models/bge-base-en-v1.5'
            self.model_name = '/mnt/cfs/huangzhiwei/BAE2025/models/bert-base-uncased'
            # self.model_name = '/mnt/cfs/huangzhiwei/BAE2025/models/deberta-v3-base'
            self.num_classes = 3
            self.dropout = 0.3
            self.freeze_pooler = 8
            self.batch_size = 16
            self.max_length = 512
            self.lr = 1e-5
            self.epochs = 50
            self.device = device
            self.name = None
            self.seed = 42
            self.data_path = './data/train_extend.json'
            self.val_data_path = './data/valid_extend.json'
            self.checkpoint_dir = 'checkpoints_2to2_extend'
            self.patience = 8
            self.exp_name = 'BAE2025_track4_bert'
    return Args()


In [6]:
def train_hierarchical(configs):
    # 设置随机种子
    random.seed(configs.seed)
    np.random.seed(configs.seed)
    torch.manual_seed(configs.seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # 创建检查点目录
    checkpoint_dir = os.path.join(configs.checkpoint_dir, configs.exp_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # 为保存混淆矩阵创建目录 - 分别为训练集和验证集创建
    train_plot_dir = os.path.join(checkpoint_dir, 'plots', 'train')
    val_plot_dir = os.path.join(checkpoint_dir, 'plots', 'val')
    os.makedirs(train_plot_dir, exist_ok=True)
    os.makedirs(val_plot_dir, exist_ok=True)
    
    # 加载数据集
    train_dataset = BAE2025Dataset(configs.data_path)
    val_dataset = BAE2025Dataset(configs.val_data_path)    

    # 创建数据加载器
    train_dataloader = BAE2025DataLoader(
        dataset=train_dataset,
        batch_size=configs.batch_size,
        max_length=configs.max_length,
        shuffle=True,
        drop_last=True,
        device=configs.device,
        tokenizer_name=configs.model_name
    )

    val_dataloader = BAE2025DataLoader(
        dataset=val_dataset,
        batch_size=configs.batch_size,
        max_length=configs.max_length,
        shuffle=False,
        drop_last=False,
        device=configs.device,
        tokenizer_name=configs.model_name
    )
    
    # 创建分层分类模型
    model = HierarchicalBertClassifier(
        pretrained_model_name=configs.model_name,
        freeze_pooler=configs.freeze_pooler,
        dropout=configs.dropout
    ).to(configs.device)

    # 定义两个阶段的损失函数
    criterion = nn.CrossEntropyLoss()

    # 定义优化器
    optimizer = AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=configs.lr
    )

    # 初始化最佳验证指标
    best_val_acc = 0.0
    best_val_f1 = 0.0
    patience_counter = 0
    
    # 定义原始类别名称和阶段类别名称
    class_names = ['Yes', 'To some extent', 'No']
    stage1_names = ['Yes', 'Non-Yes']
    stage2_names = ['To some extent', 'No']
    
    # 添加计算所需的库
    from sklearn.metrics import f1_score, confusion_matrix, accuracy_score
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # 训练循环
    for epoch in range(configs.epochs):
        print(f"\n===== Epoch {epoch + 1}/{configs.epochs} =====")
        
        # ======== 训练第一阶段模型：Yes vs 非Yes ========
        model.train()
        stage1_train_loss = 0.0
        stage1_train_preds = []
        stage1_train_labels = []
        
        print("Training Stage 1 (Yes vs Non-Yes)...")
        with tqdm(train_dataloader, total=len(train_dataloader), desc="Stage 1", unit="batch", ncols=100) as pbar:
            for input_ids, attention_mask, token_type_ids, labels in pbar:
                optimizer.zero_grad()
                
                # 将原始标签转换为二分类标签：0(Yes) 或 1(非Yes)
                stage1_labels = (labels > 0).long()  # Yes=0, 其他=1
                
                # 前向传播第一阶段
                stage1_logits = model(input_ids, attention_mask, token_type_ids, stage=1)
                
                # 计算损失
                loss = criterion(stage1_logits, stage1_labels)
                
                # 反向传播
                loss.backward()
                optimizer.step()
                
                # 收集预测和标签
                preds = torch.argmax(stage1_logits, dim=1)
                stage1_train_preds.extend(preds.cpu().numpy())
                stage1_train_labels.extend(stage1_labels.cpu().numpy())
                
                stage1_train_loss += loss.item()
                
                # 更新进度条
                pbar.set_postfix(loss=f'{loss.item():.3f}')
        
        # 计算第一阶段训练指标
        stage1_train_loss /= len(train_dataloader)
        stage1_train_acc = accuracy_score(stage1_train_labels, stage1_train_preds)
        stage1_train_f1 = f1_score(stage1_train_labels, stage1_train_preds, average='macro')
        
        print(f"Stage 1 Training - Loss: {stage1_train_loss:.4f}, Acc: {stage1_train_acc:.4f}, F1: {stage1_train_f1:.4f}")
        
        # 创建并保存第一阶段训练混淆矩阵
        cm = confusion_matrix(stage1_train_labels, stage1_train_preds)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=stage1_names, yticklabels=stage1_names)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Train: Stage 1 (Yes vs Non-Yes)\nAcc: {stage1_train_acc:.4f}, F1: {stage1_train_f1:.4f}')
        matrix_path = os.path.join(train_plot_dir, f'stage1_cm_epoch_{epoch+1}.png')
        plt.savefig(matrix_path)
        plt.close()
        
        # ======== 训练第二阶段模型：To some extent vs No ========
        # 筛选出标签为To some extent或No的样本索引
        stage2_train_loss = 0.0
        stage2_train_preds = []
        stage2_train_labels = []
        stage2_sample_count = 0
        
        print("Training Stage 2 (To some extent vs No)...")
        with tqdm(train_dataloader, total=len(train_dataloader), desc="Stage 2", unit="batch", ncols=100) as pbar:
            for input_ids, attention_mask, token_type_ids, labels in pbar:
                # 筛选非Yes样本的索引
                non_yes_indices = (labels > 0).nonzero(as_tuple=True)[0]
                
                if len(non_yes_indices) == 0:
                    continue  # 如果批次中没有非Yes样本，跳过
                
                # 提取非Yes样本的数据
                non_yes_input_ids = input_ids[non_yes_indices]
                non_yes_attention_mask = attention_mask[non_yes_indices]
                non_yes_token_type_ids = token_type_ids[non_yes_indices]
                non_yes_labels = labels[non_yes_indices]
                
                # 将原始标签转换为二分类标签：0(To some extent) 或 1(No)
                # 原始：0=Yes, 1=To some extent, 2=No
                # 现在：0=To some extent, 1=No
                stage2_labels = (non_yes_labels == 2).long()  # To some extent=0, No=1
                
                optimizer.zero_grad()
                
                # 前向传播第二阶段
                stage2_logits = model(non_yes_input_ids, non_yes_attention_mask, non_yes_token_type_ids, stage=2)
                
                # 计算损失
                loss = criterion(stage2_logits, stage2_labels)
                
                # 反向传播
                loss.backward()
                optimizer.step()
                
                # 收集预测和标签
                preds = torch.argmax(stage2_logits, dim=1)
                stage2_train_preds.extend(preds.cpu().numpy())
                stage2_train_labels.extend(stage2_labels.cpu().numpy())
                
                stage2_train_loss += loss.item()
                stage2_sample_count += 1
                
                # 更新进度条
                pbar.set_postfix(loss=f'{loss.item():.3f}')
        
        # 计算第二阶段训练指标
        if stage2_sample_count > 0:
            stage2_train_loss /= stage2_sample_count
            stage2_train_acc = accuracy_score(stage2_train_labels, stage2_train_preds)
            stage2_train_f1 = f1_score(stage2_train_labels, stage2_train_preds, average='macro')
            
            print(f"Stage 2 Training - Loss: {stage2_train_loss:.4f}, Acc: {stage2_train_acc:.4f}, F1: {stage2_train_f1:.4f}")
            
            # 创建并保存第二阶段训练混淆矩阵
            cm = confusion_matrix(stage2_train_labels, stage2_train_preds)
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=stage2_names, yticklabels=stage2_names)
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title(f'Train: Stage 2 (To some extent vs No)\nAcc: {stage2_train_acc:.4f}, F1: {stage2_train_f1:.4f}')
            matrix_path = os.path.join(train_plot_dir, f'stage2_cm_epoch_{epoch+1}.png')
            plt.savefig(matrix_path)
            plt.close()
        
        # ======== 训练集整体评估 ========
        model.eval()
        train_preds = []
        train_labels = []
        
        with torch.no_grad():
            for input_ids, attention_mask, token_type_ids, labels in train_dataloader:
                # 完整两阶段预测
                logits = model(input_ids, attention_mask, token_type_ids)
                preds = torch.argmax(logits, dim=1)
                
                train_preds.extend(preds.cpu().numpy())
                train_labels.extend(labels.cpu().numpy())
        
        # 计算整体训练集指标
        train_acc = accuracy_score(train_labels, train_preds)
        train_f1 = f1_score(train_labels, train_preds, average='macro')
        
        print(f"Overall Training - Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
        
        # 创建完整的训练集混淆矩阵
        cm_full = confusion_matrix(train_labels, train_preds, labels=[0, 1, 2])
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm_full, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Train: Full Hierarchical Confusion Matrix\nAcc: {train_acc:.4f}, F1: {train_f1:.4f}')
        matrix_path = os.path.join(train_plot_dir, f'full_cm_epoch_{epoch+1}.png')
        plt.savefig(matrix_path)
        plt.close()
        
        # ======== 验证集评估 ========
        model.eval()
        val_preds = []
        val_labels = []
        stage1_val_preds = []
        stage1_val_labels = []
        stage2_val_preds = []
        stage2_val_labels = []
        
        print("Evaluating on validation set...")
        with torch.no_grad():
            for input_ids, attention_mask, token_type_ids, labels in val_dataloader:
                # 第一阶段评估
                stage1_logits = model(input_ids, attention_mask, token_type_ids, stage=1)
                stage1_preds = torch.argmax(stage1_logits, dim=1)
                stage1_labels_binary = (labels > 0).long()
                
                stage1_val_preds.extend(stage1_preds.cpu().numpy())
                stage1_val_labels.extend(stage1_labels_binary.cpu().numpy())
                
                # 找出非Yes样本
                non_yes_indices = (labels > 0).nonzero(as_tuple=True)[0]
                
                if len(non_yes_indices) > 0:
                    # 第二阶段评估
                    non_yes_input_ids = input_ids[non_yes_indices]
                    non_yes_attention_mask = attention_mask[non_yes_indices]
                    non_yes_token_type_ids = token_type_ids[non_yes_indices]
                    non_yes_labels = labels[non_yes_indices]
                    
                    stage2_logits = model(non_yes_input_ids, non_yes_attention_mask, non_yes_token_type_ids, stage=2)
                    stage2_preds = torch.argmax(stage2_logits, dim=1)
                    
                    # 转换为二分类标签：0=To some extent, 1=No
                    stage2_labels_binary = (non_yes_labels == 2).long()
                    
                    stage2_val_preds.extend(stage2_preds.cpu().numpy())
                    stage2_val_labels.extend(stage2_labels_binary.cpu().numpy())
                
                # 完整两阶段预测
                logits = model(input_ids, attention_mask, token_type_ids)
                preds = torch.argmax(logits, dim=1)
                
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
        
        # 计算验证集指标
        # 阶段1
        stage1_val_acc = accuracy_score(stage1_val_labels, stage1_val_preds)
        stage1_val_f1 = f1_score(stage1_val_labels, stage1_val_preds, average='macro')
        
        print(f"Stage 1 Validation - Acc: {stage1_val_acc:.4f}, F1: {stage1_val_f1:.4f}")
        
        # 创建阶段1验证混淆矩阵
        cm = confusion_matrix(stage1_val_labels, stage1_val_preds)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=stage1_names, yticklabels=stage1_names)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Val: Stage 1 (Yes vs Non-Yes)\nAcc: {stage1_val_acc:.4f}, F1: {stage1_val_f1:.4f}')
        matrix_path = os.path.join(val_plot_dir, f'stage1_cm_epoch_{epoch+1}.png')
        plt.savefig(matrix_path)
        plt.close()
        
        # 阶段2
        if len(stage2_val_labels) > 0:
            stage2_val_acc = accuracy_score(stage2_val_labels, stage2_val_preds)
            stage2_val_f1 = f1_score(stage2_val_labels, stage2_val_preds, average='macro')
            
            print(f"Stage 2 Validation - Acc: {stage2_val_acc:.4f}, F1: {stage2_val_f1:.4f}")
            
            # 创建阶段2验证混淆矩阵
            cm = confusion_matrix(stage2_val_labels, stage2_val_preds)
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=stage2_names, yticklabels=stage2_names)
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title(f'Val: Stage 2 (To some extent vs No)\nAcc: {stage2_val_acc:.4f}, F1: {stage2_val_f1:.4f}')
            matrix_path = os.path.join(val_plot_dir, f'stage2_cm_epoch_{epoch+1}.png')
            plt.savefig(matrix_path)
            plt.close()
        
        # 整体验证集评估
        val_acc = accuracy_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds, average='macro')
        
        print(f"Overall Validation - Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
        
        # 创建完整验证集混淆矩阵
        cm_full = confusion_matrix(val_labels, val_preds, labels=[0, 1, 2])
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm_full, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Val: Full Hierarchical Confusion Matrix\nAcc: {val_acc:.4f}, F1: {val_f1:.4f}')
        matrix_path = os.path.join(val_plot_dir, f'full_cm_epoch_{epoch+1}.png')
        plt.savefig(matrix_path)
        plt.close()
        
        # 绘制两两类别的验证集混淆矩阵
        class_pairs = [
            ([0, 1], ['Yes', 'To some extent']),  # Yes vs To some extent
            ([0, 2], ['Yes', 'No']),              # Yes vs No
            ([1, 2], ['To some extent', 'No'])    # To some extent vs No
        ]
        
        for classes_idx, classes_names in class_pairs:
            # 筛选出对应两个类别的预测和标签
            mask = np.isin(np.array(val_labels), classes_idx)
            filtered_preds = np.array(val_preds)[mask]
            filtered_labels = np.array(val_labels)[mask]
            
            # 计算此对类别的准确率和F1分数
            if len(filtered_labels) > 0:
                pair_acc = accuracy_score(filtered_labels, filtered_preds)
                # 计算二分类F1分数
                pair_f1 = f1_score(filtered_labels, filtered_preds, average='macro')
                
                # 创建混淆矩阵
                cm = confusion_matrix(filtered_labels, filtered_preds, labels=classes_idx)
                
                # 绘制混淆矩阵
                plt.figure(figsize=(8, 6))
                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                            xticklabels=[classes_names[i == classes_idx[1]] for i in classes_idx],
                            yticklabels=[classes_names[i == classes_idx[1]] for i in classes_idx])
                plt.xlabel('Predicted')
                plt.ylabel('True')
                plt.title(f'Val: {classes_names[0]} vs {classes_names[1]}\nAcc: {pair_acc:.4f}, F1: {pair_f1:.4f}')
                
                # 保存图表
                matrix_path = os.path.join(val_plot_dir, f'cm_{classes_names[0].replace(" ", "_")}_{classes_names[1].replace(" ", "_")}_epoch_{epoch+1}.png')
                plt.savefig(matrix_path)
                plt.close()
        
        # 检查是否保存模型并判断是否需要早停
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_val_acc = val_acc
            
            # 保存模型
            # torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best_hierarchical_model.pt'))
            print(f'New best model saved with F1: {best_val_f1:.4f}, Acc: {best_val_acc:.4f}')
            
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= configs.patience:
                print(f'Early stopping triggered after {epoch+1} epochs.')
                break
        
        # 返回训练状态
        model.train()
    
    print("\nTraining complete!")
    print(f"Best validation accuracy: {best_val_acc:.4f}")
    print(f"Best validation F1 score: {best_val_f1:.4f}")
    
    return model

# 修改主函数
if __name__ == '__main__':
    # 判断是否在Jupyter环境中运行
    try:
        # 检查是否在Jupyter中运行
        get_ipython = globals().get('get_ipython', None)
        if get_ipython and 'IPKernelApp' in get_ipython().config:
            # 在Jupyter环境中运行，使用默认配置
            print("Running in Jupyter environment, using default configs")
            configs = get_default_configs()
        else:
            # 在命令行环境中运行，使用argparse
            configs = argparser()
    except:
        # 任何异常都使用argparse处理
        configs = argparser()
    
    # 设置实验名称
    if configs.name is None:
        configs.exp_name = \
            f'hierarchical_{os.path.basename(configs.model_name)}' + \
            f'{"_fp" if configs.freeze_pooler else ""}' + \
            f'_b{configs.batch_size}_e{configs.epochs}' + \
            f'_len{configs.max_length}_lr{configs.lr}'
    else:
        configs.exp_name = configs.name
    
    # 设置设备
    if configs.device is None:
        configs.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu'
        )
    
    # 调用分层训练函数
    trained_model = train_hierarchical(configs)

Running in Jupyter environment, using default configs





===== Epoch 1/50 =====
Training Stage 1 (Yes vs Non-Yes)...


Stage 1: 100%|█████████████████████████████████████| 198/198 [00:17<00:00, 11.36batch/s, loss=0.267]


Stage 1 Training - Loss: 0.5325, Acc: 0.7061, F1: 0.6469
Training Stage 2 (To some extent vs No)...


Stage 2: 100%|█████████████████████████████████████| 198/198 [00:12<00:00, 16.47batch/s, loss=0.455]


Stage 2 Training - Loss: 0.4443, Acc: 0.8047, F1: 0.7973
Overall Training - Acc: 0.7585, F1: 0.7634
Evaluating on validation set...
Stage 1 Validation - Acc: 0.8045, F1: 0.7993
Stage 2 Validation - Acc: 0.8682, F1: 0.8680
Overall Validation - Acc: 0.7755, F1: 0.7769
New best model saved with F1: 0.7769, Acc: 0.7755

===== Epoch 2/50 =====
Training Stage 1 (Yes vs Non-Yes)...


Stage 1: 100%|█████████████████████████████████████| 198/198 [00:17<00:00, 11.60batch/s, loss=0.407]


Stage 1 Training - Loss: 0.4017, Acc: 0.7920, F1: 0.7809
Training Stage 2 (To some extent vs No)...


Stage 2: 100%|█████████████████████████████████████| 198/198 [00:11<00:00, 16.55batch/s, loss=0.273]


Stage 2 Training - Loss: 0.2756, Acc: 0.8745, F1: 0.8723
Overall Training - Acc: 0.8005, F1: 0.8040
Evaluating on validation set...
Stage 1 Validation - Acc: 0.8184, F1: 0.8162
Stage 2 Validation - Acc: 0.8722, F1: 0.8721
Overall Validation - Acc: 0.8033, F1: 0.8042
New best model saved with F1: 0.8042, Acc: 0.8033

===== Epoch 3/50 =====
Training Stage 1 (Yes vs Non-Yes)...


Stage 1: 100%|█████████████████████████████████████| 198/198 [00:17<00:00, 11.55batch/s, loss=0.419]


Stage 1 Training - Loss: 0.3497, Acc: 0.8153, F1: 0.8058
Training Stage 2 (To some extent vs No)...


Stage 2: 100%|█████████████████████████████████████| 198/198 [00:12<00:00, 16.47batch/s, loss=0.567]


Stage 2 Training - Loss: 0.2143, Acc: 0.8963, F1: 0.8947
Overall Training - Acc: 0.8327, F1: 0.8361
Evaluating on validation set...
Stage 1 Validation - Acc: 0.8260, F1: 0.8183
Stage 2 Validation - Acc: 0.8803, F1: 0.8803
Overall Validation - Acc: 0.7793, F1: 0.7805

===== Epoch 4/50 =====
Training Stage 1 (Yes vs Non-Yes)...


Stage 1: 100%|█████████████████████████████████████| 198/198 [00:17<00:00, 11.55batch/s, loss=0.278]


Stage 1 Training - Loss: 0.3162, Acc: 0.8456, F1: 0.8369
Training Stage 2 (To some extent vs No)...


Stage 2: 100%|█████████████████████████████████████| 198/198 [00:12<00:00, 16.44batch/s, loss=0.184]


Stage 2 Training - Loss: 0.1874, Acc: 0.9220, F1: 0.9210
Overall Training - Acc: 0.8415, F1: 0.8434
Evaluating on validation set...
Stage 1 Validation - Acc: 0.8310, F1: 0.8288
Stage 2 Validation - Acc: 0.8661, F1: 0.8647
Overall Validation - Acc: 0.8071, F1: 0.8060
New best model saved with F1: 0.8060, Acc: 0.8071

===== Epoch 5/50 =====
Training Stage 1 (Yes vs Non-Yes)...


Stage 1: 100%|█████████████████████████████████████| 198/198 [00:17<00:00, 11.54batch/s, loss=0.087]


Stage 1 Training - Loss: 0.2832, Acc: 0.8671, F1: 0.8581
Training Stage 2 (To some extent vs No)...


Stage 2: 100%|█████████████████████████████████████| 198/198 [00:11<00:00, 16.50batch/s, loss=0.211]


Stage 2 Training - Loss: 0.1524, Acc: 0.9360, F1: 0.9353
Overall Training - Acc: 0.8829, F1: 0.8843
Evaluating on validation set...
Stage 1 Validation - Acc: 0.8298, F1: 0.8258
Stage 2 Validation - Acc: 0.8600, F1: 0.8583
Overall Validation - Acc: 0.8045, F1: 0.8030

===== Epoch 6/50 =====
Training Stage 1 (Yes vs Non-Yes)...


Stage 1: 100%|█████████████████████████████████████| 198/198 [00:17<00:00, 11.51batch/s, loss=0.353]


Stage 1 Training - Loss: 0.2467, Acc: 0.8870, F1: 0.8782
Training Stage 2 (To some extent vs No)...


Stage 2: 100%|█████████████████████████████████████| 198/198 [00:12<00:00, 16.42batch/s, loss=0.160]


Stage 2 Training - Loss: 0.1177, Acc: 0.9554, F1: 0.9549
Overall Training - Acc: 0.9167, F1: 0.9190
Evaluating on validation set...
Stage 1 Validation - Acc: 0.8272, F1: 0.8225
Stage 2 Validation - Acc: 0.8742, F1: 0.8742
Overall Validation - Acc: 0.7957, F1: 0.7953

===== Epoch 7/50 =====
Training Stage 1 (Yes vs Non-Yes)...


Stage 1: 100%|█████████████████████████████████████| 198/198 [00:32<00:00,  6.19batch/s, loss=0.095]


Stage 1 Training - Loss: 0.2091, Acc: 0.9085, F1: 0.9010
Training Stage 2 (To some extent vs No)...


Stage 2: 100%|█████████████████████████████████████| 198/198 [00:24<00:00,  8.23batch/s, loss=0.035]


Stage 2 Training - Loss: 0.1015, Acc: 0.9593, F1: 0.9589
Overall Training - Acc: 0.9485, F1: 0.9495
Evaluating on validation set...
Stage 1 Validation - Acc: 0.8235, F1: 0.8163
Stage 2 Validation - Acc: 0.8763, F1: 0.8761
Overall Validation - Acc: 0.7957, F1: 0.7969

===== Epoch 8/50 =====
Training Stage 1 (Yes vs Non-Yes)...


Stage 1: 100%|█████████████████████████████████████| 198/198 [00:33<00:00,  5.84batch/s, loss=0.288]


Stage 1 Training - Loss: 0.1754, Acc: 0.9236, F1: 0.9171
Training Stage 2 (To some extent vs No)...


Stage 2:  70%|█████████████████████████▊           | 138/198 [00:13<00:05, 10.20batch/s, loss=0.003]


KeyboardInterrupt: 