# Style Classifier

In [None]:
import random
import numpy as np
import torch
from pathlib import Path

BACKBONE_PATH = str(Path('../Models/chinese-roberta-wwm-ext').resolve())
DATASET_JSON = Path('data/llm_train.json')
OUTPUT_DIR = Path('outputs/style-classifier')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print('Backbone path:', BACKBONE_PATH)
print('Dataset json:', DATASET_JSON.resolve())
print('Output dir:', OUTPUT_DIR.resolve())

## Load Dataset

In [2]:
# Data loading: unified from JSON only
from typing import Any
import pandas as pd
from typing import List, Dict
import json
from sklearn.model_selection import train_test_split

def load_text_role_dataframe() -> pd.DataFrame:
    path = DATASET_JSON
    if not path.exists():
        raise FileNotFoundError(
            f"未找到 {path}. 请先在 LLMFine-tuning.ipynb 中运行“导出训练集文件”单元格生成 llm_train.json。"
        )
    data: List[Dict[str, Any]] = json.loads(path.read_text(encoding='utf-8'))
    rows = []
    for it in data:
        text = it.get('output')
        role = it.get('character')
        if not text or not role:
            continue
        rows.append({'text': text, 'role': role})
    df = pd.DataFrame(rows).dropna().reset_index(drop=True)
    # 基础清洗：去除极短文本
    df = df[df['text'].str.len() > 5].reset_index(drop=True)
    return df

# 加载数据并预览
full_df = load_text_role_dataframe()

# Prepare train/val split from loaded DataFrame
assert {'text', 'role'}.issubset(set(full_df.columns))
print('Total samples:', len(full_df))
print('Class distribution:\n', full_df['role'].value_counts())

# Stratified split
train_df, val_df = train_test_split(full_df, test_size=0.2, random_state=SEED, stratify=full_df['role'])
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

print('\nTrain size:', len(train_df), 'Val size:', len(val_df))

Total samples: 5699
Class distribution:
 role
沐雪      2119
神里绫华    1358
凉宫春日     933
胡桃       816
钟离       473
Name: count, dtype: int64

Train size: 4559 Val size: 1140


## Define Model Structure

In [None]:
from transformers import AutoModel
from torch.utils.data import Dataset
import torch.nn as nn

# Dataset class
class TextRoleDataset(Dataset):
    def __init__(self, texts: List[str], roles: List[int], tokenizer, max_length=128):
        self.texts = texts
        self.roles = roles
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        role = self.roles[idx]
        encoded = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        item = {
            'input_ids': encoded['input_ids'].squeeze(0),
            'attention_mask': encoded['attention_mask'].squeeze(0),
            'label': torch.tensor(role, dtype=torch.long)
        }
        return item
    
