# BERT Fine-Tuning

This notebook implements BERT fine-tuning for text classification using HuggingFace Transformers.

In [17]:
import sys
import os
# Add project root to Python path
project_root = os.path.dirname(os.getcwd())
sys.path.insert(0, project_root)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import get_scheduler
import numpy as np
from src.data_utils import load_text_classification_data
from src.text_preprocess import basic_clean
from src.model.bert_model import BertClassifier, get_tokenizer
from src.train_nn import train_epoch_with_scheduler, eval_epoch_bert
from src.evaluate import evaluate_classification

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


## 1. Load Data

In [18]:
# Load data
train_texts, train_labels, label2id, id2label = load_text_classification_data('train')
val_texts, val_labels, _, _ = load_text_classification_data('val')
test_texts, test_labels, _, _ = load_text_classification_data('test')

# BERT 使用原始文本，仅去掉首尾空白；TF-IDF/CNN 仍可单独使用 basic_clean
train_texts_clean = [text.strip() for text in train_texts]
val_texts_clean = [text.strip() for text in val_texts]
test_texts_clean = [text.strip() for text in test_texts]

print(f"Training samples: {len(train_texts_clean)}")
print(f"Validation samples: {len(val_texts_clean)}")
print(f"Test samples: {len(test_texts_clean)}")
print(f"Label mapping: {label2id}")


Training samples: 19782
Validation samples: 4239
Test samples: 4240
Label mapping: {'high': 0, 'low': 1, 'medium': 2}


## 2. Initialize Tokenizer

In [19]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_path = "/root/AI_support_tickets/src/model/distilbert-base-uncased"  # 用上一步确认过的那个

tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True)


max_length = 512  # 修改: 从256提升到512以包含更多文本信息（平均文本长度411字符）

print(f"Tokenizer loaded: {model_name}")
print(f"Max length: {max_length}")

Tokenizer loaded: distilbert-base-uncased
Max length: 512


## 3. Create Dataset Class

In [20]:
class BertDataset(Dataset):
    """Dataset for BERT text classification."""
    
    def __init__(self, texts, labels, tokenizer, max_length=512):  # 修改: 默认值从256改为512
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),  # Remove batch dimension: (1, seq_len) -> (seq_len,)
            'attention_mask': encoding['attention_mask'].squeeze(0),  # Remove batch dimension
            'labels': torch.tensor(label, dtype=torch.long)
        }

## 4. Create DataLoaders

In [21]:
# Create datasets
train_dataset = BertDataset(train_texts_clean, train_labels, tokenizer, max_length=max_length)
val_dataset = BertDataset(val_texts_clean, val_labels, tokenizer, max_length=max_length)
test_dataset = BertDataset(test_texts_clean, test_labels, tokenizer, max_length=max_length)

# Create dataloaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

Train batches: 1237
Val batches: 265
Test batches: 265


## 5. Initialize Model

