In [1]:
import os
# 设置环境变量，只让程序看到 GPU 2
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


import torch
import torch.nn as nn
import wandb
import random
import argparse
import numpy as np
from tqdm import tqdm
from transformers import BertModel, AutoModel
from transformers import AdamW

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


# 数据预处理函数

In [2]:
from torch.utils.data import Dataset
import json

class BAE2025Dataset(Dataset):
    def __init__(
            self,
            data_path,
            labels={
                "Yes": 0,
                "To some extent": 1, 
                "No": 2,
            }
    ):
        self.data_path = data_path
        self.labels = labels
        self._get_data()
        
    def _get_data(self):
        with open(self.data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            
        self.data = []
        self.original_data = data  # 保存原始数据
        
        # 增加数据编号跟踪
        data_idx = 0
        for item_idx, item in enumerate(data):
            tutor_responses = item['tutor_responses']
            for response_id, response in tutor_responses.items():
                sent1 = item['conversation_history']
                sent2 = response['response']
                label = response['annotation']["Providing_Guidance"]
                if label in self.labels:
                    # 保存额外的索引信息，以便跟踪
                    self.data.append((
                        (sent1, sent2), 
                        self.labels[label],
                        {
                            'data_idx': data_idx,
                            'item_idx': item_idx,
                            'response_id': response_id
                        }
                    ))
                    data_idx += 1
        
    def __len__(self):
        return len(self.data)
        
    def get_labels(self):
        return self.labels
    
    def __getitem__(self, idx):
        return self.data[idx]

# 数据加载函数

In [3]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

class BAE2025DataLoader:
    def __init__(
        self,
        dataset,
        batch_size=16,
        max_length=512,
        shuffle=True,
        drop_last=True,
        device=None,
        # tokenizer_name='chinese-bert-wwm-ext'
        # tokenizer_name='chinese-roberta-wwm-ext'
        # tokenizer_name='chinese-roberta-wwm-ext-large'
        # tokenizer_name='/mnt/cfs/huangzhiwei/pykt-moekt/SBM/bge-large-en-v1.5'
        # tokenizer_name='/mnt/cfs/huangzhiwei/BAE2025/models/bge-base-en-v1.5'
        tokenizer_name='/mnt/cfs/huangzhiwei/BAE2025/models/bert-base-uncased'
        # tokenizer_name='/mnt/cfs/huangzhiwei/BAE2025/models/deberta-v3-base'
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.dataset = dataset
        self.batch_size = batch_size
        self.max_length = max_length
        self.shuffle = shuffle
        self.drop_last = drop_last
        
        if device is None:
            self.device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu'
            )
        else:
            self.device = device
        
        self.loader = DataLoader(
            dataset=self.dataset,
            batch_size=self.batch_size,
            collate_fn=self.collate_fn,
            shuffle=self.shuffle,
            drop_last=self.drop_last
        )
    
    def collate_fn(self, data):
        sents = [i[0] for i in data]
        labels = [i[1] for i in data]
        indices_info = [i[2] for i in data]  # 获取索引信息
        
        # 处理两个句子的情况
        encoded_data = self.tokenizer.batch_encode_plus(
            batch_text_or_text_pairs=[(sent[0], sent[1]) for sent in sents],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt',
            return_length=True
        )
        
        input_ids = encoded_data['input_ids'].to(self.device)
        attention_mask = encoded_data['attention_mask'].to(self.device)
        token_type_ids = encoded_data['token_type_ids'].to(self.device)
        labels = torch.LongTensor(labels).to(self.device)
        
        # 返回原始索引信息
        return input_ids, attention_mask, token_type_ids, labels, indices_info
    
    def __iter__(self):
        for data in self.loader:
            yield data
    
    def __len__(self):
        return len(self.loader)

# check 一下数据是不是对的

In [4]:
# # 校验上述数据加载器出来的数据是否正确，输出数据出来看一看
# train_data_path = '../data/train.json'
# val_data_path = '../data/valid.json'
# train_dataset = BAE2025Dataset(train_data_path)
# train_dataloader = BAE2025DataLoader(train_dataset, batch_size=1)
# val_dataset = BAE2025Dataset(val_data_path)
# val_dataloader = BAE2025DataLoader(val_dataset, batch_size=1)

# cnt_train=0
# for batch in train_dataloader:
#     cnt_train += 1
#     input_ids, attention_mask, token_type_ids, labels = batch
#     # print(input_ids.shape, attention_mask.shape, token_type_ids.shape, coarse_labels.shape, fine_labels.shape)
#     # print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape)
#     # break
# print("train data size:", cnt_train)

# cnt_val=0
# for batch in val_dataloader:
#     cnt_val += 1
#     input_ids, attention_mask, token_type_ids, labels = batch
#     # print(input_ids.shape, attention_mask.shape, token_type_ids.shape, coarse_labels.shape, fine_labels.shape)
#     # print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape)
#     # break
# print("val data size:", cnt_val)

In [5]:
# # 检查一下dataset中处理出来的数据是否正确
# for data in train_dataset:
#     sent, label = data
#     # print(sent, coarse_label, fine_label, sent_id)
#     print(sent, label)

# 模型代码

In [6]:
# import torch
# import torch.nn as nn
# from transformers import BertModel

# class BertClassifier(nn.Module):
#     def __init__(self, pretrained_model_name, num_classes=3, freeze_pooler=False, dropout=0.2):
#         super().__init__()

#         self.freeze_pooler = freeze_pooler
#         self.bert = BertModel.from_pretrained(pretrained_model_name)
#         self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
#         if freeze_pooler:
#             for param in self.bert.pooler.parameters():
#                 param.requires_grad = False

#         # self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
#         self.classifier = nn.Sequential(
#             nn.Linear(self.bert.config.hidden_size, 256),
#             nn.ReLU(),
#             nn.Linear(256, num_classes)
#         )

#     def forward(self, input_ids, attention_mask, token_type_ids):
#     # def forward(self, input_ids, attention_mask):
#         outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
#         # outputs = self.bert(input_ids, attention_mask=attention_mask)

#         pooled_output = outputs.pooler_output

#         pooled_output = self.dropout(pooled_output)

#         logits = self.classifier(pooled_output)
        
#         # logits = torch.sigmoid(logits)

#         return logits





import torch
import torch.nn as nn
from transformers import BertModel


class BertClassificationHead(nn.Module):
    def __init__(self, hidden_size=1024, num_classes=3, dropout_prob=0.3):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout_prob)
        self.out_proj = nn.Linear(hidden_size, num_classes)  # (输入维度，输出维度)

    def forward(self, features, **kwargs):
        x = features[-1][:, 0, :]  # features[-1]是一个三维张量，其维度为[批次大小, 序列长度, 隐藏大小]。
        x = self.dropout(x)  # 这是一种正则化技术，用于防止模型过拟合。在训练过程中，它通过随机将输入张量中的一部分元素设置为0，来增加模型的泛化能力。
        x = self.dense(x)  # 这是一个全连接层，它将输入特征映射到一个新的特征空间。这是通过学习一个权重矩阵和一个偏置向量，并使用它们对输入特征进行线性变换来实现的，方便后续可以引入非线性变换。
        x = torch.tanh(x)  # 这是一个激活函数，它将线性层的输出转换为非线性，使得模型可以学习并表示更复杂的模式。
        x = self.dropout(x)  # 增加模型的泛化能力。
        x = self.out_proj(x)  # 这是最后的全连接层，它将特征映射到最终的输出空间。在这个例子中，输出空间的维度等于分类任务的类别数量。
        return x
    

