In [1]:
import json
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, classification_report
import torch
from torch.utils.data import Dataset, DataLoader
# from transformers import DebertaV3Tokenizer, DebertaV3ForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification, AdamW, get_linear_schedule_with_warmup

from tqdm import tqdm
import re



# 添加数据文件路径到配置
CONFIG = {
    # 数据参数
    'data_file': '/mnt/cfs/huangzhiwei/BAE2025/data/mrbench_v3_devset.json',            # 数据文件路径
    
    # 模型参数
    'model_name': '/mnt/cfs/huangzhiwei/BAE2025/models/deberta-v3-base',  # 模型名称
    'max_length': 512,                          # 文本最大长度
    'dropout': 0.2,                             # dropout率
    
    # 训练参数
    'learning_rate': 2e-5,                      # 学习率
    'batch_size': 16,                           # 批次大小
    'epochs': 40,                               # 最大训练轮次
    'warmup_steps': 0,                          # 预热步数
    'weight_decay': 0.01,                       # 权重衰减
    'gradient_clip': 1.0,                       # 梯度裁剪
    
    # 早停参数
    'early_stopping': True,                     # 是否启用早停
    'patience': 6,                              # 容忍轮次
    'min_delta': 0.001,                         # 最小改进阈值
    'metric_for_early_stopping': 'macro_f1',    # 早停指标 ('macro_f1', 'weighted_f1', 'accuracy')
    
    # 交叉验证参数
    'n_splits': 5,                              # 交叉验证折数
    'random_state': 42,                         # 随机种子
    
    # 其他配置
    'save_model': False,                        # 是否保存最佳模型
    'model_path': '/mnt/cfs/huangzhiwei/BAE2025/projects/checkpoints_debertav3/best_model.pt',              # 模型保存路径
}



# 数据加载函数
def load_data(file_path):
    """
    从文件加载数据
    
    参数:
        file_path: 文件路径，可以是JSON文件或包含JSON的文本文件
        
    返回:
        数据列表
    """
    print(f"Loading data from {file_path}")
    
    # 检查文件扩展名
    if file_path.endswith('.json'):
        # 直接作为JSON文件读取
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                print(f"Successfully loaded JSON data with {len(data)} entries")
                return data
        except json.JSONDecodeError as e:
            print(f"Error parsing JSON file: {e}")
            raise
    else:
        # 尝试从文本文件中提取JSON内容
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # 尝试从文本中提取JSON数组
        try:
            # 查找方括号包围的内容
            json_pattern = re.compile(r'\[\s*\{.*?\}\s*(?:,\s*\.{3,}\s*)?\]', re.DOTALL)
            json_match = json_pattern.search(content)
            
            if json_match:
                # 替换省略号为空列表
                json_str = json_match.group().replace('......', '[]')
                data = json.loads(json_str)
                print(f"Successfully extracted JSON data from text file with {len(data)} entries")
                return data
            else:
                # 尝试解析整个文件作为JSON
                try:
                    data = json.loads(content)
                    print(f"Successfully parsed entire file as JSON with {len(data)} entries")
                    return data
                except json.JSONDecodeError:
                    raise ValueError("Could not extract JSON data from the input file")
        except Exception as e:
            print(f"Error processing file: {e}")
            raise# 早停类
class EarlyStopping:
    def __init__(self, patience=CONFIG['patience'], min_delta=CONFIG['min_delta'], metric=CONFIG['metric_for_early_stopping']):
        """
        初始化早停机制
        
        参数:
            patience: 在停止前等待的轮次数量
            min_delta: 被视为改进的最小变化量
            metric: 用于早停的指标 ('macro_f1', 'weighted_f1', 'accuracy')
        """
        self.patience = patience
        self.min_delta = min_delta
        self.metric = metric
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_model_state = None
        
    def __call__(self, model, current_metrics):
        """
        检查是否应该停止训练
        
        参数:
            model: 当前模型
            current_metrics: 当前评估指标字典 (包含 'accuracy', 'macro_f1', 'weighted_f1')
            
        返回:
            early_stop: 是否应该停止训练
        """
        score = current_metrics[self.metric]
        
        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict().copy()
            return False
        
        # 如果当前分数没有显著改进
        if score < self.best_score + self.min_delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict().copy()
            self.counter = 0
            
        return self.early_stop


# Load the data
data = load_data(CONFIG['data_file'])

# Prepare dataset
texts = []
labels = []

label_map = {"Yes": 2, "To some extent": 1, "No": 0}
rev_label_map = {2: "Yes", 1: "To some extent", 0: "No"}