In [22]:
# Initialize BERT classifier
model = BertClassifier(
    model_name=model_path,
    num_classes=3,
    dropout=0.3,
    freeze_bert=False
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Model parameters: 66,365,187
Trainable parameters: 66,365,187


## 6. Setup Training

In [23]:
# Setup optimizer and loss
# 修改: 降低学习率从5e-5到2e-5，使用更稳定的训练
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

class_counts = np.bincount(train_labels)
class_weights = (len(train_labels) / (len(class_counts) * class_counts)).astype(np.float32)
class_weights_tensor = torch.tensor(class_weights, device=device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

# Setup learning rate scheduler with warmup
# 修改: 增加训练轮数从3到10，使用cosine调度器
num_epochs = 10
num_training_steps = len(train_loader) * num_epochs
num_warmup_steps = int(0.1 * num_training_steps)

scheduler = get_scheduler(
    "cosine",  # 修改: 从linear改为cosine，避免学习率过早归零
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

print(f"Training for {num_epochs} epochs")
print(f"Total steps: {num_training_steps}")
print(f"Warmup steps: {num_warmup_steps}")
print(f"Initial learning rate: {optimizer.param_groups[0]['lr']:.2e}")
print(f"Class weights: {class_weights.tolist()}")


Training for 10 epochs
Total steps: 12370
Warmup steps: 1237
Initial learning rate: 0.00e+00
Class weights: [0.8565860986709595, 1.630967140197754, 0.8200472593307495]


## 7. Custom Training Loop with Scheduler

In [24]:
# Training and evaluation functions are now imported from src.train_nn
# train_epoch_with_scheduler and eval_epoch_bert are available

## 8. Training Loop

In [25]:
# Training loop with Early Stopping
best_val_acc = 0
best_val_loss = float('inf')
patience = 3
patience_counter = 0
best_model_state = None

print("Starting training with Early Stopping (patience=3)...")
print("=" * 60)

for epoch in range(num_epochs):
    # Train
    train_loss, train_acc = train_epoch_with_scheduler(
        train_loader, model, criterion, optimizer, scheduler, device
    )
    
    # Validate
    val_loss, val_acc, _, _ = eval_epoch_bert(val_loader, model, criterion, device)
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.2e}")
    
    # Early Stopping logic
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_loss = val_loss
        patience_counter = 0
        best_model_state = model.state_dict().copy()
        print(f"  ✓ New best validation accuracy: {best_val_acc:.4f}")
    else:
        patience_counter += 1
        print(f"  - No improvement (patience: {patience_counter}/{patience})")
        
        if patience_counter >= patience:
            print(f"\n⚠ Early stopping triggered at epoch {epoch+1}")
            print(f"Best validation accuracy: {best_val_acc:.4f} (loss: {best_val_loss:.4f})")
            break
    
    print()

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"\n✓ Loaded best model with validation accuracy: {best_val_acc:.4f}")
else:
    print("\n⚠ No improvement found, using final model")

Starting training with Early Stopping (patience=3)...
Epoch 1/10
  Train Loss: 1.0853, Train Acc: 0.3838
  Val Loss: 1.0740, Val Acc: 0.4588
  Learning Rate: 2.00e-05
  ✓ New best validation accuracy: 0.4588

Epoch 2/10
  Train Loss: 0.9829, Train Acc: 0.4906
  Val Loss: 0.9488, Val Acc: 0.5188
  Learning Rate: 1.94e-05
  ✓ New best validation accuracy: 0.5188

Epoch 3/10
  Train Loss: 0.7683, Train Acc: 0.6411
  Val Loss: 0.9079, Val Acc: 0.6112
  Learning Rate: 1.77e-05
  ✓ New best validation accuracy: 0.6112

Epoch 4/10
  Train Loss: 0.5464, Train Acc: 0.7621
  Val Loss: 0.8580, Val Acc: 0.6719
  Learning Rate: 1.50e-05
  ✓ New best validation accuracy: 0.6719

Epoch 5/10
  Train Loss: 0.3686, Train Acc: 0.8453
  Val Loss: 0.9080, Val Acc: 0.7087
  Learning Rate: 1.17e-05
  ✓ New best validation accuracy: 0.7087

Epoch 6/10
  Train Loss: 0.2364, Train Acc: 0.9030
  Val Loss: 1.0805, Val Acc: 0.7318
  Learning Rate: 8.26e-06
  ✓ New best validation accuracy: 0.7318

Epoch 7/10
  Tra

## 9. Evaluate on Validation Set

In [26]:
# Evaluate on validation set
val_loss, val_acc, val_pred, val_true = eval_epoch_bert(val_loader, model, criterion, device)

val_results = evaluate_classification(val_true, val_pred)
print("Validation Results:")
print(f"Accuracy: {val_results['accuracy']:.4f}")
print(f"F1 Macro: {val_results['f1_macro']:.4f}")
print("\nClassification Report:")
print(val_results['report'])

Validation Results:
Accuracy: 0.7521
F1 Macro: 0.7466

Classification Report:
              precision    recall  f1-score   support

           0       0.77      0.77      0.77      1615
           1       0.78      0.67      0.72       855
           2       0.73      0.78      0.75      1769

    accuracy                           0.75      4239
   macro avg       0.76      0.74      0.75      4239
weighted avg       0.75      0.75      0.75      4239



## 10. Evaluate on Test Set

In [27]:
# Evaluate on test set
test_loss, test_acc, test_pred, test_true = eval_epoch_bert(test_loader, model, criterion, device)

test_results = evaluate_classification(test_true, test_pred)
print("Test Results:")
print(f"Accuracy: {test_results['accuracy']:.4f}")
print(f"F1 Macro: {test_results['f1_macro']:.4f}")
print("\nClassification Report:")
print(test_results['report'])

Test Results:
Accuracy: 0.7580
F1 Macro: 0.7512

Classification Report:
              precision    recall  f1-score   support

           0       0.78      0.78      0.78      1604
           1       0.76      0.67      0.72       876
           2       0.74      0.78      0.76      1760

    accuracy                           0.76      4240
   macro avg       0.76      0.74      0.75      4240
weighted avg       0.76      0.76      0.76      4240



## 11. Save Model

In [28]:
# Save model
os.makedirs('../src/model', exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict(),
    'model_name': model_name,
    'label2id': label2id,
    'id2label': id2label,
    'num_classes': 3
}, '../src/model/bert_finetuned.pt')

print("Model saved to ../src/model/bert_finetuned.pt")

Model saved to ../src/model/bert_finetuned.pt