class BertClassifier(nn.Module):
    def __init__(self, pretrained_model_name, num_classes=3, freeze_pooler=0, dropout=0.3, hidden_size=768):
        super().__init__()
        
        self.bert = BertModel.from_pretrained(pretrained_model_name, output_hidden_states=True)
        
        # 冻结BERT底层，保留顶层微调
        if freeze_pooler > 0:
            modules = [self.bert.embeddings, *self.bert.encoder.layer[:freeze_pooler]]
            for module in modules:
                for param in module.parameters():
                    param.requires_grad = False
        
        self.dropout = nn.Dropout(dropout)
        
        # 获取bert隐藏层大小
        bert_hidden_size = self.bert.config.hidden_size
        
        self.classifier = BertClassificationHead(
            hidden_size=self.bert.config.hidden_size,
            num_classes=3,  # 三分类任务
            dropout_prob=dropout
        )
        
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        # 获取BERT输出
        outputs = self.bert(
            input_ids, 
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
            # output_hidden_states=True  # 获取所有隐藏层
        )
        
        # 使用[CLS]表示的序列表示
        # pooled_output = outputs.pooler_output
        
        # 可选：结合最后四层的[CLS]表示以获取更丰富的信息
        # last_4_layers = outputs.hidden_states[-4:]
        # cls_embeddings = torch.stack([layer[:, 0, :] for layer in last_4_layers], dim=0)
        # pooled_output = torch.mean(cls_embeddings, dim=0)  # 平均最后四层
        
        # 应用dropout
        # pooled_output = self.dropout1(pooled_output)
        
        # 分类
        logits = self.classifier(outputs.hidden_states)
        
        return logits