for item in data:
    conversation_history = item["conversation_history"]
    
    # Process each tutor response
    for tutor, response_data in item["tutor_responses"].items():
        if "annotation" in response_data and "Actionability" in response_data["annotation"]:
            # Include the tutor's response as the input text
            texts.append(response_data["response"])
            
            # Map the guidance label to a numeric value
            guidance_label = response_data["annotation"]["Actionability"]
            labels.append(label_map[guidance_label])

# Convert to pandas DataFrame for easier manipulation
df = pd.DataFrame({
    "text": texts,
    "label": labels
})

print(f"Total samples: {len(df)}")
print(f"Label distribution: {df['label'].value_counts().to_dict()}")

# Create a PyTorch dataset
class GuidanceDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=CONFIG['max_length']):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Training function
def train_model(model, train_dataloader, val_dataloader, device, epochs=CONFIG['epochs']):
    
    # # 首先冻结所有层
    # for param in model.parameters():
    #     param.requires_grad = False

    # # 然后只解冻最后几层
    # # 解冻分类器
    # for param in model.classifier.parameters():
    #     param.requires_grad = True

    # # 解冻最后N个encoder层（例如最后3层）
    # for i in range(len(model.deberta.encoder.layer) - 4, len(model.deberta.encoder.layer)):
    #     for param in model.deberta.encoder.layer[i].parameters():
    #         param.requires_grad = True
    
    optimizer = AdamW(model.parameters(), 
                     lr=CONFIG['learning_rate'], 
                     weight_decay=CONFIG['weight_decay'])
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=CONFIG['warmup_steps'], 
        num_training_steps=total_steps
    )
    
    # 初始化最佳指标和模型状态
    best_val_metric = 0
    best_model_state = None
    
    # 初始化早停机制
    early_stopping = None
    if CONFIG['early_stopping']:
        early_stopping = EarlyStopping(
            patience=CONFIG['patience'],
            min_delta=CONFIG['min_delta'],
            metric=CONFIG['metric_for_early_stopping']
        )
    
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        
        # Training
        model.train()
        train_loss = 0
        train_preds = []
        train_true = []
        
        for batch in tqdm(train_dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            model.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            train_loss += loss.item()
            
            # Get predictions for metrics
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            train_preds.extend(preds)
            train_true.extend(labels.cpu().numpy())
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['gradient_clip'])
            optimizer.step()
            scheduler.step()
        
        # Calculate training metrics
        avg_train_loss = train_loss / len(train_dataloader)
        train_acc = accuracy_score(train_true, train_preds)
        train_report = classification_report(train_true, train_preds, target_names=[rev_label_map[i] for i in range(3)], output_dict=True)
        train_f1_weighted = train_report['weighted avg']['f1-score']
        train_f1_macro = train_report['macro avg']['f1-score']
        
        print(f"Training loss: {avg_train_loss:.4f}")
        print(f"Training accuracy: {train_acc:.4f}")
        print(f"Training weighted F1 score: {train_f1_weighted:.4f}")
        print(f"Training macro F1 score: {train_f1_macro:.4f}")
        
        # Validation
        model.eval()
        val_preds = []
        val_true = []
        
        with torch.no_grad():
            for batch in tqdm(val_dataloader):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                preds = torch.argmax(logits, dim=1).cpu().numpy()
                
                val_preds.extend(preds)
                val_true.extend(labels.cpu().numpy())
        
        val_acc = accuracy_score(val_true, val_preds)
        val_report = classification_report(val_true, val_preds, target_names=[rev_label_map[i] for i in range(3)], output_dict=True)
        val_f1_weighted = val_report['weighted avg']['f1-score']
        val_f1_macro = val_report['macro avg']['f1-score']
        
        print(f"Validation accuracy: {val_acc:.4f}")
        print(f"Validation weighted F1 score: {val_f1_weighted:.4f}")
        print(f"Validation macro F1 score: {val_f1_macro:.4f}")
        print(classification_report(val_true, val_preds, target_names=[rev_label_map[i] for i in range(3)]))
        
        # 跟踪当前指标，用于早停和保存最佳模型
        current_metrics = {
            'accuracy': val_acc,
            'weighted_f1': val_f1_weighted,
            'macro_f1': val_f1_macro
        }
        
        # 跟踪最佳模型（基于与早停相同的指标）
        current_metric = current_metrics[CONFIG['metric_for_early_stopping']]
        if current_metric > best_val_metric:
            best_val_metric = current_metric
            best_model_state = model.state_dict().copy()
            
            # 如果配置了保存模型，则保存模型
            if CONFIG['save_model']:
                torch.save(model.state_dict(), CONFIG['model_path'])
                print(f"Model saved to {CONFIG['model_path']} (best {CONFIG['metric_for_early_stopping']}: {best_val_metric:.4f})")
        
        # 早停检查
        if CONFIG['early_stopping'] and early_stopping(model, current_metrics):
            print(f"Early stopping triggered after epoch {epoch + 1}")
            # 恢复最佳模型状态
            model.load_state_dict(early_stopping.best_model_state)
            break
    
    # 如果训练完成但没有触发早停，确保加载最佳模型
    if best_model_state is not None and (not CONFIG['early_stopping'] or not early_stopping.early_stop):
        model.load_state_dict(best_model_state)
        print(f"Training completed. Loaded best model with {CONFIG['metric_for_early_stopping']}: {best_val_metric:.4f}")
    
    return model