# Model definition
class CharacterStyleClassifier(nn.Module):
    def __init__(self, backbone_name: str, embed_dim: int = 768, proj_dim: int = 256, num_roles: int = 6, dropout: float = 0.4):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(backbone_name, output_hidden_states=False)
        self.hidden_size = self.backbone.config.hidden_size
        assert self.hidden_size == embed_dim, f"Backbone hidden_size {self.hidden_size} != embed_dim {embed_dim}"

        self.proj = nn.Sequential(
            nn.Linear(self.hidden_size, proj_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.classifier = nn.Linear(proj_dim, num_roles)

    def mean_pooling(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_hidden = torch.sum(last_hidden_state * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
        mean_pooled = sum_hidden / sum_mask
        return mean_pooled

    def forward(self, input_ids, attention_mask):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        last_hidden = out.last_hidden_state
        sent_emb = self.mean_pooling(last_hidden, attention_mask)
        proj = self.proj(sent_emb)
        logits = self.classifier(proj)
        return logits, proj

## Define Utility Functions

In [4]:
import os

from typing import Tuple

from tokenizers import Tokenizer
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder

# Utilities: save/load, centers, style score

def save_checkpoint(model: nn.Module, tokenizer, label_encoder: LabelEncoder, role_centers: Dict[str, np.ndarray], path: str):
    os.makedirs(path, exist_ok=True)
    model.backbone.save_pretrained(path)  # type: ignore
    torch.save({'proj_state': model.proj.state_dict(), 'classifier_state': model.classifier.state_dict()}, os.path.join(path, 'head.pt'))  # type: ignore
    tokenizer.save_pretrained(path)
    with open(os.path.join(path, 'label_encoder.json'), 'w', encoding='utf-8') as f:
        json.dump({'classes': label_encoder.classes_.tolist()}, f, ensure_ascii=False)
    np.savez(os.path.join(path, 'role_centers.npz'), **role_centers)
    print(f"Saved checkpoint to {path}")


def load_checkpoint(path: str, device='cpu') -> Tuple[nn.Module, 'Tokenizer', LabelEncoder, Dict[str, np.ndarray]]:
    tokenizer = AutoTokenizer.from_pretrained(path)
    with open(os.path.join(path, 'label_encoder.json'), 'r', encoding='utf-8') as f:
        le_json = json.load(f)
    le = LabelEncoder()
    le.classes_ = np.array(le_json['classes'])

    # use the saved backbone
    # 1. 从 config 获取 hidden_size
    backbone_hidden = AutoModel.from_pretrained(path).config.hidden_size
    # 2. 直接从 `path` 初始化模型，这将加载微调后的 backbone
    model = CharacterStyleClassifier(path, embed_dim=backbone_hidden, proj_dim=256, num_roles=len(le.classes_))
    # 3. 加载 head 权重
    chk = torch.load(os.path.join(path, 'head.pt'), map_location=device)
    model.proj.load_state_dict(chk['proj_state'])
    model.classifier.load_state_dict(chk['classifier_state'])
    # 4. 加载 centers
    centers_npz = np.load(os.path.join(path, 'role_centers.npz'))
    role_centers = {k: centers_npz[k] for k in centers_npz.files}
    return model.to(device), tokenizer, le, role_centers


@torch.no_grad()
def compute_role_centers(model: nn.Module, dataloader: DataLoader, label_encoder: LabelEncoder, device='cpu') -> Dict[str, np.ndarray]:
    model.eval()
    accum: Dict[int, List[np.ndarray]] = {}
    for batch in tqdm(dataloader, desc="Computing centers"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].cpu().numpy()
        _, embeddings = model(input_ids=input_ids, attention_mask=attention_mask)
        emb_np = embeddings.detach().cpu().numpy()
        for lbl, e in zip(labels, emb_np):
            accum.setdefault(int(lbl), []).append(e)
    role_centers = {}
    for lbl, vecs in accum.items():
        avg = np.mean(np.stack(vecs, axis=0), axis=0)
        role_name = label_encoder.inverse_transform([lbl])[0]
        role_centers[role_name] = avg
    return role_centers


def get_style_score(model: nn.Module, tokenizer, text: str, role_center: np.ndarray, device='cpu', max_length=128) -> float:
    model.eval()
    enc = tokenizer(text, truncation=True, max_length=max_length, padding='max_length', return_tensors='pt')
    input_ids = enc['input_ids'].to(device)
    attention_mask = enc['attention_mask'].to(device)
    with torch.no_grad():
        _, emb = model(input_ids=input_ids, attention_mask=attention_mask)
        emb_np = emb.detach().cpu().numpy()[0]
    num = float(np.dot(emb_np, role_center))
    den = float(np.linalg.norm(emb_np) * np.linalg.norm(role_center) + 1e-9)
    return num / den

## Define Training Function

In [5]:
# Train / Evaluate
import math
from transformers import get_linear_schedule_with_warmup, AutoConfig
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from torch.optim import AdamW

def evaluate(model: nn.Module, dataloader: DataLoader, label_encoder: LabelEncoder, device='cpu') -> Dict:
    model.eval()
    preds = []
    trues = []
    embeddings = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].cpu().numpy()
            logits, emb = model(input_ids=input_ids, attention_mask=attention_mask)
            pred = torch.argmax(logits, dim=-1).cpu().numpy()
            preds.extend(pred.tolist())
            trues.extend(labels.tolist())
            embeddings.extend(emb.detach().cpu().numpy().tolist())

    acc = accuracy_score(trues, preds)
    macro_f1 = f1_score(trues, preds, average='macro')
    micro_f1 = f1_score(trues, preds, average='micro')
    report = classification_report(trues, preds, target_names=label_encoder.classes_, zero_division=0)
    cm = confusion_matrix(trues, preds)
    return {
        'accuracy': acc,
        'macro_f1': macro_f1,
        'micro_f1': micro_f1,
        'report': report,
        'confusion_matrix': cm,
        'preds': preds,
        'trues': trues,
        'embeddings': embeddings
    }


def train(
    train_df: pd.DataFrame,
    val_df: pd.DataFrame,
    backbone_name: str = BACKBONE_PATH,
    output_dir: str = str(OUTPUT_DIR),
    epochs: int = 3,
    batch_size: int = 32,
    lr: float = 2e-5,
    weight_decay: float = 3e-4,
    max_length: int = 128,
    device: str | None = None,
    patience: int = 3,
    seed: int = SEED
):
    """
    仅负责训练/验证与早停：
    - 不在训练过程中保存完整工件，避免与 save_checkpoint 重复。
    - 在出现更优指标时，临时保存整模型 state_dict（best_model_state.pt）。
    - 训练结束后，恢复至最佳权重，计算 centers，并返回 (model, tokenizer, le, role_centers)。
    外部调用者随后统一调用 save_checkpoint() 进行一次性落盘。
    """
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on device: {device}")

    le = LabelEncoder()
    le.fit(train_df['role'].tolist() + val_df['role'].tolist())
    num_roles = len(le.classes_)
    print('Roles:', le.classes_)

    print("Loading tokenizer and preparing datasets...")
    tokenizer = AutoTokenizer.from_pretrained(backbone_name)

    train_texts = train_df['text'].tolist()
    train_labels = le.transform(train_df['role'].tolist()).tolist()  # type:ignore
    val_texts = val_df['text'].tolist()
    val_labels = le.transform(val_df['role'].tolist()).tolist()  # type:ignore

    train_ds = TextRoleDataset(train_texts, train_labels, tokenizer, max_length)
    val_ds = TextRoleDataset(val_texts, val_labels, tokenizer, max_length)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    embed_dim = AutoConfig.from_pretrained(backbone_name).hidden_size
    model = CharacterStyleClassifier(backbone_name, embed_dim=embed_dim, proj_dim=256, num_roles=num_roles, dropout=0.4)
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=math.ceil(0.06 * total_steps), num_training_steps=total_steps)

    criterion = nn.CrossEntropyLoss()

    best_val_f1 = -1.0
    best_epoch = -1
    bad_cnt = 0

    # 临时保存最佳整模型权重（避免与最终 save_checkpoint 重复保存工件）
    tmp_best_path = os.path.join(output_dir, 'best_model_state.pt')
    os.makedirs(output_dir, exist_ok=True)

    print("Starting training...")
    for epoch in range(1, epochs + 1):
        model.train()
        losses = []
        pbar = tqdm(train_loader, desc=f'Epoch {epoch} training')
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits, _ = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            losses.append(loss.item())
            pbar.set_postfix(loss=np.mean(losses))

        # 验证
        val_metrics = evaluate(model, val_loader, le, device=device)
        val_f1 = val_metrics['macro_f1']
        print(f"Epoch {epoch} validation macro_f1: {val_f1:.4f}, acc: {val_metrics['accuracy']:.4f}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_epoch = epoch
            bad_cnt = 0
            # 仅保存一次整模型 state_dict（临时文件）
            torch.save(model.state_dict(), tmp_best_path)
            print(f"Updated best model at epoch {epoch} (macro_f1={val_f1:.4f}).")
        else:
            bad_cnt += 1
            if bad_cnt >= patience:
                print(f"Early stopping at epoch {epoch}, best epoch {best_epoch}, best_macro_f1 {best_val_f1:.4f}")
                break

    # 恢复最佳权重
    if os.path.exists(tmp_best_path):
        best_state = torch.load(tmp_best_path, map_location=device)
        model.load_state_dict(best_state)
        print(f"Restored best model from {tmp_best_path}")
    else:
        print("Warning: best_model_state.pt not found; using last epoch weights.")

    # 计算 centers（基于最佳权重）
    full_train_loader = DataLoader(TextRoleDataset(train_texts, train_labels, tokenizer, max_length), batch_size=batch_size, shuffle=False)
    role_centers = compute_role_centers(model, full_train_loader, le, device=device)

    # 可选：清理临时文件
    try:
        if os.path.exists(tmp_best_path):
            os.remove(tmp_best_path)
    except Exception:
        pass

    return model, tokenizer, le, role_centers

## Train Model

In [6]:
EPOCHS = 2
BATCH_SIZE = 32
LR = 2e-5

model, tokenizer, le, role_centers = train(
    train_df=train_df,
    val_df=val_df,
    backbone_name=BACKBONE_PATH,
    output_dir=str(OUTPUT_DIR),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
    patience=2
)

# Evaluate on validation set
val_ds = TextRoleDataset(val_df['text'].tolist(), le.transform(val_df['role'].tolist()).tolist(), tokenizer)  # type:ignore
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)
metrics = evaluate(model, val_loader, le, device=('cuda' if torch.cuda.is_available() else 'cpu'))
print('Validation report:\n', metrics['report'])

# Save full checkpoint
save_checkpoint(model, tokenizer, le, role_centers, str(OUTPUT_DIR))

Training on device: cuda
Roles: ['凉宫春日' '沐雪' '神里绫华' '胡桃' '钟离']
Loading tokenizer and preparing datasets...
Starting training...


Epoch 1 training: 100%|██████████| 143/143 [00:48<00:00,  2.94it/s, loss=0.846]


Epoch 1 validation macro_f1: 0.8463, acc: 0.8868
Updated best model at epoch 1 (macro_f1=0.8463).


Epoch 2 training: 100%|██████████| 143/143 [00:47<00:00,  3.02it/s, loss=0.297]


Epoch 2 validation macro_f1: 0.8816, acc: 0.9132
Updated best model at epoch 2 (macro_f1=0.8816).
Restored best model from outputs\style-classifier\best_model_state.pt


Computing centers: 100%|██████████| 143/143 [00:23<00:00,  6.04it/s]


Validation report:
               precision    recall  f1-score   support

        凉宫春日       0.89      0.90      0.90       187
          沐雪       0.98      0.96      0.97       424
        神里绫华       0.86      0.93      0.89       272
          胡桃       0.92      0.93      0.92       163
          钟离       0.82      0.66      0.73        94

    accuracy                           0.91      1140
   macro avg       0.89      0.88      0.88      1140
weighted avg       0.91      0.91      0.91      1140

Saved checkpoint to outputs\style-classifier


In [7]:
# Inference demo: style score on one sample
sample_idx = 0
sample_text = val_df['text'].iloc[sample_idx]
sample_role = val_df['role'].iloc[sample_idx]

score = get_style_score(
    model, tokenizer, sample_text, role_centers[sample_role],
    device=('cuda' if torch.cuda.is_available() else 'cpu')
)
print(f"Sample role: {sample_role}\nStyle score (cosine to center): {score:.4f}\nText: {sample_text[:120]}...")

Sample role: 凉宫春日
Style score (cosine to center): 0.5417
Text: 你们好!我前来征收一台电脑!...