# 训练参数设置

In [7]:
# 如果在Jupyter Notebook中运行，可以使用这个自定义参数函数替代argparser
def get_default_configs():
    """在Jupyter环境中使用的默认配置，避免argparse解析错误"""
    class Args:
        def __init__(self):
            # self.model_name = '/mnt/cfs/huangzhiwei/pykt-moekt/SBM/bge-large-en-v1.5'
            # self.model_name = "/mnt/cfs/huangzhiwei/BAE2025/models/ModernBERT-large"
            # self.model_name = '/mnt/cfs/huangzhiwei/pykt-moekt/SBM/xlm-roberta-large'
            # self.model_name = '/mnt/cfs/huangzhiwei/BAE2025/models/bge-base-en-v1.5'
            self.model_name = '/mnt/cfs/huangzhiwei/BAE2025/models/bert-base-uncased'
            # self.model_name = '/mnt/cfs/huangzhiwei/BAE2025/models/deberta-v3-base'
            self.num_classes = 3
            self.dropout = 0.3
            self.freeze_pooler = 8
            self.batch_size = 16
            self.max_length = 512
            self.lr = 1e-5
            self.epochs = 50
            self.device = device
            self.name = None
            self.seed = 42
            self.data_path = '../data/train.json'
            self.val_data_path = '../data/valid.json'
            self.checkpoint_dir = 'checkpoints_error'
            self.patience = 8
            self.exp_name = 'BAE2025_track4_bert'
            
            # 新增参数 - 用于错误提取功能
            self.mode = 'train'  # 默认为训练模式，可选值: 'train', 'extract_errors'
            self.model_path = None  # 默认为None，当mode为'extract_errors'时使用
    
    return Args()

# 训练函数