# 5-fold cross-validation
def perform_cross_validation(df, n_splits=CONFIG['n_splits'], batch_size=CONFIG['batch_size'], epochs=CONFIG['epochs']):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=CONFIG['random_state'])
    fold_results = []
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load the tokenizer
    # tokenizer = DebertaV3Tokenizer.from_pretrained(CONFIG['model_name'])
    tokenizer = DebertaV2Tokenizer.from_pretrained(CONFIG['model_name'])
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(df)):
        print(f"Fold {fold + 1}/{n_splits}")
        
        train_df = df.iloc[train_idx]
        val_df = df.iloc[val_idx]
        
        # Create datasets
        train_dataset = GuidanceDataset(
            train_df['text'].tolist(),
            train_df['label'].tolist(),
            tokenizer
        )
        
        val_dataset = GuidanceDataset(
            val_df['text'].tolist(),
            val_df['label'].tolist(),
            tokenizer
        )
        
        # Create dataloaders
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True
        )
        
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=batch_size
        )
        
        # Initialize model
        # model = DebertaV3ForSequenceClassification.from_pretrained(
        #     CONFIG['model_name'],
        #     num_labels=3,
        #     hidden_dropout_prob=CONFIG['dropout'],
        #     attention_probs_dropout_prob=CONFIG['dropout']
        # ).to(device)
        
        model = DebertaV2ForSequenceClassification.from_pretrained(
            CONFIG['model_name'],
            num_labels=3,
            hidden_dropout_prob=CONFIG['dropout'],
            attention_probs_dropout_prob=CONFIG['dropout']
        ).to(device)
        
        # Train and evaluate
        model = train_model(
            model,
            train_dataloader,
            val_dataloader,
            device,
            epochs=epochs
        )
        
        # Final evaluation on validation set
        model.eval()
        val_preds = []
        val_true = []
        
        with torch.no_grad():
            for batch in DataLoader(val_dataset, batch_size=batch_size):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                preds = torch.argmax(logits, dim=1).cpu().numpy()
                
                val_preds.extend(preds)
                val_true.extend(labels.cpu().numpy())
        
        fold_acc = accuracy_score(val_true, val_preds)
        fold_report = classification_report(
            val_true, 
            val_preds, 
            target_names=[rev_label_map[i] for i in range(3)],
            output_dict=True
        )
        
        fold_f1_weighted = fold_report['weighted avg']['f1-score']
        fold_f1_macro = fold_report['macro avg']['f1-score']
        
        fold_results.append({
            'fold': fold + 1,
            'accuracy': fold_acc,
            'f1_weighted': fold_f1_weighted,
            'f1_macro': fold_f1_macro,
            'report': fold_report
        })
        
        print(f"Fold {fold + 1} accuracy: {fold_acc:.4f}")
        print(f"Fold {fold + 1} weighted F1 score: {fold_f1_weighted:.4f}")
        print(f"Fold {fold + 1} macro F1 score: {fold_f1_macro:.4f}")
    
    # Calculate average metrics across folds
    avg_acc = np.mean([result['accuracy'] for result in fold_results])
    avg_f1_weighted = np.mean([result['f1_weighted'] for result in fold_results])
    avg_f1_macro = np.mean([result['f1_macro'] for result in fold_results])
    
    print(f"Average accuracy across {n_splits} folds: {avg_acc:.4f}")
    print(f"Average weighted F1 score across {n_splits} folds: {avg_f1_weighted:.4f}")
    print(f"Average macro F1 score across {n_splits} folds: {avg_f1_macro:.4f}")
    
    # Calculate average per-class metrics
    class_metrics = {label: {'precision': [], 'recall': [], 'f1-score': []} for label in rev_label_map.values()}
    
    for result in fold_results:
        for label in rev_label_map.values():
            if label in result['report']:
                class_metrics[label]['precision'].append(result['report'][label]['precision'])
                class_metrics[label]['recall'].append(result['report'][label]['recall'])
                class_metrics[label]['f1-score'].append(result['report'][label]['f1-score'])
    
    print("\nAverage per-class metrics across all folds:")
    for label, metrics in class_metrics.items():
        print(f"{label}:")
        for metric_name, values in metrics.items():
            avg_value = np.mean(values) if values else 0
            print(f"  {metric_name}: {avg_value:.4f}")
    
    return fold_results

