In [None]:
!pip install -q transformers torch datasets scikit-learn tqdm accelerate optuna

In [None]:
import logging
import sys
import os

LOG_DIR = "/kaggle/working"
LOG_FILE = os.path.join(LOG_DIR, "train.log")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(message)s",
    handlers=[
        logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8"),
        logging.StreamHandler(sys.stdout),
    ],
    force=True,   # <<< QUAN TR·ªåNG
)

logger = logging.getLogger(__name__)

In [None]:
import os
import random
import numpy as np
import torch

# Setup
os.makedirs("models", exist_ok=True)
os.makedirs("outputs", exist_ok=True)

def set_seed(seed=42): # C·ªë ƒë·ªãnh random seed ƒë·ªÉ k·∫øt qu·∫£ reproducible
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"‚úì Setup complete. Device: {device}")
if torch.cuda.is_available():
    logger.info(f"  GPU: {torch.cuda.get_device_name(0)}")

# Thi·∫øt k·∫ø ch√≠nh

- TaskConfig: Encapsulate to√†n b·ªô th√¥ng tin v·ªÅ label schema
- D·ªÖ d√†ng th√™m t√°c v·ª• m·ªõi b·∫±ng c√°ch t·∫°o config m·ªõi
- ignore_labels: Linh ho·∫°t ƒë·ªãnh nghƒ©a labels kh√¥ng t√≠nh trong evaluation

In [None]:
from dataclasses import dataclass
from typing import List, Dict
import torch

@dataclass
class TaskConfig:
    """Configuration cho m·ªói t√°c v·ª•"""
    task_name: str
    labels: List[str]
    label2id: Dict[str, int]
    id2label: Dict[int, str]
    num_labels: int
    ignore_labels: List[str] = None  # Labels b·ªè qua khi eval
    
    @classmethod
    def create(cls, task_name: str, labels: List[str], ignore_labels: List[str] = None):
        """Factory method t·∫°o config"""
        label2id = {label: idx for idx, label in enumerate(labels)}
        id2label = {idx: label for label, idx in label2id.items()}
        return cls(
            task_name=task_name,
            labels=labels,
            label2id=label2id,
            id2label=id2label,
            num_labels=len(labels),
            ignore_labels=ignore_labels or []
        )

# Config cho Sentence Punctuation
PUNCTUATION_CONFIG = TaskConfig.create(
    task_name="punctuation",
    labels=['O', 'Ôºå', '„ÄÇ', 'Ôºö', '„ÄÅ', 'Ôºõ', 'Ôºü', 'ÔºÅ'],
    ignore_labels=['O']  # B·ªè qua token kh√¥ng c√≥ d·∫•u khi eval
)

# Config cho Sentence Segmentation
SEGMENTATION_CONFIG = TaskConfig.create(
    task_name="segmentation",
    labels=['B', 'M', 'E', 'S'],
    ignore_labels=[]  # ƒê√°nh gi√° t·∫•t c·∫£ c√°c nh√£n
)

# Training hyperparameters
@dataclass
class TrainingConfig:
    """Hyperparameters chung"""
    model_name: str = "SIKU-BERT/sikubert"
    max_length: int = 256

    # Hyperparameters to be tuned
    batch_size: int = 64
    learning_rate: float = 2e-5
    num_epochs: int = 5
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    dropout: float = 0.1
    max_grad_norm: float = 1.0

    # Early stopping
    early_stopping_patience: int = 3 # K·∫øt qu·∫£ val ko tƒÉng 3 l·∫ßn li√™n ti·∫øp

    # Fixed
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 42
    gradient_accumulation_steps: int = 1  # TƒÉng n·∫øu GPU memory kh√¥ng ƒë·ªß
    fp16=True

logger.info("‚úÖ Configurations defined")

In [None]:
# 1. Select task
TASK = "segmentation"  # or "punctuation"
task_config = PUNCTUATION_CONFIG if TASK == "punctuation" else SEGMENTATION_CONFIG
train_path='/kaggle/input/tbnl-sliding-window-256-128/segmentation_train.json'
val_path='/kaggle/input/tbnl-sliding-window-256-128/segmentation_val.json'
test_path='/kaggle/input/tbnl-sliding-window-256-128/segmentation_test.json'

# Dataset & Preprocessing
- H·ªó tr·ª£ character-level alignment (quan tr·ªçng cho Classical Chinese)
- X·ª≠ l√Ω special tokens ([CLS], [SEP], [PAD]) b·∫±ng label -100
- Validate input ƒë·ªÉ ph√°t hi·ªán l·ªói s·ªõm

In [None]:
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from typing import List, Tuple
import numpy as np