In [8]:
def train(configs):
    
    # 设置随机种子
    random.seed(configs.seed)
    np.random.seed(configs.seed)
    torch.manual_seed(configs.seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # 创建检查点目录
    checkpoint_dir = os.path.join(configs.checkpoint_dir, configs.exp_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # 为保存混淆矩阵创建目录 - 分别为训练集和验证集创建
    train_plot_dir = os.path.join(checkpoint_dir, 'plots', 'train')
    val_plot_dir = os.path.join(checkpoint_dir, 'plots', 'val')
    os.makedirs(train_plot_dir, exist_ok=True)
    os.makedirs(val_plot_dir, exist_ok=True)
    
    # 创建错误预测保存目录
    error_analysis_dir = os.path.join(checkpoint_dir, 'error_analysis')
    os.makedirs(error_analysis_dir, exist_ok=True)
    
    # 加载数据集
    train_dataset = BAE2025Dataset(configs.data_path)
    val_dataset = BAE2025Dataset(configs.val_data_path)    

    # 创建数据加载器
    train_dataloader = BAE2025DataLoader(
        dataset=train_dataset,
        batch_size=configs.batch_size,
        max_length=configs.max_length,
        shuffle=True,
        drop_last=True,
        device=configs.device,
        tokenizer_name=configs.model_name
    )

    val_dataloader = BAE2025DataLoader(
        dataset=val_dataset,
        batch_size=configs.batch_size,
        max_length=configs.max_length,
        shuffle=False,
        drop_last=False,
        device=configs.device,
        tokenizer_name=configs.model_name
    )
    
    # 创建模型
    model = BertClassifier(
        pretrained_model_name=configs.model_name,
        num_classes=configs.num_classes,
        freeze_pooler=configs.freeze_pooler,
        dropout=configs.dropout
    ).to(configs.device)

    # 获取对应的tokenizer用于后面解码预测错误的样本
    tokenizer = AutoTokenizer.from_pretrained(configs.model_name)

    criterion = nn.CrossEntropyLoss()

    # 定义优化器
    optimizer = AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=configs.lr
    )

    # 初始化最佳验证损失和早停计数器
    best_val_acc = 0.0
    best_val_f1 = 0.0  # 添加F1分数作为评估指标
    best_val_loss = float('inf')
    patience_counter = 0
    
    # 定义类别名称
    class_names = ['Yes', 'To some extent', 'No']
    
    # 添加F1计算所需的库
    from sklearn.metrics import f1_score, confusion_matrix
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # 训练循环
    for epoch in range(configs.epochs):
        # 训练阶段
        model.train()
        train_loss = 0.0
        train_acc = 0.0
        train_preds = []
        train_labels_list = []
        
        with tqdm(
            train_dataloader,
            total=len(train_dataloader),
            desc=f'Epoch {epoch + 1}/{configs.epochs}',
            unit='batch',
            ncols=100
        ) as pbar:
            for input_ids, attention_mask, token_type_ids, labels, indices_info in pbar:
                optimizer.zero_grad()
                
                # 前向传播
                logits = model(input_ids, attention_mask, token_type_ids)
                
                # 计算损失 - 确保labels是长整型
                labels = labels.long()
                loss = criterion(logits, labels)
                
                # 反向传播
                loss.backward()
                optimizer.step()
                
                preds = logits.argmax(dim=1)
                accuracy = (preds == labels).float().mean()
                accuracy_all = (preds == labels).float().sum()
                
                # 收集预测结果和真实标签，用于计算F1
                train_preds.extend(preds.cpu().numpy())
                train_labels_list.extend(labels.cpu().numpy())
                
                train_loss += loss.item()
                train_acc += accuracy_all.item()
                
                pbar.set_postfix(
                    loss=f'{loss.item():.3f}',
                    accuracy=f'{accuracy.item():.3f}'
                )
        
        train_loss = train_loss / len(train_dataloader)
        train_acc = train_acc / len(train_dataset)
        
        # 计算训练集的F1分数 - 使用macro平均以处理多分类
        train_f1 = f1_score(train_labels_list, train_preds, average='macro')
        
        print(f'Training Loss: {train_loss:.4f}')
        print(f'Training Accuracy: {train_acc:.4f}')
        print(f'Training F1 Score: {train_f1:.4f}')
        
        # 创建训练集的混淆矩阵
        # 创建三个二分类混淆矩阵（两两类别之间）
        class_pairs = [
            ([0, 1], ['Yes', 'To some extent']),  # Yes vs To some extent
            ([0, 2], ['Yes', 'No']),              # Yes vs No
            ([1, 2], ['To some extent', 'No'])    # To some extent vs No
        ]
        
        for classes_idx, classes_names in class_pairs:
            # 筛选出对应两个类别的预测和标签
            mask = np.isin(np.array(train_labels_list), classes_idx)
            filtered_preds = np.array(train_preds)[mask]
            filtered_labels = np.array(train_labels_list)[mask]
            
            # 创建混淆矩阵
            cm = confusion_matrix(filtered_labels, filtered_preds, labels=classes_idx)
            
            # 计算此对类别的准确率和F1分数
            pair_mask = np.isin(np.array(train_labels_list), classes_idx)
            pair_acc = np.mean(np.array(train_preds)[pair_mask] == np.array(train_labels_list)[pair_mask])
            # 计算二分类F1分数
            pair_f1 = f1_score(
                np.array(train_labels_list)[pair_mask], 
                np.array(train_preds)[pair_mask], 
                labels=classes_idx, 
                average='macro'
            )
            
            # 绘制混淆矩阵
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                        xticklabels=[classes_names[i == classes_idx[1]] for i in classes_idx],
                        yticklabels=[classes_names[i == classes_idx[1]] for i in classes_idx])
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title(f'Train: {classes_names[0]} vs {classes_names[1]}\nAcc: {pair_acc:.4f}, F1: {pair_f1:.4f}')
            
            # 保存图表
            matrix_path = os.path.join(train_plot_dir, f'cm_{classes_names[0].replace(" ", "_")}_{classes_names[1].replace(" ", "_")}_epoch_{epoch+1}.png')
            plt.savefig(matrix_path)
            plt.close()
            
        # 创建完整的三分类混淆矩阵
        cm_full = confusion_matrix(train_labels_list, train_preds, labels=[0, 1, 2])
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm_full, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names,
                    yticklabels=class_names)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Train: Full Confusion Matrix\nAcc: {train_acc:.4f}, F1: {train_f1:.4f}')
        
        # 保存完整混淆矩阵
        matrix_path = os.path.join(train_plot_dir, f'cm_full_epoch_{epoch+1}.png')
        plt.savefig(matrix_path)
        plt.close()
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        val_corrects = 0.0
        val_preds = []
        val_labels_list = []
        
        # 用于存储错误预测的信息
        error_examples = []

        with torch.no_grad():
            for batch_idx, (input_ids, attention_mask, token_type_ids, labels, indices_info) in enumerate(val_dataloader):
                # 确保labels是长整型
                labels = labels.long()
                
                # 前向传播
                logits = model(input_ids, attention_mask, token_type_ids)
                
                loss = criterion(logits, labels)
                val_loss += loss.item()
                preds = logits.argmax(dim=1)
                
                # 找出本批次中预测错误的样本
                error_indices = (preds != labels).nonzero(as_tuple=True)[0]
                
                # 如果有错误预测，使用原始索引获取文本内容并保存
                if len(error_indices) > 0:
                    for idx in error_indices:
                        # 获取样本的原始索引信息
                        sample_indices = indices_info[idx]
                        data_idx = sample_indices['data_idx']
                        item_idx = sample_indices['item_idx']
                        response_id = sample_indices['response_id']
                        
                        # 从原始数据中获取文本，而不是解码
                        original_item = val_dataset.original_data[item_idx]
                        conversation_history = original_item['conversation_history']
                        response_text = original_item['tutor_responses'][response_id]['response']
                        
                        # 获取真实标签和预测标签
                        true_label = labels[idx].item()
                        pred_label = preds[idx].item()
                        
                        # 记录错误预测信息 - 使用原始文本
                        error_examples.append({
                            'first_sentence': conversation_history,
                            'second_sentence': response_text,
                            'true_label': class_names[true_label],
                            'pred_label': class_names[pred_label],
                            'data_idx': data_idx,
                            'item_idx': item_idx,
                            'response_id': response_id
                        })
                
                accuracy = (preds == labels).float().sum()
                val_corrects += accuracy
                
                # 收集预测结果和真实标签，用于计算F1和混淆矩阵
                val_preds.extend(preds.cpu().numpy())
                val_labels_list.extend(labels.cpu().numpy())
        
        val_loss = val_loss / len(val_dataloader)
        val_acc = val_corrects.double() / len(val_dataset)
        
        # 计算验证集的F1分数
        val_f1 = f1_score(val_labels_list, val_preds, average='macro')
        
        print('Validation Loss: {:.4f} Acc: {:.4f} F1: {:.4f}'.format(val_loss, val_acc, val_f1))
        print(f'Number of error examples: {len(error_examples)}')
        
        # 保存错误预测的样本到JSON文件
        error_file_path = os.path.join(error_analysis_dir, f'error_examples_epoch_{epoch+1}.json')
        with open(error_file_path, 'w', encoding='utf-8') as f:
            json.dump(error_examples, f, ensure_ascii=False, indent=4)
        
        print(f'Error examples saved to {error_file_path}')
        
        # 创建验证集三个二分类混淆矩阵（两两类别之间）
        class_pairs = [
            ([0, 1], ['Yes', 'To some extent']),  # Yes vs To some extent
            ([0, 2], ['Yes', 'No']),              # Yes vs No
            ([1, 2], ['To some extent', 'No'])    # To some extent vs No
        ]
        
        for classes_idx, classes_names in class_pairs:
            # 筛选出对应两个类别的预测和标签
            mask = np.isin(np.array(val_labels_list), classes_idx)
            filtered_preds = np.array(val_preds)[mask]
            filtered_labels = np.array(val_labels_list)[mask]
            
            # 创建混淆矩阵
            cm = confusion_matrix(filtered_labels, filtered_preds, labels=classes_idx)
            
            # 计算此对类别的准确率和F1分数
            pair_mask = np.isin(np.array(val_labels_list), classes_idx)
            pair_acc = np.mean(np.array(val_preds)[pair_mask] == np.array(val_labels_list)[pair_mask])
            # 计算二分类F1分数
            pair_f1 = f1_score(
                np.array(val_labels_list)[pair_mask], 
                np.array(val_preds)[pair_mask], 
                labels=classes_idx, 
                average='macro'
            )
            
            # 绘制混淆矩阵
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                        xticklabels=[classes_names[i == classes_idx[1]] for i in classes_idx],
                        yticklabels=[classes_names[i == classes_idx[1]] for i in classes_idx])
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title(f'Val: {classes_names[0]} vs {classes_names[1]}\nAcc: {pair_acc:.4f}, F1: {pair_f1:.4f}')
            
            # 保存图表
            matrix_path = os.path.join(val_plot_dir, f'cm_{classes_names[0].replace(" ", "_")}_{classes_names[1].replace(" ", "_")}_epoch_{epoch+1}.png')
            plt.savefig(matrix_path)
            plt.close()
            
        # 创建完整的三分类混淆矩阵
        cm_full = confusion_matrix(val_labels_list, val_preds, labels=[0, 1, 2])
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm_full, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names,
                    yticklabels=class_names)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Val: Full Confusion Matrix\nAcc: {val_acc:.4f}, F1: {val_f1:.4f}')
        
        # 保存完整混淆矩阵
        matrix_path = os.path.join(val_plot_dir, f'cm_full_epoch_{epoch+1}.png')
        plt.savefig(matrix_path)
        plt.close()
        
        # 检查是否保存模型并判断是否需要早停
        # 使用F1分数作为主要指标
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_val_acc = val_acc
            
            # 保存模型
            # state_dict = model.state_dict()
            # torch.save(state_dict, os.path.join(checkpoint_dir, 'best_model_f1.pt'))
            print(f'New best model saved with F1: {best_val_f1:.4f}, Acc: {best_val_acc:.4f}')
            
            # 同时保存这个最佳模型的错误预测样本
            best_error_file_path = os.path.join(error_analysis_dir, 'best_model_error_examples.json')
            with open(best_error_file_path, 'w', encoding='utf-8') as f:
                json.dump(error_examples, f, ensure_ascii=False, indent=4)
            print(f'Best model error examples saved to {best_error_file_path}')
            
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= configs.patience:
                print(f'Early stopping triggered after {epoch+1} epochs.')
                break

        model.train()