# Save processed data to CSV (optional)
df.to_csv('/mnt/cfs/huangzhiwei/BAE2025/projects/checkpoints_debertav3/guidance_classification_data.csv', index=False)

# Execute cross-validation with the specified parameters
print("Starting cross-validation with the following parameters:")
for param, value in CONFIG.items():
    print(f"  {param}: {value}")

fold_results = perform_cross_validation(
    df, 
    n_splits=CONFIG['n_splits'],
    batch_size=CONFIG['batch_size'],
    epochs=CONFIG['epochs']
)

# Save the fold results (optional)
with open('/mnt/cfs/huangzhiwei/BAE2025/projects/checkpoints_debertav3/fold_results.json', 'w') as f:
    json.dump(fold_results, f, indent=2)

print("Classification task completed successfully.")

  from .autonotebook import tqdm as notebook_tqdm


Loading data from /mnt/cfs/huangzhiwei/BAE2025/data/mrbench_v3_devset.json
Successfully loaded JSON data with 300 entries
Total samples: 2476
Label distribution: {2: 1310, 0: 797, 1: 369}
Starting cross-validation with the following parameters:
  data_file: /mnt/cfs/huangzhiwei/BAE2025/data/mrbench_v3_devset.json
  model_name: /mnt/cfs/huangzhiwei/BAE2025/models/deberta-v3-base
  max_length: 512
  dropout: 0.2
  learning_rate: 2e-05
  batch_size: 16
  epochs: 40
  warmup_steps: 0
  weight_decay: 0.01
  gradient_clip: 1.0
  early_stopping: True
  patience: 6
  min_delta: 0.001
  metric_for_early_stopping: macro_f1
  n_splits: 5
  random_state: 42
  save_model: False
  model_path: /mnt/cfs/huangzhiwei/BAE2025/projects/checkpoints_debertav3/best_model.pt
Using device: cuda


Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at /mnt/cfs/huangzhiwei/BAE2025/models/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 1/5




Epoch 1/40