class ClassicalChineseDataset(Dataset):
    """
    Dataset cho token classification tasks.
    
    Input format:
        texts: List[str] - danh s√°ch vƒÉn b·∫£n (m·ªói vƒÉn b·∫£n l√† chu·ªói k√Ω t·ª±)
        labels: List[List[str]] - nh√£n t∆∞∆°ng ·ª©ng cho m·ªói k√Ω t·ª±
    
    Example:
        texts = ["Â§©Âú∞ÁéÑÈªÉ", "ÂÆáÂÆôÊ¥™Ëçí"]
        labels = [['B', 'M', 'M', 'E'], ['B', 'M', 'M', 'E']]
    """
    
    def __init__(
        self,
        texts: List[str],
        labels: List[List[str]],
        tokenizer,
        config: TaskConfig,
        max_length: int = 256
    ):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.config = config
        self.max_length = max_length
        
        # Validate data
        assert len(texts) == len(labels), "texts v√† labels ph·∫£i c√πng ƒë·ªô d√†i"
        for text, label_seq in zip(texts, labels):
            assert len(text) == len(label_seq), \
                f"Text v√† labels kh√¥ng kh·ªõp: {len(text)} vs {len(label_seq)}"
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label_seq = self.labels[idx]
        
        # Tokenize v·ªõi is_split_into_words=True ƒë·ªÉ track alignment
        # SikuBERT th∆∞·ªùng tokenize t·ª´ng k√Ω t·ª± -> 1:1 mapping
        tokenized = self.tokenizer(
            list(text),  # Convert sang list k√Ω t·ª±
            is_split_into_words=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Align labels v·ªõi subword tokens
        # SikuBERT: th∆∞·ªùng 1 char = 1 token, nh∆∞ng v·∫´n c·∫ßn x·ª≠ l√Ω edge cases
        word_ids = tokenized.word_ids(batch_index=0)
        label_ids = []
        
        for word_id in word_ids:
            if word_id is None:
                # [CLS], [SEP], [PAD] -> assign -100 (ignored by CrossEntropyLoss)
                label_ids.append(-100)
            else:
                # Map label sang ID
                label = label_seq[word_id]
                label_ids.append(self.config.label2id[label])
        
        return {
            'input_ids': tokenized['input_ids'].squeeze(0),
            'attention_mask': tokenized['attention_mask'].squeeze(0),
            'labels': torch.tensor(label_ids, dtype=torch.long)
        }

def create_dataloaders(
    train_texts: List[str],
    train_labels: List[List[str]],
    val_texts: List[str],
    val_labels: List[List[str]],
    tokenizer,
    config: TaskConfig,
    training_config: TrainingConfig
) -> Tuple[DataLoader, DataLoader]:
    """Factory function t·∫°o train & val dataloaders"""
    
    train_dataset = ClassicalChineseDataset(
        train_texts, train_labels, tokenizer, config, training_config.max_length
    )
    val_dataset = ClassicalChineseDataset(
        val_texts, val_labels, tokenizer, config, training_config.max_length
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=training_config.batch_size,
        shuffle=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=training_config.batch_size,
        shuffle=False
    )
    
    return train_loader, val_loader

logger.info("‚úÖ Dataset class defined")

# Model Definition

- Module h√≥a: D·ªÖ d√†ng th√™m BiLSTM/CNN qua parameter extra_layer_type
- Extensible: Placeholder cho CRF (s·∫Ω return logits, CRF x·ª≠ l√Ω b√™n ngo√†i)
- Linear head ƒë∆°n gi·∫£n nh∆∞ng hi·ªáu qu·∫£ cho baseline

In [None]:
from transformers import AutoModel
import torch.nn as nn


class SikuBERTForTokenClassification(nn.Module):
    """
    SikuBERT v·ªõi classification head c√≥ th·ªÉ m·ªü r·ªông.
    
    Architecture:
        BERT Encoder -> [Optional: Extra Layers] -> Classification Head
    
    Thi·∫øt k·∫ø module h√≥a cho ph√©p:
        - Thay Linear head b·∫±ng CRF
        - Th√™m BiLSTM/CNN layers
    """
    
    def __init__(
        self,
        model_name: str,
        num_labels: int,
        dropout: float = 0.1,
        use_extra_layer: bool = False,
        extra_layer_type: str = None,  # 'lstm', 'cnn', None
        cnn_kernel_sizes: list = None,  # M·ªõi: kernel sizes cho CNN
        cnn_num_filters: int = 128      # M·ªõi: s·ªë filters cho CNN
    ):
        super().__init__()
        
        # Backbone: SikuBERT
        self.bert = AutoModel.from_pretrained(model_name)
        self.hidden_size = self.bert.config.hidden_size
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Extra layers (placeholder cho future extensions)
        self.extra_layer = None
        self.extra_layer_type = extra_layer_type
        
        if use_extra_layer:
            if extra_layer_type == 'lstm':
                # BiLSTM layer
                self.extra_layer = nn.LSTM(
                    self.hidden_size,
                    self.hidden_size // 2,
                    batch_first=True,
                    bidirectional=True
                )
                classifier_input_size = self.hidden_size
                
            elif extra_layer_type == 'cnn':
                # CNN layer v·ªõi multiple kernel sizes
                if cnn_kernel_sizes is None:
                    cnn_kernel_sizes = [3, 5, 7]  # default
                
                self.cnn_kernel_sizes = cnn_kernel_sizes
                self.cnn_num_filters = cnn_num_filters
                
                # T·∫°o multiple Conv1d layers v·ªõi kernel sizes kh√°c nhau
                self.convs = nn.ModuleList([
                    nn.Conv1d(
                        in_channels=self.hidden_size,
                        out_channels=cnn_num_filters,
                        kernel_size=k,
                        padding=k//2  # Same padding
                    )
                    for k in cnn_kernel_sizes
                ])
                
                # Output size = s·ªë filters √ó s·ªë kernels
                classifier_input_size = cnn_num_filters * len(cnn_kernel_sizes)
                
        else:
            classifier_input_size = self.hidden_size
        
        # Classification head
        self.classifier = nn.Linear(classifier_input_size, num_labels)
        
        # Loss function
        self.loss_fct = nn.CrossEntropyLoss()
    
    def forward(
        self,
        input_ids,
        attention_mask,
        labels=None
    ):
        """
        Forward pass.
        
        Returns:
            dict v·ªõi keys: loss (n·∫øu labels ƒë∆∞·ª£c cung c·∫•p), logits
        """
        # BERT encoding
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        sequence_output = outputs.last_hidden_state  # (batch, seq_len, hidden)
        
        # Extra layers (n·∫øu c√≥)
        if self.extra_layer is not None:
            if self.extra_layer_type == 'lstm':
                sequence_output, _ = self.extra_layer(sequence_output)
                
            elif self.extra_layer_type == 'cnn':
                # CNN expects (batch, channels, seq_len)
                # Input: (batch, seq_len, hidden) -> transpose to (batch, hidden, seq_len)
                cnn_input = sequence_output.transpose(1, 2)
                
                # Apply multiple convolutions
                conv_outputs = []
                for conv in self.convs:
                    # conv output: (batch, num_filters, seq_len)
                    conv_out = F.relu(conv(cnn_input))
                    conv_outputs.append(conv_out)
                
                # Concatenate along channel dimension
                # (batch, num_filters * num_kernels, seq_len)
                combined = torch.cat(conv_outputs, dim=1)
                
                # Transpose back: (batch, seq_len, num_filters * num_kernels)
                sequence_output = combined.transpose(1, 2)
        
        # Dropout + Classification
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)  # (batch, seq_len, num_labels)
        
        # Calculate loss n·∫øu c√≥ labels
        loss = None
        if labels is not None:
            # Flatten ƒë·ªÉ t√≠nh loss
            loss = self.loss_fct(
                logits.view(-1, self.classifier.out_features),
                labels.view(-1)
            )
        
        return {
            'loss': loss,
            'logits': logits
        }

logger.info("‚úÖ Model class with CNN support defined")

# Evaluation & EvalHan2024-style Scorer
- Stateful scorer: accumulate predictions qua batches
- B·ªè qua padding v√† ignore labels theo ƒë√∫ng EvalHan2024
- Per-label metrics + overall macro average
- Pretty printing cho d·ªÖ ƒë·ªçc

In [None]:
import json
import os
from datetime import datetime
from dataclasses import asdict