# 同样修改extract_error_examples函数
def extract_error_examples(configs, model_path):
    """
    从已训练的模型中提取验证集上的错误预测
    
    Args:
        configs: 配置对象
        model_path: 已训练模型的路径
    """
    import json
    from transformers import AutoTokenizer
    
    # 确保输出目录存在
    error_analysis_dir = os.path.join(configs.checkpoint_dir, configs.exp_name, 'error_analysis')
    os.makedirs(error_analysis_dir, exist_ok=True)
    
    # 加载验证数据集
    val_dataset = BAE2025Dataset(configs.val_data_path)
    
    # 创建数据加载器
    val_dataloader = BAE2025DataLoader(
        dataset=val_dataset,
        batch_size=configs.batch_size,
        max_length=configs.max_length,
        shuffle=False,
        drop_last=False,
        device=configs.device,
        tokenizer_name=configs.model_name
    )
    
    # 加载tokenizer
    tokenizer = AutoTokenizer.from_pretrained(configs.model_name)
    
    # 创建模型
    model = BertClassifier(
        pretrained_model_name=configs.model_name,
        num_classes=configs.num_classes,
        freeze_pooler=configs.freeze_pooler,
        dropout=configs.dropout
    ).to(configs.device)
    
    # 加载模型权重
    model.load_state_dict(torch.load(model_path, map_location=configs.device))
    model.eval()
    
    # 定义类别名称
    class_names = ['Yes', 'To some extent', 'No']
    
    # 用于存储错误预测的信息
    error_examples = []
    all_examples = []  # 存储所有样本信息（可选）
    
    # 预测和收集错误
    with torch.no_grad():
        for batch_idx, (input_ids, attention_mask, token_type_ids, labels, indices_info) in enumerate(tqdm(val_dataloader, desc="Evaluating")):
            # 确保labels是长整型
            labels = labels.long()
            
            # 前向传播
            logits = model(input_ids, attention_mask, token_type_ids)
            preds = logits.argmax(dim=1)
            
            # 对批次中的每个样本进行处理
            for idx in range(len(labels)):
                # 获取样本的原始索引信息
                sample_indices = indices_info[idx]
                data_idx = sample_indices['data_idx']
                item_idx = sample_indices['item_idx']
                response_id = sample_indices['response_id']
                
                # 从原始数据中获取文本
                original_item = val_dataset.original_data[item_idx]
                conversation_history = original_item['conversation_history']
                response_text = original_item['tutor_responses'][response_id]['response']
                
                # 获取真实标签和预测标签
                true_label = labels[idx].item()
                pred_label = preds[idx].item()
                
                # 创建样本信息
                example_info = {
                    'first_sentence': conversation_history,
                    'second_sentence': response_text,
                    'true_label': class_names[true_label],
                    'pred_label': class_names[pred_label],
                    'is_error': true_label != pred_label,
                    'data_idx': data_idx,
                    'item_idx': item_idx,
                    'response_id': response_id
                }
                
                # 如果是错误预测，添加到错误样本列表
                if true_label != pred_label:
                    error_examples.append(example_info)
                
                # 可选：存储所有样本信息
                all_examples.append(example_info)
    
    # 保存错误预测的样本到JSON文件
    error_file_path = os.path.join(error_analysis_dir, 'error_examples.json')
    with open(error_file_path, 'w', encoding='utf-8') as f:
        json.dump(error_examples, f, ensure_ascii=False, indent=4)
    
    print(f'Found {len(error_examples)} error examples out of {len(val_dataset)} samples ({len(error_examples)/len(val_dataset)*100:.2f}%)')
    print(f'Error examples saved to {error_file_path}')
    
    # 可选：保存所有样本的预测结果
    all_examples_path = os.path.join(error_analysis_dir, 'all_examples.json')
    with open(all_examples_path, 'w', encoding='utf-8') as f:
        json.dump(all_examples, f, ensure_ascii=False, indent=4)
    
    print(f'All examples with predictions saved to {all_examples_path}')
    
    return error_examples