100%|██████████| 124/124 [00:43<00:00,  2.85it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training loss: 0.9623
Training accuracy: 0.5359
Training weighted F1 score: 0.4214
Training macro F1 score: 0.2852


100%|██████████| 31/31 [00:04<00:00,  7.71it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation accuracy: 0.6794
Validation weighted F1 score: 0.6247
Validation macro F1 score: 0.4831
                precision    recall  f1-score   support

            No       0.74      0.66      0.70       176
To some extent       0.00      0.00      0.00        71
           Yes       0.65      0.89      0.75       249

      accuracy                           0.68       496
     macro avg       0.46      0.52      0.48       496
  weighted avg       0.59      0.68      0.62       496

Epoch 2/40


100%|██████████| 124/124 [00:43<00:00,  2.88it/s]


Training loss: 0.7622
Training accuracy: 0.7091
Training weighted F1 score: 0.6501
Training macro F1 score: 0.5053


100%|██████████| 31/31 [00:03<00:00,  7.76it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation accuracy: 0.6331
Validation weighted F1 score: 0.5852
Validation macro F1 score: 0.4548
                precision    recall  f1-score   support

            No       0.53      0.95      0.68       176
To some extent       0.00      0.00      0.00        71
           Yes       0.82      0.59      0.69       249

      accuracy                           0.63       496
     macro avg       0.45      0.51      0.45       496
  weighted avg       0.60      0.63      0.59       496

EarlyStopping counter: 1 out of 6
Epoch 3/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.6393
Training accuracy: 0.7556
Training weighted F1 score: 0.7101
Training macro F1 score: 0.5802


100%|██████████| 31/31 [00:04<00:00,  7.75it/s]


Validation accuracy: 0.7077
Validation weighted F1 score: 0.6652
Validation macro F1 score: 0.5387
                precision    recall  f1-score   support

            No       0.67      0.82      0.74       176
To some extent       0.67      0.06      0.10        71
           Yes       0.74      0.81      0.77       249

      accuracy                           0.71       496
     macro avg       0.69      0.56      0.54       496
  weighted avg       0.70      0.71      0.67       496

Epoch 4/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.5784
Training accuracy: 0.7742
Training weighted F1 score: 0.7374
Training macro F1 score: 0.6173


100%|██████████| 31/31 [00:04<00:00,  7.73it/s]


Validation accuracy: 0.7339
Validation weighted F1 score: 0.7070
Validation macro F1 score: 0.6080
                precision    recall  f1-score   support

            No       0.77      0.78      0.77       176
To some extent       0.60      0.17      0.26        71
           Yes       0.72      0.86      0.79       249

      accuracy                           0.73       496
     macro avg       0.70      0.60      0.61       496
  weighted avg       0.72      0.73      0.71       496

Epoch 5/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.4886
Training accuracy: 0.8091
Training weighted F1 score: 0.7907
Training macro F1 score: 0.7000


100%|██████████| 31/31 [00:04<00:00,  7.73it/s]


Validation accuracy: 0.7097
Validation weighted F1 score: 0.6776
Validation macro F1 score: 0.5672
                precision    recall  f1-score   support

            No       0.66      0.84      0.74       176
To some extent       0.57      0.11      0.19        71
           Yes       0.76      0.79      0.77       249

      accuracy                           0.71       496
     macro avg       0.66      0.58      0.57       496
  weighted avg       0.70      0.71      0.68       496

EarlyStopping counter: 1 out of 6
Epoch 6/40


100%|██████████| 124/124 [00:43<00:00,  2.85it/s]


Training loss: 0.4183
Training accuracy: 0.8399
Training weighted F1 score: 0.8272
Training macro F1 score: 0.7479


100%|██████████| 31/31 [00:04<00:00,  7.71it/s]


Validation accuracy: 0.6915
Validation weighted F1 score: 0.6836
Validation macro F1 score: 0.6082
                precision    recall  f1-score   support

            No       0.67      0.84      0.75       176
To some extent       0.40      0.30      0.34        71
           Yes       0.77      0.70      0.74       249

      accuracy                           0.69       496
     macro avg       0.62      0.61      0.61       496
  weighted avg       0.69      0.69      0.68       496

EarlyStopping counter: 2 out of 6
Epoch 7/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.3591
Training accuracy: 0.8631
Training weighted F1 score: 0.8568
Training macro F1 score: 0.7931


100%|██████████| 31/31 [00:04<00:00,  7.73it/s]


Validation accuracy: 0.7218
Validation weighted F1 score: 0.7034
Validation macro F1 score: 0.6138
                precision    recall  f1-score   support

            No       0.77      0.73      0.75       176
To some extent       0.48      0.23      0.31        71
           Yes       0.72      0.86      0.78       249

      accuracy                           0.72       496
     macro avg       0.66      0.60      0.61       496
  weighted avg       0.70      0.72      0.70       496

Epoch 8/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.3425
Training accuracy: 0.8763
Training weighted F1 score: 0.8711
Training macro F1 score: 0.8157


100%|██████████| 31/31 [00:03<00:00,  7.75it/s]


Validation accuracy: 0.6956
Validation weighted F1 score: 0.6948
Validation macro F1 score: 0.6230
                precision    recall  f1-score   support

            No       0.70      0.85      0.77       176
To some extent       0.36      0.37      0.36        71
           Yes       0.81      0.68      0.74       249

      accuracy                           0.70       496
     macro avg       0.62      0.63      0.62       496
  weighted avg       0.71      0.70      0.69       496

Epoch 9/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.2752
Training accuracy: 0.9035
Training weighted F1 score: 0.9021
Training macro F1 score: 0.8626


100%|██████████| 31/31 [00:04<00:00,  7.68it/s]


Validation accuracy: 0.7077
Validation weighted F1 score: 0.6736
Validation macro F1 score: 0.5542
                precision    recall  f1-score   support

            No       0.68      0.88      0.76       176
To some extent       0.35      0.08      0.14        71
           Yes       0.76      0.77      0.76       249

      accuracy                           0.71       496
     macro avg       0.60      0.58      0.55       496
  weighted avg       0.67      0.71      0.67       496

EarlyStopping counter: 1 out of 6
Epoch 10/40


100%|██████████| 124/124 [00:43<00:00,  2.86it/s]


Training loss: 0.2292
Training accuracy: 0.9237
Training weighted F1 score: 0.9207
Training macro F1 score: 0.8850


100%|██████████| 31/31 [00:04<00:00,  7.73it/s]


Validation accuracy: 0.6774
Validation weighted F1 score: 0.6571
Validation macro F1 score: 0.5644
                precision    recall  f1-score   support

            No       0.61      0.88      0.72       176
To some extent       0.43      0.17      0.24        71
           Yes       0.79      0.68      0.73       249

      accuracy                           0.68       496
     macro avg       0.61      0.58      0.56       496
  weighted avg       0.67      0.68      0.66       496

EarlyStopping counter: 2 out of 6
Epoch 11/40


100%|██████████| 124/124 [00:43<00:00,  2.85it/s]


Training loss: 0.1880
Training accuracy: 0.9354
Training weighted F1 score: 0.9337
Training macro F1 score: 0.9041


100%|██████████| 31/31 [00:04<00:00,  7.73it/s]


Validation accuracy: 0.6512
Validation weighted F1 score: 0.6479
Validation macro F1 score: 0.5859
                precision    recall  f1-score   support

            No       0.60      0.88      0.71       176
To some extent       0.37      0.35      0.36        71
           Yes       0.84      0.58      0.69       249

      accuracy                           0.65       496
     macro avg       0.60      0.60      0.59       496
  weighted avg       0.69      0.65      0.65       496

EarlyStopping counter: 3 out of 6
Epoch 12/40


100%|██████████| 124/124 [00:43<00:00,  2.86it/s]


Training loss: 0.1585
Training accuracy: 0.9525
Training weighted F1 score: 0.9521
Training macro F1 score: 0.9316


100%|██████████| 31/31 [00:04<00:00,  7.73it/s]


Validation accuracy: 0.6935
Validation weighted F1 score: 0.6918
Validation macro F1 score: 0.6192
                precision    recall  f1-score   support

            No       0.71      0.80      0.75       176
To some extent       0.37      0.35      0.36        71
           Yes       0.77      0.71      0.74       249

      accuracy                           0.69       496
     macro avg       0.62      0.62      0.62       496
  weighted avg       0.69      0.69      0.69       496

EarlyStopping counter: 4 out of 6
Epoch 13/40


100%|██████████| 124/124 [00:43<00:00,  2.86it/s]


Training loss: 0.1638
Training accuracy: 0.9495
Training weighted F1 score: 0.9491
Training macro F1 score: 0.9264


100%|██████████| 31/31 [00:04<00:00,  7.72it/s]


Validation accuracy: 0.6633
Validation weighted F1 score: 0.6617
Validation macro F1 score: 0.6061
                precision    recall  f1-score   support

            No       0.62      0.86      0.72       176
To some extent       0.40      0.41      0.40        71
           Yes       0.83      0.59      0.69       249

      accuracy                           0.66       496
     macro avg       0.62      0.62      0.61       496
  weighted avg       0.69      0.66      0.66       496

EarlyStopping counter: 5 out of 6
Epoch 14/40


100%|██████████| 124/124 [00:43<00:00,  2.86it/s]


Training loss: 0.1588
Training accuracy: 0.9530
Training weighted F1 score: 0.9525
Training macro F1 score: 0.9301


100%|██████████| 31/31 [00:04<00:00,  7.73it/s]


Validation accuracy: 0.6875
Validation weighted F1 score: 0.6658
Validation macro F1 score: 0.5722
                precision    recall  f1-score   support

            No       0.66      0.83      0.73       176
To some extent       0.44      0.17      0.24        71
           Yes       0.74      0.73      0.74       249

      accuracy                           0.69       496
     macro avg       0.61      0.58      0.57       496
  weighted avg       0.67      0.69      0.67       496

EarlyStopping counter: 6 out of 6
Early stopping triggered after epoch 14


Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at /mnt/cfs/huangzhiwei/BAE2025/models/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 1 accuracy: 0.6875
Fold 1 weighted F1 score: 0.6658
Fold 1 macro F1 score: 0.5722
Fold 2/5




Epoch 1/40


100%|██████████| 124/124 [00:43<00:00,  2.86it/s]


Training loss: 0.9395
Training accuracy: 0.5704
Training weighted F1 score: 0.4809
Training macro F1 score: 0.3520


100%|██████████| 31/31 [00:03<00:00,  7.76it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation accuracy: 0.7010
Validation weighted F1 score: 0.6550
Validation macro F1 score: 0.4949
                precision    recall  f1-score   support

            No       0.61      0.82      0.70       163
To some extent       0.00      0.00      0.00        64
           Yes       0.77      0.79      0.78       268

      accuracy                           0.70       495
     macro avg       0.46      0.54      0.49       495
  weighted avg       0.62      0.70      0.65       495

Epoch 2/40


100%|██████████| 124/124 [00:43<00:00,  2.86it/s]


Training loss: 0.7037
Training accuracy: 0.7183
Training weighted F1 score: 0.6704
Training macro F1 score: 0.5430


100%|██████████| 31/31 [00:03<00:00,  7.89it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation accuracy: 0.6788
Validation weighted F1 score: 0.6371
Validation macro F1 score: 0.4827
                precision    recall  f1-score   support

            No       0.55      0.93      0.69       163
To some extent       0.00      0.00      0.00        64
           Yes       0.84      0.69      0.76       268

      accuracy                           0.68       495
     macro avg       0.46      0.54      0.48       495
  weighted avg       0.64      0.68      0.64       495

EarlyStopping counter: 1 out of 6
Epoch 3/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.6403
Training accuracy: 0.7552
Training weighted F1 score: 0.7160
Training macro F1 score: 0.5994


100%|██████████| 31/31 [00:03<00:00,  7.81it/s]


Validation accuracy: 0.6808
Validation weighted F1 score: 0.6826
Validation macro F1 score: 0.5804
                precision    recall  f1-score   support

            No       0.68      0.86      0.76       163
To some extent       0.23      0.25      0.24        64
           Yes       0.82      0.68      0.74       268

      accuracy                           0.68       495
     macro avg       0.58      0.59      0.58       495
  weighted avg       0.70      0.68      0.68       495

Epoch 4/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.5396
Training accuracy: 0.7961
Training weighted F1 score: 0.7726
Training macro F1 score: 0.6797


100%|██████████| 31/31 [00:03<00:00,  7.81it/s]


Validation accuracy: 0.7434
Validation weighted F1 score: 0.7074
Validation macro F1 score: 0.5615
                precision    recall  f1-score   support

            No       0.71      0.82      0.76       163
To some extent       0.33      0.06      0.11        64
           Yes       0.78      0.86      0.82       268

      accuracy                           0.74       495
     macro avg       0.61      0.58      0.56       495
  weighted avg       0.70      0.74      0.71       495

EarlyStopping counter: 1 out of 6
Epoch 5/40


100%|██████████| 124/124 [00:43<00:00,  2.88it/s]


Training loss: 0.5057
Training accuracy: 0.8087
Training weighted F1 score: 0.8003
Training macro F1 score: 0.7300


100%|██████████| 31/31 [00:03<00:00,  7.81it/s]


Validation accuracy: 0.6949
Validation weighted F1 score: 0.6850
Validation macro F1 score: 0.5753
                precision    recall  f1-score   support

            No       0.64      0.88      0.74       163
To some extent       0.27      0.19      0.22        64
           Yes       0.83      0.70      0.76       268

      accuracy                           0.69       495
     macro avg       0.58      0.59      0.58       495
  weighted avg       0.70      0.69      0.69       495

EarlyStopping counter: 2 out of 6
Epoch 6/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.4177
Training accuracy: 0.8476
Training weighted F1 score: 0.8397
Training macro F1 score: 0.7807


100%|██████████| 31/31 [00:03<00:00,  7.80it/s]


Validation accuracy: 0.6121
Validation weighted F1 score: 0.5997
Validation macro F1 score: 0.5044
                precision    recall  f1-score   support

            No       0.53      0.94      0.68       163
To some extent       0.23      0.16      0.19        64
           Yes       0.88      0.52      0.65       268

      accuracy                           0.61       495
     macro avg       0.54      0.54      0.50       495
  weighted avg       0.68      0.61      0.60       495

EarlyStopping counter: 3 out of 6
Epoch 7/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.3843
Training accuracy: 0.8556
Training weighted F1 score: 0.8499
Training macro F1 score: 0.7933


100%|██████████| 31/31 [00:03<00:00,  7.78it/s]


Validation accuracy: 0.6909
Validation weighted F1 score: 0.6814
Validation macro F1 score: 0.5726
                precision    recall  f1-score   support

            No       0.64      0.90      0.75       163
To some extent       0.26      0.19      0.22        64
           Yes       0.84      0.68      0.75       268

      accuracy                           0.69       495
     macro avg       0.58      0.59      0.57       495
  weighted avg       0.70      0.69      0.68       495

EarlyStopping counter: 4 out of 6
Epoch 8/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.3161
Training accuracy: 0.8900
Training weighted F1 score: 0.8894
Training macro F1 score: 0.8485


100%|██████████| 31/31 [00:03<00:00,  7.78it/s]


Validation accuracy: 0.6848
Validation weighted F1 score: 0.6742
Validation macro F1 score: 0.5635
                precision    recall  f1-score   support

            No       0.62      0.90      0.74       163
To some extent       0.26      0.17      0.21        64
           Yes       0.84      0.68      0.75       268

      accuracy                           0.68       495
     macro avg       0.57      0.58      0.56       495
  weighted avg       0.69      0.68      0.67       495

EarlyStopping counter: 5 out of 6
Epoch 9/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.2440
Training accuracy: 0.9233
Training weighted F1 score: 0.9221
Training macro F1 score: 0.8933


100%|██████████| 31/31 [00:03<00:00,  7.77it/s]


Validation accuracy: 0.7030
Validation weighted F1 score: 0.6856
Validation macro F1 score: 0.5609
                precision    recall  f1-score   support

            No       0.65      0.86      0.74       163
To some extent       0.24      0.12      0.16        64
           Yes       0.81      0.75      0.78       268

      accuracy                           0.70       495
     macro avg       0.57      0.58      0.56       495
  weighted avg       0.68      0.70      0.69       495

EarlyStopping counter: 6 out of 6
Early stopping triggered after epoch 9


Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at /mnt/cfs/huangzhiwei/BAE2025/models/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 2 accuracy: 0.7030
Fold 2 weighted F1 score: 0.6856
Fold 2 macro F1 score: 0.5609
Fold 3/5




Epoch 1/40


100%|██████████| 124/124 [00:43<00:00,  2.86it/s]


Training loss: 0.9617
Training accuracy: 0.5462
Training weighted F1 score: 0.4835
Training macro F1 score: 0.3759


100%|██████████| 31/31 [00:04<00:00,  7.74it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation accuracy: 0.6808
Validation weighted F1 score: 0.6321
Validation macro F1 score: 0.4846
                precision    recall  f1-score   support

            No       0.57      0.89      0.70       145
To some extent       0.00      0.00      0.00        70
           Yes       0.77      0.74      0.76       280

      accuracy                           0.68       495
     macro avg       0.45      0.54      0.48       495
  weighted avg       0.60      0.68      0.63       495

Epoch 2/40


100%|██████████| 124/124 [00:43<00:00,  2.86it/s]


Training loss: 0.7236
Training accuracy: 0.7077
Training weighted F1 score: 0.6532
Training macro F1 score: 0.5155


100%|██████████| 31/31 [00:03<00:00,  7.76it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation accuracy: 0.7434
Validation weighted F1 score: 0.6875
Validation macro F1 score: 0.5280
                precision    recall  f1-score   support

            No       0.69      0.86      0.76       145
To some extent       0.00      0.00      0.00        70
           Yes       0.78      0.87      0.82       280

      accuracy                           0.74       495
     macro avg       0.49      0.58      0.53       495
  weighted avg       0.64      0.74      0.69       495

Epoch 3/40


100%|██████████| 124/124 [00:43<00:00,  2.86it/s]


Training loss: 0.6205
Training accuracy: 0.7617
Training weighted F1 score: 0.7280
Training macro F1 score: 0.6162


100%|██████████| 31/31 [00:03<00:00,  7.77it/s]


Validation accuracy: 0.7374
Validation weighted F1 score: 0.6894
Validation macro F1 score: 0.5404
                precision    recall  f1-score   support

            No       0.68      0.83      0.75       145
To some extent       0.40      0.03      0.05        70
           Yes       0.77      0.86      0.82       280

      accuracy                           0.74       495
     macro avg       0.62      0.58      0.54       495
  weighted avg       0.69      0.74      0.69       495

Epoch 4/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.5430
Training accuracy: 0.7940
Training weighted F1 score: 0.7666
Training macro F1 score: 0.6650


100%|██████████| 31/31 [00:03<00:00,  7.76it/s]


Validation accuracy: 0.7333
Validation weighted F1 score: 0.7018
Validation macro F1 score: 0.5744
                precision    recall  f1-score   support

            No       0.68      0.86      0.76       145
To some extent       0.37      0.10      0.16        70
           Yes       0.79      0.83      0.81       280

      accuracy                           0.73       495
     macro avg       0.61      0.59      0.57       495
  weighted avg       0.70      0.73      0.70       495

Epoch 5/40


100%|██████████| 124/124 [00:43<00:00,  2.86it/s]


Training loss: 0.4762
Training accuracy: 0.8163
Training weighted F1 score: 0.8041
Training macro F1 score: 0.7259


100%|██████████| 31/31 [00:03<00:00,  7.77it/s]


Validation accuracy: 0.7455
Validation weighted F1 score: 0.7060
Validation macro F1 score: 0.5698
                precision    recall  f1-score   support

            No       0.69      0.86      0.76       145
To some extent       0.45      0.07      0.12        70
           Yes       0.79      0.85      0.82       280

      accuracy                           0.75       495
     macro avg       0.64      0.60      0.57       495
  weighted avg       0.71      0.75      0.71       495

EarlyStopping counter: 1 out of 6
Epoch 6/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.4137
Training accuracy: 0.8491
Training weighted F1 score: 0.8393
Training macro F1 score: 0.7713


100%|██████████| 31/31 [00:03<00:00,  7.78it/s]


Validation accuracy: 0.7131
Validation weighted F1 score: 0.6909
Validation macro F1 score: 0.5772
                precision    recall  f1-score   support

            No       0.63      0.89      0.74       145
To some extent       0.37      0.14      0.21        70
           Yes       0.81      0.76      0.79       280

      accuracy                           0.71       495
     macro avg       0.60      0.60      0.58       495
  weighted avg       0.70      0.71      0.69       495

Epoch 7/40


100%|██████████| 124/124 [00:43<00:00,  2.87it/s]


Training loss: 0.3771
Training accuracy: 0.8693
Training weighted F1 score: 0.8647
Training macro F1 score: 0.8110


 10%|▉         | 3/31 [00:00<00:04,  5.83it/s]


KeyboardInterrupt: 