def save_eval_results(
    results: dict,
    task_config: TaskConfig,
    training_config: TrainingConfig,
    split: str,                       # "val" | "test"
    output_dir: str = "eval_results"
):
    """
    L∆∞u k·∫øt qu·∫£ ƒë√°nh gi√° + si√™u tham s·ªë ra file JSON (EvalHan2024 style)
    """
    assert split in ["val", "test"], "split must be 'val' or 'test'"

    os.makedirs(output_dir, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_tag = training_config.model_name.replace("/", "_")

    filename = (
        f"{task_config.task_name}_"
        f"{split}_"
        f"{model_tag}_"
        f"{timestamp}.json"
    )

    save_path = os.path.join(output_dir, filename)

    payload = {
        "meta": {
            "task": task_config.task_name,
            "split": split,
            "timestamp": timestamp
        },
        "model": {
            "name": training_config.model_name,
            "num_labels": task_config.num_labels,
            "labels": task_config.labels,
            "ignore_labels": task_config.ignore_labels
        },
        "training_config": asdict(training_config),
        "results": results
    }

    with open(save_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, ensure_ascii=False, indent=2)

    logger.info(f"[INFO] Saved EvalHan results to: {save_path}")

In [None]:
from sklearn.metrics import precision_recall_fscore_support
from collections import defaultdict

class EvalHanScorer:
    """
    Scorer theo chu·∫©n EvalHan2024.
    
    T√≠nh Precision/Recall/F1 cho:
        - T·ª´ng lo·∫°i label ri√™ng bi·ªát
        - Overall (macro average)
    
    B·ªè qua:
        - Padding tokens (label = -100)
        - Labels trong ignore_labels
    """
    
    def __init__(self, config: TaskConfig):
        self.config = config
        self.reset()
    
    def reset(self):
        """Reset statistics"""
        self.all_predictions = []
        self.all_labels = []
    
    def add_batch(self, predictions, labels):
        """
        Th√™m m·ªôt batch predictions v√† labels.
        
        Args:
            predictions: tensor (batch, seq_len) - predicted label IDs
            labels: tensor (batch, seq_len) - ground truth label IDs
        """
        # Flatten v√† filter
        predictions = predictions.view(-1).cpu().numpy()
        labels = labels.view(-1).cpu().numpy()
        
        # L·ªçc padding (-100) v√† ignore labels
        valid_mask = labels != -100
        
        for pred, label in zip(predictions[valid_mask], labels[valid_mask]):
            label_str = self.config.id2label[label]
            # B·ªè qua ignore labels (v√≠ d·ª•: 'O' trong punctuation)
            if label_str not in self.config.ignore_labels:
                self.all_predictions.append(pred)
                self.all_labels.append(label)
    
    def compute(self):
        """
        T√≠nh metrics theo chu·∫©n EvalHan2024.
        
        Returns:
            dict v·ªõi structure:
                {
                    'per_label': {
                        'label_name': {'precision': ..., 'recall': ..., 'f1': ...}
                    },
                    'overall': {'precision': ..., 'recall': ..., 'f1': ...}
                }
        """
        if len(self.all_predictions) == 0:
            return {'overall': {'precision': 0, 'recall': 0, 'f1': 0}, 'per_label': {}}
        
        # Get unique labels (exclude ignore labels)
        unique_labels = []
        for label_str in self.config.labels:
            if label_str not in self.config.ignore_labels:
                unique_labels.append(self.config.label2id[label_str])
        
        # T√≠nh metrics cho t·ª´ng label
        precision, recall, f1, support = precision_recall_fscore_support(
            self.all_labels,
            self.all_predictions,
            labels=unique_labels,
            average=None,
            zero_division=0
        )
        
        # Format results
        results = {'per_label': {}}
        
        for idx, label_id in enumerate(unique_labels):
            label_name = self.config.id2label[label_id]
            results['per_label'][label_name] = {
                'precision': float(precision[idx]),
                'recall': float(recall[idx]),
                'f1': float(f1[idx]),
                'support': int(support[idx])
            }
        
        # Overall metrics (macro average)
        results['overall'] = {
            'precision': float(np.mean(precision)),
            'recall': float(np.mean(recall)),
            'f1': float(np.mean(f1))
        }
        
        return results
    
    def print_results(self, results):
        """Pretty print results"""
        logger.info(f"\n{'='*70}")
        logger.info(f"EvalHan2024 Results - {self.config.task_name.upper()}")
        logger.info(f"{'='*70}")
        
        # Per-label results
        logger.info(f"\n{'Label':<10} {'Precision':<12} {'Recall':<12} {'F1':<12} {'Support':<10}")
        logger.info(f"{'-'*70}")
        
        for label_name, metrics in results['per_label'].items():
            logger.info(f"{label_name:<10} "
                  f"{metrics['precision']:<12.4f} "
                  f"{metrics['recall']:<12.4f} "
                  f"{metrics['f1']:<12.4f} "
                  f"{metrics['support']:<10}")
        
        # Overall results
        logger.info(f"{'-'*70}")
        logger.info(f"{'OVERALL':<10} "
              f"{results['overall']['precision']:<12.4f} "
              f"{results['overall']['recall']:<12.4f} "
              f"{results['overall']['f1']:<12.4f}")
        logger.info(f"{'='*70}\n")

def evaluate_model(
    model,
    dataloader,
    config: TaskConfig,
    device,
    split: str,
    training_config: TrainingConfig,
    output_dir: str = "eval_results"
):
    """
    Evaluate model tr√™n validation/test set.
    
    Returns:
        dict: EvalHan2024-style results
    """
    model.eval()
    scorer = EvalHanScorer(config)
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            # Get predictions
            predictions = torch.argmax(outputs['logits'], dim=-1)
            
            # Add to scorer
            scorer.add_batch(predictions, labels)
    
    # Compute metrics
    results = scorer.compute()
    scorer.print_results(results)

    save_eval_results(
        results=results,
        task_config=config,
        training_config=training_config,
        split=split,
        output_dir=output_dir
    )
    
    return results

logger.info("‚úÖ Evaluation functions defined")

In [None]:
"""
========================================================================
DATA LOADING UTILITIES
========================================================================
H·ªó tr·ª£ nhi·ªÅu format ph·ªï bi·∫øn cho Classical Chinese data:
- JSON format
- CoNLL format (IOB style)
- Plain text with inline labels
- CSV format
========================================================================
"""

import json
import csv
from typing import List, Tuple
from pathlib import Path

# ============================================================================
# FORMAT 1: JSON Format
# ============================================================================

def load_json_format(file_path: str) -> Tuple[List[str], List[List[str]]]:
    """
    Load data t·ª´ JSON format.
    
    Expected JSON structure:
    [
        {
            "text": "Â§©Âú∞ÁéÑÈªÉÂÆáÂÆôÊ¥™Ëçí",
            "labels": ["O", "O", "O", "Ôºå", "O", "O", "O", "„ÄÇ"]
        },
        ...
    ]
    
    Returns:
        texts: List[str]
        labels: List[List[str]]
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    texts = [item['text'] for item in data]
    labels = [item['labels'] for item in data]
    
    logger.info(f"‚úì Loaded {len(texts)} samples from {file_path}")
    return texts, labels


def save_json_format(texts: List[str], labels: List[List[str]], output_path: str):
    """Save data to JSON format"""
    data = [{'text': t, 'labels': l} for t, l in zip(texts, labels)]
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    
    logger.info(f"‚úì Saved {len(texts)} samples to {output_path}")


# ============================================================================
# FORMAT 2: CoNLL Format (Character-level)
# ============================================================================

def load_conll_format(file_path: str) -> Tuple[List[str], List[List[str]]]:
    """
    Load data t·ª´ CoNLL-style format.
    
    Expected format (character per line, blank line separates samples):
    Â§© O
    Âú∞ O
    ÁéÑ O
    ÈªÉ Ôºå
    
    ÂÆá O
    ÂÆô O
    ...
    
    Returns:
        texts: List[str]
        labels: List[List[str]]
    """
    texts = []
    labels = []
    current_text = []
    current_labels = []
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            
            if not line:  # Blank line = new sample
                if current_text:
                    texts.append(''.join(current_text))
                    labels.append(current_labels)
                    current_text = []
                    current_labels = []
            else:
                parts = line.split()
                if len(parts) >= 2:
                    char, label = parts[0], parts[1]
                    current_text.append(char)
                    current_labels.append(label)
        
        # Don't forget last sample
        if current_text:
            texts.append(''.join(current_text))
            labels.append(current_labels)
    
    logger.info(f"‚úì Loaded {len(texts)} samples from {file_path}")
    return texts, labels


def save_conll_format(texts: List[str], labels: List[List[str]], output_path: str):
    """Save data to CoNLL format"""
    with open(output_path, 'w', encoding='utf-8') as f:
        for text, label_seq in zip(texts, labels):
            for char, label in zip(text, label_seq):
                f.write(f"{char} {label}\n")
            f.write("\n")  # Blank line between samples
    
    logger.info(f"‚úì Saved {len(texts)} samples to {output_path}")


# ============================================================================
# FORMAT 3: Inline Format (text with embedded punctuation)
# ============================================================================

def load_inline_punctuation(file_path: str, 
                           punctuation_marks: List[str] = None) -> Tuple[List[str], List[List[str]]]:
    """
    Load text v·ªõi d·∫•u c√¢u inline, convert th√†nh format chu·∫©n.
    
    Input: "Â§©Âú∞ÁéÑÈªÉÔºåÂÆáÂÆôÊ¥™Ëçí„ÄÇ"
    Output: 
        text: "Â§©Âú∞ÁéÑÈªÉÂÆáÂÆôÊ¥™Ëçí"
        labels: ['O', 'O', 'O', 'O', 'Ôºå', 'O', 'O', 'O', 'O', '„ÄÇ']
    
    Args:
        file_path: path to file (one sample per line)
        punctuation_marks: list of punctuation to extract
    
    Returns:
        texts: List[str] (without punctuation)
        labels: List[List[str]] (labels at character positions)
    """
    if punctuation_marks is None:
        punctuation_marks = ['Ôºå', '„ÄÇ', 'Ôºö', '„ÄÅ', 'Ôºõ', 'Ôºü', 'ÔºÅ']
    
    punct_set = set(punctuation_marks)
    texts = []
    labels = []
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            
            current_text = []
            current_labels = []
            
            for i, char in enumerate(line):
                if char in punct_set:
                    # D·∫•u c√¢u g√°n cho k√Ω t·ª± tr∆∞·ªõc ƒë√≥
                    if current_labels:
                        current_labels[-1] = char
                else:
                    current_text.append(char)
                    current_labels.append('O')
            
            if current_text:
                texts.append(''.join(current_text))
                labels.append(current_labels)
    
    logger.info(f"‚úì Loaded {len(texts)} samples from {file_path}")
    return texts, labels


# ============================================================================
# FORMAT 4: CSV Format
# ============================================================================

def load_csv_format(file_path: str, 
                   text_column: str = 'text',
                   label_column: str = 'labels',
                   delimiter: str = ',') -> Tuple[List[str], List[List[str]]]:
    """
    Load data t·ª´ CSV.
    
    Expected CSV columns:
        text,labels
        "Â§©Âú∞ÁéÑÈªÉ","O O O Ôºå"
    
    Args:
        file_path: path to CSV
        text_column: name of text column
        label_column: name of labels column
        delimiter: CSV delimiter
    
    Returns:
        texts: List[str]
        labels: List[List[str]]
    """
    texts = []
    labels = []
    
    with open(file_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f, delimiter=delimiter)
        for row in reader:
            text = row[text_column]
            label_str = row[label_column]
            
            # Parse labels (assume space-separated)
            label_list = label_str.strip().split()
            
            texts.append(text)
            labels.append(label_list)
    
    logger.info(f"‚úì Loaded {len(texts)} samples from {file_path}")
    return texts, labels


# ============================================================================
# FORMAT 5: BEMS from sentence boundaries
# ============================================================================

def create_bems_labels_from_sentences(sentences: List[str]) -> Tuple[str, List[str]]:
    """
    Convert list of sentences th√†nh text + BEMS labels.
    
    Input: ["Â§©Âú∞ÁéÑÈªÉ", "ÂÆáÂÆôÊ¥™Ëçí"]
    Output:
        text: "Â§©Âú∞ÁéÑÈªÉÂÆáÂÆôÊ¥™Ëçí"
        labels: ['B','M','M','E','B','M','M','E']
    
    Args:
        sentences: List of sentences
    
    Returns:
        text: concatenated text
        labels: BEMS labels
    """
    text = ''.join(sentences)
    labels = []
    
    for sentence in sentences:
        length = len(sentence)
        if length == 1:
            labels.append('S')
        else:
            labels.append('B')
            labels.extend(['M'] * (length - 2))
            labels.append('E')
    
    return text, labels


def load_sentence_file_to_bems(file_path: str, 
                               sentence_delimiter: str = '\n') -> Tuple[List[str], List[List[str]]]:
    """
    Load file v·ªõi sentences (one per line ho·∫∑c separated by delimiter),
    convert th√†nh BEMS format.
    
    Input file:
        Â§©Âú∞ÁéÑÈªÉ
        ÂÆáÂÆôÊ¥™Ëçí
        
        Êó•ÊúàÁõàÊòÉ
        Ëæ∞ÂÆøÂàóÂºµ
    
    (blank line separates documents)
    
    Returns:
        texts: List[str]
        labels: List[List[str]] (BEMS)
    """
    texts = []
    labels = []
    current_sentences = []
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            
            if not line:  # Blank line = new document
                if current_sentences:
                    text, label_seq = create_bems_labels_from_sentences(current_sentences)
                    texts.append(text)
                    labels.append(label_seq)
                    current_sentences = []
            else:
                current_sentences.append(line)
        
        # Last document
        if current_sentences:
            text, label_seq = create_bems_labels_from_sentences(current_sentences)
            texts.append(text)
            labels.append(label_seq)
    
    logger.info(f"‚úì Loaded {len(texts)} samples from {file_path}")
    return texts, labels


# ============================================================================
# AUTO-DETECT FORMAT
# ============================================================================

def load_data_auto(file_path: str, **kwargs) -> Tuple[List[str], List[List[str]]]:
    """
    T·ª± ƒë·ªông detect format v√† load data.
    
    Args:
        file_path: path to data file
        **kwargs: additional arguments for specific loaders
    
    Returns:
        texts: List[str]
        labels: List[List[str]]
    """
    file_path = Path(file_path)
    suffix = file_path.suffix.lower()
    
    logger.info(f"Auto-detecting format for: {file_path}")
    
    if suffix == '.json':
        return load_json_format(file_path)
    elif suffix == '.csv':
        return load_csv_format(file_path, **kwargs)
    elif suffix in ['.txt', '.conll']:
        # Try to detect: CoNLL vs inline vs sentence format
        with open(file_path, 'r', encoding='utf-8') as f:
            first_line = f.readline().strip()
        
        if '\t' in first_line or (len(first_line.split()) == 2):
            logger.info("  Detected: CoNLL format")
            return load_conll_format(file_path)
        else:
            logger.info("  Detected: Plain text format")
            logger.info("  Assuming inline punctuation - specify format if incorrect")
            return load_inline_punctuation(file_path)
    else:
        raise ValueError(f"Unsupported format: {suffix}")


# ============================================================================
# VALIDATION UTILITIES
# ============================================================================

def validate_data(texts: List[str], labels: List[List[str]]) -> bool:
    """
    Validate data integrity.
    
    Returns:
        bool: True if valid, raises exception otherwise
    """
    assert len(texts) == len(labels), \
        f"Length mismatch: {len(texts)} texts vs {len(labels)} label sequences"
    
    for i, (text, label_seq) in enumerate(zip(texts, labels)):
        assert len(text) == len(label_seq), \
            f"Sample {i}: {len(text)} chars vs {len(label_seq)} labels\n" \
            f"  Text: {text[:50]}...\n" \
            f"  Labels: {label_seq[:50]}..."
    
    logger.info(f"‚úì Data validation passed: {len(texts)} samples")
    return True


def print_data_stats(texts: List[str], labels: List[List[str]], task_config):
    """Print statistics about dataset"""
    from collections import Counter
    
    logger.info(f"\n{'='*70}")
    logger.info("DATA STATISTICS")
    logger.info(f"{'='*70}")
    logger.info(f"Total samples: {len(texts)}")
    logger.info(f"Avg text length: {sum(len(t) for t in texts) / len(texts):.1f} chars")
    logger.info(f"Min/Max length: {min(len(t) for t in texts)} / {max(len(t) for t in texts)}")
    
    # Label distribution
    all_labels = [label for label_seq in labels for label in label_seq]
    label_counts = Counter(all_labels)
    
    logger.info(f"\nLabel distribution:")
    for label in task_config.labels:
        count = label_counts.get(label, 0)
        pct = 100 * count / len(all_labels) if all_labels else 0
        logger.info(f"  {label}: {count:>8} ({pct:>5.2f}%)")
    
    logger.info(f"{'='*70}\n")


# ============================================================================
# EXAMPLE USAGE
# ============================================================================

"""
Example usage:

# Auto-detect format
texts, labels = load_data_auto('/kaggle/input/mydata/train.json')

# Specific format
texts, labels = load_json_format('/kaggle/input/mydata/train.json')
texts, labels = load_conll_format('/kaggle/input/mydata/train.conll')

# Validate
validate_data(texts, labels)
print_data_stats(texts, labels, PUNCTUATION_CONFIG)
"""

# Training Procedure
- Gradient clipping ƒë·ªÉ stability
- Learning rate warmup (quan tr·ªçng cho BERT-based models)
- Early stopping ƒë·ªÉ tr√°nh overfitting
- Checkpoint best model theo F1 score
- Early Stopping

In [None]:
class EarlyStopping:
    """Early stopping handler"""
    
    def __init__(self, patience: int = 3, min_delta: float = 0.0, mode: str = 'max'):
        """
        Args:
            patience: Number of epochs to wait for improvement
            min_delta: Minimum change to qualify as improvement
            mode: 'max' for metrics to maximize (F1), 'min' for loss
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0
    
    def __call__(self, score: float, epoch: int) -> bool:
        """
        Check if should stop training.
        
        Returns:
            True if should stop, False otherwise
        """
        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
            return False
        
        if self.mode == 'max':
            improved = score > (self.best_score + self.min_delta)
        else:
            improved = score < (self.best_score - self.min_delta)
        
        if improved:
            self.best_score = score
            self.best_epoch = epoch
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                return True
        
        return False

In [None]:
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import random
from typing import Optional
import optuna

def train_with_early_stopping(
    model,
    train_loader,
    val_loader,
    task_config: TaskConfig,
    training_config: TrainingConfig,
    trial: Optional[optuna.Trial] = None,
    save_path: str = "models/best_model_cnn.pt"
):
    """
    Train model with early stopping.
    
    Args:
        trial: Optuna trial (for pruning)
    
    Returns:
        best_val_f1: Best validation F1 score
    """
    set_seed(training_config.seed)
    
    # Optimizer
    optimizer = AdamW(
        model.parameters(),
        lr=training_config.learning_rate,
        weight_decay=training_config.weight_decay
    )
    
    # Scheduler
    total_steps = len(train_loader) * training_config.num_epochs
    warmup_steps = int(total_steps * training_config.warmup_ratio)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
    
    # Early stopping
    early_stopping = EarlyStopping(
        patience=training_config.early_stopping_patience,
        mode='max'
    )
    
    best_val_f1 = 0.0
    
    for epoch in range(training_config.num_epochs):
        # ====================================================================
        # TRAINING
        # ====================================================================
        model.train()
        total_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{training_config.num_epochs}")
        for step, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(training_config.device)
            attention_mask = batch['attention_mask'].to(training_config.device)
            labels = batch['labels'].to(training_config.device)
            
            # Forward
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs['loss']
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), training_config.max_grad_norm)
            optimizer.step()
            scheduler.step()

            loss_value = loss.item()
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})
            
            # Train log
            
            logger.info(
                "Epoch %d | Step %d/%d | Loss %.4f",
                epoch + 1, step, len(train_loader), loss_value
            )
        
        avg_train_loss = total_loss / len(train_loader)
        
        # ====================================================================
        # VALIDATION
        # ====================================================================
        val_results = evaluate_model(model, val_loader, task_config, training_config.device, "test", training_config)
        val_f1 = val_results['overall']['f1']
        
        logger.info(f"\nEpoch {epoch+1}/{training_config.num_epochs}:")
        logger.info(f"  Train Loss: {avg_train_loss:.4f}")
        logger.info(f"  Val F1:     {val_f1:.4f}")
        logger.info(f"  Val Prec:   {val_results['overall']['precision']:.4f}")
        logger.info(f"  Val Recall: {val_results['overall']['recall']:.4f}")
        
        # Save best model
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), save_path)
            logger.info(f"  ‚úì Saved best model (F1: {best_val_f1:.4f})")
        
        # Optuna pruning (optional)
        if trial is not None:
            trial.report(val_f1, epoch)
            if trial.should_prune():
                logger.info(f"  ‚úÇÔ∏è Trial pruned at epoch {epoch+1}")
                raise optuna.TrialPruned()
        
        # Early stopping check
        if early_stopping(val_f1, epoch):
            logger.info(f"\n‚èπÔ∏è  Early stopping triggered!")
            logger.info(f"  No improvement for {early_stopping.patience} epochs")
            logger.info(f"  Best epoch: {early_stopping.best_epoch + 1}")
            logger.info(f"  Best Val F1: {best_val_f1:.4f}")
            break
    
    # Load best model
    model.load_state_dict(torch.load(save_path))
    
    return best_val_f1