# 在以下主函数中添加判断Jupyter环境的逻辑
if __name__ == '__main__':
    # 判断是否在Jupyter环境中运行
    try:
        # 检查是否在Jupyter中运行
        get_ipython = globals().get('get_ipython', None)
        if get_ipython and 'IPKernelApp' in get_ipython().config:
            # 在Jupyter环境中运行，使用默认配置
            print("Running in Jupyter environment, using default configs")
            configs = get_default_configs()
        else:
            # 在命令行环境中运行，使用argparse
            configs = argparser()
    except:
        # 任何异常都使用argparse处理
        configs = argparser()
    
    # 设置实验名称
    if configs.name is None:
        configs.exp_name = \
            f'{os.path.basename(configs.model_name)}' + \
            f'{"_fp" if configs.freeze_pooler else ""}' + \
            f'_b{configs.batch_size}_e{configs.epochs}' + \
            f'_len{configs.max_length}_lr{configs.lr}'
    else:
        configs.exp_name = configs.name
    
    # 设置设备
    if configs.device is None:
        configs.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu'
        )
    
    # 根据命令行参数选择操作模式
    if hasattr(configs, 'mode') and configs.mode == 'extract_errors':
        # 如果是提取错误模式，则运行错误提取函数
        if not hasattr(configs, 'model_path') or configs.model_path is None:
            # 如果没有指定模型路径，使用默认的最佳模型路径
            configs.model_path = os.path.join(configs.checkpoint_dir, configs.exp_name, 'best_model_f1.pt')
        
        print(f"Extracting error examples using model: {configs.model_path}")
        extract_error_examples(configs, configs.model_path)
    else:
        # 正常训练模式
        train(configs)

Running in Jupyter environment, using default configs


Epoch 1/50: 100%|██████████████████| 123/123 [00:10<00:00, 11.31batch/s, accuracy=0.688, loss=0.930]


Training Loss: 0.9581
Training Accuracy: 0.5878
Training F1 Score: 0.3560
Validation Loss: 0.8338 Acc: 0.6687 F1: 0.4265
Number of error examples: 166
Error examples saved to checkpoints_error/bert-base-uncased_fp_b16_e50_len512_lr1e-05/error_analysis/error_examples_epoch_1.json
New best model saved with F1: 0.4265, Acc: 0.6687
Best model error examples saved to checkpoints_error/bert-base-uncased_fp_b16_e50_len512_lr1e-05/error_analysis/best_model_error_examples.json


Epoch 2/50: 100%|██████████████████| 123/123 [00:10<00:00, 11.71batch/s, accuracy=0.500, loss=1.000]


Training Loss: 0.8540
Training Accuracy: 0.6466
Training F1 Score: 0.4491
Validation Loss: 0.8012 Acc: 0.6946 F1: 0.4789
Number of error examples: 153
Error examples saved to checkpoints_error/bert-base-uncased_fp_b16_e50_len512_lr1e-05/error_analysis/error_examples_epoch_2.json
New best model saved with F1: 0.4789, Acc: 0.6946
Best model error examples saved to checkpoints_error/bert-base-uncased_fp_b16_e50_len512_lr1e-05/error_analysis/best_model_error_examples.json


Epoch 3/50: 100%|██████████████████| 123/123 [00:10<00:00, 11.67batch/s, accuracy=0.625, loss=0.739]


Training Loss: 0.8079
Training Accuracy: 0.6592
Training F1 Score: 0.4715
Validation Loss: 0.7987 Acc: 0.7126 F1: 0.5541
Number of error examples: 144
Error examples saved to checkpoints_error/bert-base-uncased_fp_b16_e50_len512_lr1e-05/error_analysis/error_examples_epoch_3.json
New best model saved with F1: 0.5541, Acc: 0.7126
Best model error examples saved to checkpoints_error/bert-base-uncased_fp_b16_e50_len512_lr1e-05/error_analysis/best_model_error_examples.json


Epoch 4/50:  41%|███████▋           | 50/123 [00:04<00:06, 11.42batch/s, accuracy=0.750, loss=0.679]


KeyboardInterrupt: 