logger.info("‚úÖ Training functions with early stopping defined")

# OPTUNA BAYESIAN OPTIMIZATION

In [None]:
from optuna.visualization import plot_optimization_history, plot_param_importances
from optuna.importance import MeanDecreaseImpurityImportanceEvaluator

def create_optuna_objective(
    train_texts, train_labels,
    val_texts, val_labels,
    tokenizer,
    task_config: TaskConfig,
    base_training_config: TrainingConfig
):
    """
    Create Optuna objective function.
    
    Returns a function that Optuna will optimize.
    """
    
    def objective(trial: optuna.Trial):
        """
        Objective function for Optuna to maximize.
        
        Samples hyperparameters and returns validation F1.
        """
        
        # ====================================================================
        # SAMPLE HYPERPARAMETERS
        # ====================================================================
        learning_rate = trial.suggest_float('learning_rate', 1e-5, 5e-5, log=True)
        # batch_size = trial.suggest_categorical('batch_size', [8, 16, 32, 64])
        batch_size = trial.suggest_categorical('batch_size', [64])
        warmup_ratio = trial.suggest_float('warmup_ratio', 0.0, 0.2)
        weight_decay = trial.suggest_float('weight_decay', 0.0, 0.1)
        dropout = trial.suggest_float('dropout', 0.1, 0.3)
        
        logger.info(f"\n{'='*70}")
        logger.info(f"Trial {trial.number}")
        logger.info(f"{'='*70}")
        logger.info(f"Hyperparameters:")
        logger.info(f"  learning_rate: {learning_rate:.2e}")
        logger.info(f"  batch_size:    {batch_size}")
        logger.info(f"  warmup_ratio:  {warmup_ratio:.3f}")
        logger.info(f"  weight_decay:  {weight_decay:.3f}")
        logger.info(f"  dropout:       {dropout:.3f}")
        logger.info(f"{'='*70}")
        
        # ====================================================================
        # CREATE DATALOADERS
        # ====================================================================
        train_dataset = ClassicalChineseDataset(
            train_texts, train_labels, tokenizer, task_config, base_training_config.max_length
        )
        val_dataset = ClassicalChineseDataset(
            val_texts, val_labels, tokenizer, task_config, base_training_config.max_length
        )
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        # ====================================================================
        # CREATE MODEL
        # ====================================================================
        model = SikuBERTForTokenClassification(
            model_name=base_training_config.model_name,
            num_labels=task_config.num_labels,
            dropout=dropout,
            use_extra_layer=True,
            extra_layer_type='cnn',
            cnn_kernel_sizes=[3, 5, 7], # Custom kernels
            cnn_num_filters=256 # Nhi·ªÅu filters h∆°n
        ).to(base_training_config.device)
        
        # ====================================================================
        # CREATE TRAINING CONFIG
        # ====================================================================
        trial_config = TrainingConfig(
            model_name=base_training_config.model_name,
            max_length=base_training_config.max_length,
            learning_rate=learning_rate,
            batch_size=batch_size,
            num_epochs=base_training_config.num_epochs,
            warmup_ratio=warmup_ratio,
            weight_decay=weight_decay,
            dropout=dropout,
            early_stopping_patience=base_training_config.early_stopping_patience,
            device=base_training_config.device,
            seed=base_training_config.seed
        )
        
        # ====================================================================
        # TRAIN WITH EARLY STOPPING
        # ====================================================================
        try:
            best_val_f1 = train_with_early_stopping(
                model, train_loader, val_loader,
                task_config, trial_config,
                trial=trial,
                save_path=f"models/optuna_trial_{trial.number}_best_cnn.pt"
            )
            
            logger.info(f"\n‚úì Trial {trial.number} completed: Val F1 = {best_val_f1:.4f}")
            
            return best_val_f1
            
        except optuna.TrialPruned:
            # Trial was pruned by Optuna
            raise
        
        except Exception as e:
            logger.info(f"\n‚ùå Trial {trial.number} failed: {e}")
            return 0.0
    
    return objective


def run_optuna_optimization(
    train_texts, train_labels,
    val_texts, val_labels,
    tokenizer,
    task_config: TaskConfig,
    base_training_config: TrainingConfig,
    n_trials: int = 30,
    study_name: str = "sikubert_tuning"
):
    """
    Run Optuna hyperparameter optimization.
    
    Args:
        n_trials: Number of trials to run
        study_name: Name of the study
    
    Returns:
        study: Optuna study object with all results
    """
    
    logger.info(f"\n{'='*70}")
    logger.info(f"OPTUNA BAYESIAN OPTIMIZATION")
    logger.info(f"{'='*70}")
    logger.info(f"Task: {task_config.task_name}")
    logger.info(f"Number of trials: {n_trials}")
    logger.info(f"Early stopping patience: {base_training_config.early_stopping_patience}")
    logger.info(f"{'='*70}\n")
    
    # Create objective function
    objective = create_optuna_objective(
        train_texts, train_labels,
        val_texts, val_labels,
        tokenizer, task_config, base_training_config
    )
    
    # Create study
    study = optuna.create_study(
        study_name=study_name,
        direction='maximize',  # Maximize F1
        pruner=optuna.pruners.MedianPruner(  # Prune unpromising trials
            n_startup_trials=5,
            n_warmup_steps=3
        )
    )
    
    # Run optimization
    study.optimize(
        objective,
        n_trials=n_trials,
        show_progress_bar=True
    )
    
    # ========================================================================
    # RESULTS
    # ========================================================================
    logger.info(f"\n{'='*70}")
    logger.info(f"OPTIMIZATION COMPLETE")
    logger.info(f"{'='*70}")
    
    logger.info(f"\nüìä Study Statistics:")
    logger.info(f"  Completed trials: {len(study.trials)}")
    logger.info(f"  Pruned trials:    {len([t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED])}")
    logger.info(f"  Failed trials:    {len([t for t in study.trials if t.state == optuna.trial.TrialState.FAIL])}")
    
    logger.info(f"\nüèÜ Best Trial:")
    best_trial = study.best_trial
    logger.info(f"  Trial number:  {best_trial.number}")
    logger.info(f"  Val F1:        {best_trial.value:.4f}")
    logger.info(f"\n  Best Hyperparameters:")
    for key, value in best_trial.params.items():
        if 'learning_rate' in key:
            logger.info(f"    {key}: {value:.2e}")
        else:
            logger.info(f"    {key}: {value}")
    
    # Save results
    results = {
        'study_name': study_name,
        'task': task_config.task_name,
        'n_trials': len(study.trials),
        'best_trial': {
            'number': best_trial.number,
            'value': best_trial.value,
            'params': best_trial.params
        },
        'all_trials': [
            {
                'number': t.number,
                'value': t.value,
                'params': t.params,
                'state': str(t.state)
            }
            for t in study.trials
        ],
        'timestamp': datetime.now().isoformat()
    }
    
    with open(f'outputs/optuna_{task_config.task_name}_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    logger.info(f"\n‚úì Results saved to outputs/optuna_{task_config.task_name}_results.json")
    
    return study

logger.info("‚úÖ Optuna optimization functions defined")


# ============================================================================
# VISUALIZATION
# ============================================================================

def visualize_optuna_results(study, task_name: str):
    """Visualize Optuna results"""
    try:
        # Plot 1: Optimization history
        fig1 = plot_optimization_history(study)
        fig1.write_html(f'outputs/optuna_{task_name}_history.html')
        logger.info(f"‚úì Saved optimization history plot")
        
        # Plot 2: Parameter importances
        # FIX: Use the Random Forest evaluator to avoid the NumPy/fANOVA ValueError
        fig2 = plot_param_importances(
            study, 
            evaluator=MeanDecreaseImpurityImportanceEvaluator()
        )
        
        # FIX: Save as HTML to avoid needing the 'kaleido' package
        fig2.write_html(f'outputs/optuna_{task_name}_importance.html')
        logger.info(f"‚úì Saved parameter importance plot")
        
    except ImportError as e:
        logger.info(f"‚ö†Ô∏è Visualization failed: {e}")
        logger.info("   Ensure 'plotly' is installed.")
    except Exception as e:
        logger.info(f"‚ö†Ô∏è An unexpected error occurred during visualization: {e}")

logger.info("‚úÖ Visualization functions defined")

# Training

##  Load Config

In [None]:
base_config = TrainingConfig(early_stopping_patience=3)

# 2. Setup
tokenizer = AutoTokenizer.from_pretrained(base_config.model_name)
logger.info("‚úì Tokenizer loaded")

logger.info(f"\n{'='*70}")
logger.info(f"Training: {task_config.task_name.upper()}")
logger.info(f"Labels: {task_config.labels}")
logger.info(f"Device: {base_config.device}")
logger.info(f"{'='*70}\n")

## Load data

In [None]:
# Load your data
train_texts, train_labels = load_data_auto(train_path)
val_texts, val_labels = load_data_auto(val_path)

# train_texts=train_texts[:5]
# train_labels=train_labels[:5]
# val_texts=val_texts[:5]
# val_labels=val_labels[:5]

# Validate data
logger.info("\nüìä Validating data...")
validate_data(train_texts, train_labels)
validate_data(val_texts, val_labels)

logger.info(f"\n‚úì Data loaded successfully:")
logger.info(f"  - Training set: {len(train_texts)} samples")
logger.info(f"  - Validation set: {len(val_texts)} samples")
logger.info(f"  - Train/Val ratio: {len(train_texts)/len(val_texts):.1f}:1")

# Create dataloaders
train_loader, val_loader = create_dataloaders(
    train_texts, train_labels, val_texts, val_labels,
    tokenizer, task_config, training_config=base_config
)

print_data_stats(train_texts, train_texts, PUNCTUATION_CONFIG)

## Run Optuna to fine best Hyper Parameters

In [None]:
# study = run_optuna_optimization(
#     train_texts, train_labels,
#     val_texts, val_labels,
#     tokenizer, task_config, base_config,
#     n_trials=1
# )

## Visualize Optuna

In [None]:
# visualize_optuna_results(study, TASK)

## Train final model with best hyperparameters

In [None]:
# best_params = study.best_params
# final_config = TrainingConfig(
#     learning_rate=best_params['learning_rate'],
#     batch_size=best_params['batch_size'],
#     warmup_ratio=best_params['warmup_ratio'],
#     weight_decay=best_params['weight_decay'],
#     dropout=best_params['dropout']
# )

# model = SikuBERTForTokenClassification(
#     final_config.model_name,
#     task_config.num_labels,
#     use_extra_layer=True,
#     extra_layer_type='cnn',
#     cnn_kernel_sizes=[3, 5, 7], # Custom kernels
#     cnn_num_filters=256 # Nhi·ªÅu filters h∆°n
# ).to(final_config.device)

# logger.info(f"‚úì Model created ({sum(p.numel() for p in model.parameters()):,} parameters)")
# save_path = f"models/best_{task_config.task_name}_model_cnn.pt"

# best_val_f1 = train_with_early_stopping(
#     model, train_loader, val_loader,
#     task_config, final_config,
#     trial=None,
#     save_path=save_path
# )

# logger.info("\nüéâ Training complete!")

In [None]:
# Training only
final_config=base_config

model = SikuBERTForTokenClassification(
    base_config.model_name,
    task_config.num_labels,
    use_extra_layer=True,
    extra_layer_type='cnn',
    cnn_kernel_sizes=[3, 5, 7], # Custom kernels
    cnn_num_filters=256 # Nhi·ªÅu filters h∆°n
).to(base_config.device)

logger.info(f"‚úì Model created ({sum(p.numel() for p in model.parameters()):,} parameters)")
save_path = f"models/best_{task_config.task_name}_model_cnn.pt"

best_val_f1 = train_with_early_stopping(
    model, train_loader, val_loader,
    task_config, base_config,
    trial=None,
    save_path=save_path
)

logger.info("\nüéâ Training complete!")

# Final Test

In [None]:
logger.info("\n" + "="*70)
logger.info("‚≠ê FINAL TEST SET EVALUATION")
logger.info("="*70)
logger.info("This is the OFFICIAL performance evaluation")
logger.info("Model has NEVER seen test data during training!")
logger.info("="*70)

# Load best model
model.load_state_dict(torch.load(save_path))
model.eval()

test_texts, test_labels = load_data_auto(test_path)

# Create test loader
test_dataset = ClassicalChineseDataset(
    test_texts, test_labels, tokenizer, task_config, final_config.max_length
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=final_config.batch_size, 
    shuffle=False
)

# Evaluate on test set
logger.info("\nüéØ Evaluating on TEST set...")
test_results = evaluate_model(model, test_loader, task_config, final_config.device, "test", final_config)

# Save results
import json
final_results = {
    'task': task_config.task_name,
    'test': test_results['overall'],  # ‚Üê OFFICIAL RESULT
    'test_per_label': test_results['per_label']
}

with open(f'outputs/{task_config.task_name}_test_results.json', 'w') as f:
    json.dump(final_results, f, indent=2, ensure_ascii=False)

logger.info("\n" + "="*70)
logger.info("‚úÖ OFFICIAL TEST RESULTS:")
logger.info(f"   Precision: {test_results['overall']['precision']:.4f}")
logger.info(f"   Recall:    {test_results['overall']['recall']:.4f}")
logger.info(f"   F1 Score:  {test_results['overall']['f1']:.4f}")
logger.info("="*70)

# INFERENCE & DEMO

In [None]:
def predict_text(model, text, tokenizer, config, device):
    """Predict labels for text"""
    model.eval()
    
    tokenized = tokenizer(list(text), is_split_into_words=True, 
                         return_tensors='pt', padding=True, truncation=True)
    input_ids = tokenized['input_ids'].to(device)
    attention_mask = tokenized['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs['logits'], dim=-1)
    
    word_ids = tokenized.word_ids(batch_index=0)
    predicted_labels = []
    
    for idx, word_id in enumerate(word_ids):
        if word_id is not None:
            pred_id = predictions[0][idx].item()
            predicted_labels.append(config.id2label[pred_id])
    
    return predicted_labels

In [None]:
# Demo
#"Á∂±ÈëëÊúÉÁ∑®Âç∑‰∫îÂçÅ‰∏â„ÄÇÊñáÂÆâÂäâÂæ∑Ëä≥Âå°Ë®èÊ≠£Â¥ëÂ±±ËëâÊæêÈ∫ÅËºØÈåÑÂîêÁ¥ÄÈ´òÂÆóÁöáÂ∏ùË´±Ê≤ªÔºåÂ≠óÁà≤ÂñÑÔºåÂ§™ÂÆóÁ¨¨‰πùÂ≠ê„ÄÇÂàùÂ∞Å£àÜÁèèÂæåÁ´ãÁà≤ÁöáÂ§™Â≠êÂú®‰Ωç‰∏âÂçÅÂõõÂπ¥Â¥©Â£Ω‰∫îÂçÅÂÖ≠Á∂±„ÄÇ"
#"Á∂±ÈëëÊúÉÁ∑®Âç∑‰∫îÂçÅ‰∏â | ÊñáÂÆâÂäâÂæ∑Ëä≥Âå°Ë®èÊ≠£Â¥ëÂ±±ËëâÊæêÈ∫ÅËºØÈåÑÂîêÁ¥ÄÈ´òÂÆóÁöáÂ∏ùË´±Ê≤ª | Â≠óÁà≤ÂñÑ | Â§™ÂÆóÁ¨¨‰πùÂ≠ê/ÂàùÂ∞Å£àÜÁèèÂæåÁ´ãÁà≤ÁöáÂ§™Â≠êÂú®‰Ωç‰∏âÂçÅÂõõÂπ¥Â¥©Â£Ω‰∫îÂçÅÂÖ≠Á∂± | "

test_text = "Á∂±ÈëëÊúÉÁ∑®Âç∑‰∫îÂçÅ‰∏âÊñáÂÆâÂäâÂæ∑Ëä≥Âå°Ë®èÊ≠£Â¥ëÂ±±ËëâÊæêÈ∫ÅËºØÈåÑÂîêÁ¥ÄÈ´òÂÆóÁöáÂ∏ùË´±Ê≤ªÂ≠óÁà≤ÂñÑÂ§™ÂÆóÁ¨¨‰πùÂ≠êÂàùÂ∞Å£àÜÁèèÂæåÁ´ãÁà≤ÁöáÂ§™Â≠êÂú®‰Ωç‰∏âÂçÅÂõõÂπ¥Â¥©Â£Ω‰∫îÂçÅÂÖ≠Á∂±"
predicted = predict_text(model, test_text, tokenizer, task_config, final_config.device)

logger.info(f"\n{'='*70}")
logger.info("DEMO INFERENCE")
logger.info(f"{'='*70}")
logger.info(f"\nText: {test_text}")
logger.info(f"Labels: {predicted}")  # Show first 20

if TASK == "punctuation":
    result = ''.join([c if l == 'O' else c+l for c, l in zip(test_text, predicted)])
    logger.info(f"\nPunctuated: {result}")
else:
    sentences = []
    current = []
    for c, l in zip(test_text, predicted):
        current.append(c)
        if l in ['E', 'S']:
            sentences.append(''.join(current))
            current = []
    if current:
        sentences.append(''.join(current))
    logger.info(f"\nSegmented: {' | '.join(sentences)}")

logger.info(f"\n{'='*70}")

In [None]:
def apply_punctuation_labels(text, labels):
    """
    text: str
    labels: List[str] with punctuation symbols or 'O'
    """
    output = []

    for ch, label in zip(text, labels):
        output.append(ch)
        if label != "O":
            output.append(label)

    return "".join(output)

In [None]:
def apply_segmentation_inline(text, labels, sep=" | "):
    output = []

    for ch, label in zip(text, labels):
        output.append(ch)
        if label in ("E", "S"):
            output.append(sep)

    return "".join(output).rstrip(sep)

In [None]:
def predict_labels(model, text, tokenizer, config, device):
    """
    Predict labels for ONE text (character-level)
    """
    model.eval()

    chars = list(text)

    tokenized = tokenizer(
        chars,
        is_split_into_words=True,
        return_tensors="pt",
        truncation=True
    )

    input_ids = tokenized["input_ids"].to(device)
    attention_mask = tokenized["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        preds = torch.argmax(outputs["logits"], dim=-1)[0]

    word_ids = tokenized.word_ids()
    pred_labels = []

    for idx, word_id in enumerate(word_ids):
        if word_id is not None:
            label_id = preds[idx].item()
            pred_labels.append(config.id2label[label_id])

    return pred_labels

In [None]:
def run_test_set(
    model,
    tokenizer,
    config,
    device,
    test_path,
    output_path,
):
    import json
    from tqdm import tqdm

    with open(test_path, "r", encoding="utf-8") as f:
        test_data = json.load(f)

    results = []

    for sample in tqdm(test_data):
        text = sample["text"]
        gold_labels = sample["labels"]

        pred_labels = predict_labels(
            model=model,
            text=text,
            tokenizer=tokenizer,
            config=config,
            device=device
        )

        # assert len(text) == len(gold_labels) == len(pred_labels)

        if config.task_name == "punctuation":
            gold_text = apply_punctuation_labels(text, gold_labels)
            pred_text = apply_punctuation_labels(text, pred_labels)

        elif config.task_name == "segmentation":
            gold_text = apply_segmentation_inline(text, gold_labels)
            pred_text = apply_segmentation_inline(text, pred_labels)

        else:
            raise ValueError(f"Unknown task: {config.task_name}")

        results.append({
            "text": text,
            "gold_labels": gold_labels,
            "pred_labels": pred_labels,
            "gold_text_labeled": gold_text,
            "pred_text_labeled": pred_text,
        })

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

In [None]:
run_test_set(
    model=model,
    tokenizer=tokenizer,
    config=task_config,
    device=device,
    test_path=test_path,
    output_path="/kaggle/working/test_pred.json"
)