# Syntactic Pattern Dimension Ablation Test

In [1]:
# Run this on Remote Jupyter Book.

import os

os.getcwd()
# os.chdir("/root/OtakuLab")

'/root/OtakuLab'

## Define PCFG Mapping

### d=10

In [1]:
rule_to_10dim = {
    # Referentiality: 专名、指称性
    "NP -> NR": "referentiality",
    "NP -> PN": "referentiality",
    "NP -> NR NN": "referentiality",
    "DP -> DT": "referentiality",
    "NP -> NN NN": "nominal_complexity",

    # Interjectionality: 感叹与插入语气
    "INTJ -> IJ": "interjectionality",
    "FLR -> IJ": "interjectionality",
    "FLR -> SP": "interjectionality",
    "IP -> INTJ PU VP": "interjectionality",
    "IP -> INTJ VP": "interjectionality",

    # Declarativity: 陈述语气/句型层级
    "CP -> IP SP": "declarativity",
    "CP -> IP SP PU": "declarativity",
    "CP -> IP DEC": "declarativity",
    "TOP -> CP": "declarativity",
    "CP -> CP": "declarativity",
    "TOP -> IP": "declarativity",
    "VP -> VC NP": "declarativity",

    # Clausal Embedding: 子句嵌套
    "VP -> VV IP": "clausal_embedding",
    "LCP -> IP LC": "clausal_embedding",
    "IP -> ADVP PU NP VP": "clausal_embedding",
    "NP -> CP NP": "clausal_embedding",

    # Subordination: 修饰性从属结构
    "VP -> ADVP VP": "subordination",
    "IP -> NP VP": "subordination",
    "IP -> VP": "subordination",
    "VP -> VV NP": "subordination",
    "VP -> VA": "subordination",
    "IP -> VP SP": "subordination",
    "VP -> PP VP": "subordination",

    # Parallelism: 句式并列
    "VP -> VP PU VP": "parallelism",
    "VP -> VP PU VP PU VP": "parallelism",
    "UCP -> IP PU CP": "parallelism",
    "IP -> VP PU": "parallelism",
    "TOP -> UCP": "parallelism",

    # Modifier Density: 修饰成分密度
    "DNP -> ADJP DEG": "modifier_density",
    "DNP -> NP DEG": "modifier_density",
    "NP -> DNP NP": "modifier_density",
    "ADVP -> AD": "modifier_density",
    "ADJP -> JJ": "modifier_density",

    # Nominal Complexity: 名词短语复杂度
    "NP -> ADJP NP": "nominal_complexity",
    "NP -> DNP NP": "nominal_complexity",
    "NP -> NN NN": "nominal_complexity",


    # Topic Fronting: 话题提前结构
    "TOP -> NP IP": "topic_fronting",

    # Ellipsis or Fragmentation: 口语省略、残缺句
    "IP -> VP": "ellipsis_or_fragmentation",
    "IP -> ADVP VP": "ellipsis_or_fragmentation",

    # Deep Embedding: 深层嵌套
    "VP -> VV NP IP": "clausal_embedding"
}


### d=18

In [2]:
# ---- Revised PCFG-to-Dimension Mapping ----
rule_to_18dim = {
    # Referentiality: 专名、指称性
    "NP -> NR": "referentiality",
    "NP -> PN": "referentiality",
    "NP -> NR NN": "referentiality",
    "DP -> DT": "referentiality",
    "NP -> NN NN": "nominal_complexity",

    # Interjectionality: 感叹与插入语气
    "INTJ -> IJ": "interjectionality",
    "FLR -> IJ": "interjectionality",
    "FLR -> SP": "interjectionality",
    "IP -> INTJ PU VP": "interjectionality",
    "IP -> INTJ VP": "interjectionality",

    # Declarativity: 陈述语气/句型层级
    "CP -> IP SP": "declarativity",
    "CP -> IP SP PU": "declarativity",
    "CP -> IP DEC": "declarativity",
    "TOP -> CP": "declarativity",
    "CP -> CP": "declarativity",
    "TOP -> IP": "declarativity",
    "VP -> VC NP": "declarativity",

    # Clausal Embedding: 子句嵌套
    "VP -> VV IP": "clausal_embedding",
    "LCP -> IP LC": "clausal_embedding",
    "IP -> ADVP PU NP VP": "clausal_embedding",
    "NP -> CP NP": "clausal_embedding",

    # Subordination: 修饰性从属结构
    "VP -> ADVP VP": "subordination",
    "IP -> NP VP": "subordination",
    "IP -> VP": "subordination",
    "VP -> VV NP": "subordination",
    "VP -> VA": "subordination",
    "IP -> VP SP": "subordination",
    "VP -> PP VP": "subordination",

    # Parallelism: 句式并列
    "VP -> VP PU VP": "parallelism",
    "VP -> VP PU VP PU VP": "parallelism",
    "UCP -> IP PU CP": "parallelism",
    "IP -> VP PU": "parallelism",
    "TOP -> UCP": "parallelism",

    # Coordination Density: 并列结构复杂度
    "NP -> NP CC NP": "coordination_density",
    "NP -> NN CC NN": "coordination_density",
    "CP -> CP CC CP": "coordination_density",

    # Modifier Density: 修饰成分密度
    "DNP -> ADJP DEG": "modifier_density",
    "DNP -> NP DEG": "modifier_density",
    "NP -> DNP NP": "modifier_density",
    "ADVP -> AD": "modifier_density",
    "ADJP -> JJ": "modifier_density",

    # Nominal Complexity: 名词短语复杂度
    "NP -> ADJP NP": "nominal_complexity",
    "NP -> DNP NP": "nominal_complexity",
    "NP -> NN NN": "nominal_complexity",

    # Prepositional Density: 介词结构使用密度
    "PP -> P LCP": "prepositional_density",
    "PP -> P NP": "prepositional_density",

    # Topic Fronting: 话题提前结构
    "TOP -> NP IP": "topic_fronting",

    # Ellipsis or Fragmentation: 口语省略、残缺句
    "IP -> VP": "ellipsis_or_fragmentation",
    "IP -> ADVP VP": "ellipsis_or_fragmentation",

    # Syntactic Compression: 句法压缩
    "NP -> NN": "syntactic_compression",
    "VP -> VV": "syntactic_compression",
    "VP -> VV VP": "syntactic_compression",

    # Quantificationality: 数量词使用
    "CLP -> M": "quantificationality",
    "QP -> CD CLP": "quantificationality",
    "NP -> DP NP": "quantificationality",
    "NP -> QP CP NP": "quantificationality",

    # Deep Embedding: 深层嵌套
    "VP -> VV NP IP": "clausal_embedding",

    # Aspectuality: 体貌表达（了 / 过 / 着）
    "VP -> VV AS": "aspectuality",
    "VP -> VV AS NP": "aspectuality",

    # Existential / Presentational: 存现与判断
    "VP -> VE NP": "existentiality",
    "IP -> NP VE NP": "existentiality",

    # Serial Verb Construction: 连动结构
    "VP -> VP VP": "serial_verb_construction",
    "VP -> VV NP VP": "serial_verb_construction",

    # Discourse Framing: 话语框架 / 状语堆叠
    "VP -> ADVP ADVP VP": "discourse_framing",
    "IP -> ADVP NP VP": "discourse_framing",
    "ADVP -> CS": "discourse_framing",

}


## Data Processing and Training Pipeline

In [3]:
import os
import json
import re
from pathlib import Path
from typing import Optional, Any, Dict, Iterable, Tuple
from collections import Counter, defaultdict
from typing_extensions import TypedDict

BASE_DIR = Path(os.getcwd())

DATASET_DIR = BASE_DIR / "dataset"
OUTPUTS_DIR = BASE_DIR / "evaluate" / "outputs" / "DimensionDissolution"
OUTPUTS_DIR.mkdir(exist_ok=True, parents=True)

NEUTRAL_SENTENCES_FILE = DATASET_DIR / "neutral_sentences_with_CoT.jsonl"

class InstructionComponents(TypedDict):
    lexical_keywords: list[str]
    syntactic_vector: dict[str, float]
    pragmatic_styles: list[str]

class DatasetItem(TypedDict):
    character: str
    neutral_sentence: str
    instruction_components: InstructionComponents | dict[str, list[str]|dict]
    thinking_process: str
    output: str

class DatasetStorage:
    def __init__(self) -> None:
        self.items: dict[str, DatasetItem] = {}

    def __len__(self) -> int:
        return len(self.items)

    def new_item(self, character: str, neutral_sentence: str, thinking_process:str, output: str):
        item = DatasetItem(character=character,
                           neutral_sentence=neutral_sentence,
                           instruction_components={},
                           thinking_process=thinking_process,
                           output=output)
        self.items[output] = item

    def save_characters_keywords(self, character: str, lexical_keywords: list[str]):
        saved = False
        for item in self.items.values():
            if item['character'] == character:
                output = item['output']
                self.items[output]['instruction_components']['lexical_keywords'] = lexical_keywords
                saved = True
        if not saved:
            raise ValueError(f"角色 {character} 未找到")        

    def save_component(self,
                       output: str,
                       syntactic_vector: Optional[dict[str, float]] = None,
                       pragmatic_styles: Optional[list[str]] = None):
        if output not in self.items.keys():
            raise ValueError(f"风格句: {output} 似乎未加载")
                
        if syntactic_vector:
            self.items[output]['instruction_components']['syntactic_vector'] = syntactic_vector
        if pragmatic_styles:
            self.items[output]['instruction_components']['pragmatic_styles'] = pragmatic_styles

    @staticmethod
    def _verify_validity(item: DatasetItem):
        if not all((item['character'],
                    item['neutral_sentence'],
                    item['output'],
                    item['instruction_components'])):
            return False
        
        instruction_components = item['instruction_components']
        
        return all((
            isinstance(instruction_components.get('lexical_keywords'), list),
            isinstance(instruction_components.get('pragmatic_styles'), list),
            isinstance(instruction_components.get("syntactic_vector"), dict)
        ))
    
    @staticmethod
    def _oversampling(items_list: list[DatasetItem], item: DatasetItem):
        difficult_labels_5x = {'tsundere', 'sharp_tongued', 'proud'}
        difficult_labels_3x = {'tsukkomi', 'chuunibyou', 'yandere', 'airhead'}
        pragmatic_styles = set(item['instruction_components'].get('pragmatic_styles', []))
        
        if pragmatic_styles & difficult_labels_5x:
            for _ in range(5):
                items_list.append(item)
        elif pragmatic_styles & difficult_labels_3x:
            for _ in range(3):
                items_list.append(item)

    def output(self, output_path: Path, oversampling: bool = False):
        items = list(self.items.values())
        vaild_items = []
        vaild_items_count = 0
        
        for item in items:
            if not self._verify_validity(item):
                continue
            vaild_items.append(item)
            vaild_items_count += 1
            
            if oversampling:
                self._oversampling(vaild_items, item)

        output_path.parent.mkdir(parents=True, exist_ok=True)
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(vaild_items, f, ensure_ascii=False, indent=2)

        print(f"训练集已导出至 {output_path}, 总有效训练集数量: {len(vaild_items)}, 跳过数量: {len(items) - vaild_items_count}")


class PCFGExtractor:
    def __init__(self, rule_mapping: dict[str, str]) -> None:
        self.rules_counter: Dict[str, Counter[Tuple[str, ...]]] = defaultdict(Counter)
        self.total_rules: int = 0
        self.rule_mapping = rule_mapping

    def load_trees(self, file_path: Path) -> list[dict[str, Any]]:
        with open(file_path, "r", encoding="utf-8") as f:
            return json.load(f)

    def extract_rules_from_tree(self, tree: Any) -> None:
        if not isinstance(tree, list) or len(tree) != 2:
            return
        lhs_symbol, rhs = tree
        if isinstance(rhs, list) and all(isinstance(child, list) and len(child) == 2 for child in rhs):
            rhs_symbols = tuple(child[0] for child in rhs)
            self.rules_counter[lhs_symbol][rhs_symbols] += 1
            self.total_rules += 1
            for child in rhs:
                self.extract_rules_from_tree(child)

    def extract_from_data(self, data: Iterable[dict[str, Any]]) -> None:
        for item in data:
            for tree in item.get("con", []):
                self.extract_rules_from_tree(tree)

    def build_feature_vectors(self) -> dict[str, float]:
        feature_vector: dict[str, float] = defaultdict(float)
        unmapped_rules_counter = Counter()
        unmapped_rules_count = 0

        # 确保所有维度都有初始值 0.0，避免维度缺失
        all_dims = set(self.rule_mapping.values())
        for dim in all_dims:
            feature_vector[dim] = 0.0

        for lhs_symbol, rhs_counter in self.rules_counter.items():
            for rhs_symbols, freq in rhs_counter.items():
                rule = f"{lhs_symbol} -> {' '.join(rhs_symbols)}"
                dim = self.rule_mapping.get(rule)
                if not dim:
                    unmapped_rules_count += freq
                    unmapped_rules_counter[rule] += freq
                    continue
                feature_vector[dim] += freq

        print(f"Total rules processed: {self.total_rules}, "
              f"Unmapped rules: {unmapped_rules_count}({unmapped_rules_count / self.total_rules * 100:.2f}%), "
              f"Mapped rules: {self.total_rules - unmapped_rules_count}({(self.total_rules - unmapped_rules_count) / self.total_rules * 100:.2f}%)")

        total = sum(feature_vector.values()) or 1.0
        return {dim: freq / total for dim, freq in feature_vector.items()}

# === 统一句法向量计算：先抽取 rule 计数，再按维度映射聚合 ===
# 这样 10d/18d 的 syntactic_vec 都来自同一份底层 rule 分布，避免维度定义不一致导致的偏差。
_PCFG_RULE_COUNT_CACHE: dict[str, tuple[dict[str, int], int]] = {}

def extract_pcfg_rule_counts(path: Path) -> tuple[dict[str, int], int]:
    key = str(path.resolve())
    if key in _PCFG_RULE_COUNT_CACHE:
        return _PCFG_RULE_COUNT_CACHE[key]

    extractor = PCFGExtractor(rule_mapping={})
    trees_data = extractor.load_trees(path)
    extractor.extract_from_data(trees_data)

    rule_counts: dict[str, int] = {}
    for lhs_symbol, rhs_counter in extractor.rules_counter.items():
        for rhs_symbols, count in rhs_counter.items():
            rule = f"{lhs_symbol} -> {' '.join(rhs_symbols)}"
            rule_counts[rule] = int(count)

    result = (rule_counts, int(extractor.total_rules))
    _PCFG_RULE_COUNT_CACHE[key] = result
    return result

def aggregate_rule_counts_to_dims(rule_counts: dict[str, int], rule_mapping: dict[str, str]) -> tuple[dict[str, float], int, int]:
    feature_counts: dict[str, int] = {dim: 0 for dim in set(rule_mapping.values())}
    unmapped_rules_count = 0
    for rule, count in rule_counts.items():
        dim = rule_mapping.get(rule)
        if not dim:
            unmapped_rules_count += count
            continue
        feature_counts[dim] += count

    mapped_rules_count = int(sum(feature_counts.values()))
    total = mapped_rules_count or 1
    feature_vector = {dim: c / total for dim, c in feature_counts.items()}
    return feature_vector, mapped_rules_count, unmapped_rules_count

def load_and_extract_pcfg(path: Path, rule_mapping: dict[str, str]) -> dict[str, float]:
    rule_counts, total_rules = extract_pcfg_rule_counts(path)
    feature_vector, mapped_rules, unmapped_rules = aggregate_rule_counts_to_dims(rule_counts, rule_mapping)
    if total_rules > 0:
        print(
            f"Total rules processed: {total_rules}, "
            f"Unmapped rules: {unmapped_rules}({unmapped_rules / total_rules * 100:.2f}%), "
            f"Mapped rules: {mapped_rules}({mapped_rules / total_rules * 100:.2f}%)"
        )
    return feature_vector


In [None]:
# === 数据集生成函数 ===

EN_NAME_TO_ZH = {"Muice": "沐雪", "Ayaka": "神里绫华", "Zhongli": "钟离", "Hutao": "胡桃", "Haruhi": "凉宫春日"}

def generate_dataset_for_dim(dim_name: str, rule_mapping: dict[str, str]):
    """
    根据给定的 rule_mapping 生成训练集
    dim_name: "10d" 或 "18d" 等标识
    """
    print(f"Generating {dim_name} dataset...")
    storage = DatasetStorage()
    
    # 1. 加载中性句
    print("Loading neutral sentences...")
    with open(NEUTRAL_SENTENCES_FILE, "r", encoding="utf-8") as f:
        jsonl_lines = f.readlines()
    for line in jsonl_lines:
        if line := line.rstrip():
            item = json.loads(line)
            character = EN_NAME_TO_ZH.get(item["character"], None) or item["character"]
            storage.new_item(character, item["neutral"], item['CoT'], item["original"])
    
    # 2. 加载词汇层向量 (复用输出目录下的文件)
    print("Loading lexical keywords...")
    def get_lexical(name):
        path = BASE_DIR / "outputs" / "pmi" / f"{name}_pmi_filtered.json"
        if not path.exists():
             # Fallback logic or empty
             return []
        with open(path, "r", encoding="utf-8") as f:
            return list(json.loads(f.read()).keys())[:25]

    storage.save_characters_keywords("沐雪", get_lexical("muice"))
    storage.save_characters_keywords("神里绫华", get_lexical("ayaka"))
    storage.save_characters_keywords("钟离", get_lexical("zhongli"))
    storage.save_characters_keywords("胡桃", get_lexical("hutao"))
    storage.save_characters_keywords("凉宫春日", get_lexical("haruhi"))
    
    # 3. 提取 PCFG 向量
    print("Extracting PCFG vectors...")
    
    def process_pcfg(char_zh, char_en):
        path = BASE_DIR / "outputs" / "cons" / f"{char_en}.json"
        if not path.exists():
            print(f"Warning: PCFG file not found for {char_en}")
            return
        vector = load_and_extract_pcfg(path, rule_mapping)
        # 为该角色的所有句子保存 PCFG 向量
        count = 0
        for item in storage.items.values():
            if item["character"] == char_zh:
                storage.save_component(item["output"], syntactic_vector=vector)
                count += 1

    process_pcfg("沐雪", "muice")
    process_pcfg("神里绫华", "ayaka")
    process_pcfg("钟离", "zhongli")
    process_pcfg("胡桃", "hutao")
    process_pcfg("凉宫春日", "haruhi")

    # 4. 加载语用风格
    print("Loading pragmatic styles...")
    
    class RawPCFGItem(TypedDict):
        response: str
        pragmatic_styles: list[dict[str, float]] # Note: structure might vary, handled below

    def load_pragmatic(char_en):
        path = BASE_DIR/ "outputs" / "pragmatic" / f"{char_en}.jsonl"
        if not path.exists():
            return
        with open(path, "r", encoding="utf-8") as f:
            lines = f.readlines()
        
        count = 0
        for line in lines:
            if not line.strip(): continue
            raw_item = json.loads(line)
            # Flatten styles
            raw_styles = raw_item["pragmatic_styles"]
            styles_map = {}
            for vec in raw_styles:
                styles_map.update(vec)
            
            final_styles = [k for k, v in styles_map.items() if v > 0.4]
            try:
                storage.save_component(raw_item["response"], pragmatic_styles=final_styles)
                count += 1
            except Exception:
                pass
        print(f"  Loaded pragmatic styles for {char_en} ({count} items)")
        
    load_pragmatic("muice")
    load_pragmatic("ayaka")
    load_pragmatic("zhongli")
    load_pragmatic("hutao")
    load_pragmatic("haruhi")

    # 5. 导出
    oversample_file = OUTPUTS_DIR / f"llm_train_{dim_name}_oversampling.json"
    storage.output(oversample_file, oversampling=True)
    return oversample_file

# 生成两个版本的数据集
data_path_10d = generate_dataset_for_dim("10d", rule_to_10dim)
data_path_18d = generate_dataset_for_dim("18d", rule_to_18dim)


## Model Definition and Training
Define model structures that support style vector inputs of different dimensions, and provide training functions.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, Qwen3ForCausalLM
from transformers.data.data_collator import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from peft import get_peft_model, LoraConfig, PeftModel
from sklearn.model_selection import train_test_split
from tqdm.autonotebook import tqdm
import math

BASE_MODEL = Path("../Models/Qwen3-1.7B")

class StyleDataset(Dataset):
    def __init__(self, path:Path, tokenizer, prag_style_vocab: list[str], syntactic_dims: list[str], max_length=256):
        with open(path, "r", encoding="utf-8") as f:
            self.data: list[DatasetItem] = json.load(f)
        self.tokenizer = tokenizer
        self.prag_vocab = {style: i for i, style in enumerate(prag_style_vocab)}
        self.syntactic_dims = syntactic_dims
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx:int):
        item = self.data[idx]
        instr = item["instruction_components"]
        # instr可能有些是空的，这里做个简单健壮性处理
        if not isinstance(instr, dict):
             # Should be filtered by validity check, but strictly typed
             print(f"Warning: instruction_components is not a dict for item idx {idx}")
             instr = {"lexical_keywords":[], "syntactic_vector":{}, "pragmatic_styles":[]} # type:ignore

        keywords = ", ".join(instr.get("lexical_keywords", [])) if instr.get("lexical_keywords") else "None"
        pragmatic_styles = ", ".join(instr.get("pragmatic_styles", [])) if instr.get("pragmatic_styles") else "None"
        
        system_prompt = "You are a style transfer expert. Your task is to generate a new sentence that matches the target style, based on the content of a neutral sentence."
        user_prompt = (
            f"Target Character {item['character']}\n"
            f"Personality: {pragmatic_styles}\n"
            f"Keywords: {keywords}\n"
            f"Neutral Content: {item['neutral_sentence']}\n"
        )
        assistant_response = f"{item['thinking_process']}\n\n{item['output']}"

        full_text = self.tokenizer.apply_chat_template(
            [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
                {"role": "assistant", "content": assistant_response}
            ],
            tokenize=False,
            add_generation_prompt=False,
            enable_thinking=False
        )

        self.tokenizer.truncation_side = "left"
        full_tokenized = self.tokenizer(
            full_text,
            truncation=True,
            max_length=self.max_length,
            padding=False,
        )
        input_ids = full_tokenized["input_ids"] 

        prompt_only_text = self.tokenizer.apply_chat_template(
            [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            tokenize=False,
            add_generation_prompt=True
        )
        prompt_len = len(self.tokenizer(prompt_only_text).input_ids)

        labels = input_ids.copy()
        # mask prompt
        labels_len = len(labels)
        mask_len = min(prompt_len, labels_len)
        labels[:mask_len] = [-100] * mask_len

        syn_dict_raw = instr.get("syntactic_vector", {})
        syn_dict = syn_dict_raw if isinstance(syn_dict_raw, dict) else {}
        syntactic_vec = torch.tensor(
            [syn_dict.get(dim, 0.0) for dim in self.syntactic_dims],
            dtype=torch.float
        )
        prag_vec = torch.zeros(len(self.prag_vocab), dtype=torch.float32)

        for tag in instr.get("pragmatic_styles", []):
            if tag in self.prag_vocab:
                prag_vec[self.prag_vocab[tag]] = 1.0

        return {
            "input_ids": input_ids,
            "syntactic_vec": syntactic_vec,
            "prag_vec": prag_vec,
            "labels": labels
        }

class StyleDataCollator(DataCollatorWithPadding):
    def __init__(self, tokenizer, **kwargs):
        super().__init__(tokenizer=tokenizer, padding=True, **kwargs)

    def __call__(self, features):
        syntactic_vecs = torch.stack([f["syntactic_vec"] for f in features])
        prag_vecs = torch.stack([f["prag_vec"] for f in features])
        labels = [f["labels"] for f in features]

        base_features = [
            {k: v for k, v in f.items() if k not in ("syntactic_vec", "prag_vec", "labels")}
            for f in features
        ]
        batch = super().__call__(base_features)

        max_label_length = batch["input_ids"].shape[1]
        padded_labels = []
        for label_seq in labels:
            truncated = label_seq[:max_label_length]
            padding = max_label_length - len(truncated)
            padded_labels.append(truncated + [-100] * padding)
        
        batch["labels"] = torch.tensor(padded_labels, dtype=torch.long)
        batch["syntactic_vec"] = syntactic_vecs
        batch["prag_vec"] = prag_vecs
        return batch

class StyleEncoder(nn.Module):
    def __init__(self, input_dim, out_dim):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(input_dim, out_dim // 2),
            nn.ReLU(),
            nn.Linear(out_dim // 2, out_dim),
            nn.Tanh()
        )

    def forward(self, syntactic_vec):
        return self.proj(syntactic_vec)
    
class StyleConditionedLoRAModel(nn.Module):
    def __init__(self, model_name_or_path: str|Path, syntactic_dim: int, lora_r=16, lora_alpha=16):
        super().__init__()
        base_model = Qwen3ForCausalLM.from_pretrained(
            model_name_or_path,
            dtype=torch.bfloat16
        )
        config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        self.peft_model: PeftModel = get_peft_model(base_model, config) # type: ignore
        # self.peft_model.print_trainable_parameters()
        
        self.hidden_size = base_model.config.hidden_size
        self.style_encoder = StyleEncoder(syntactic_dim, self.hidden_size)
        self.style_encoder.to(base_model.dtype)

    def get_input_embeddings(self):
        return self.peft_model.get_input_embeddings() # type: ignore

    def forward(self, input_ids, attention_mask, labels, syntactic_vec):
        style_emb = self.style_encoder(syntactic_vec.to(self.peft_model.dtype)) # type: ignore
        style_emb_prefix = style_emb.unsqueeze(1)
        token_embeds = self.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([style_emb_prefix, token_embeds], dim=1)
        
        prefix_mask = torch.ones(attention_mask.shape[0], 1, dtype=torch.long, device=attention_mask.device)
        new_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
        
        prefix_labels = torch.full((labels.shape[0], 1), -100, dtype=torch.long, device=labels.device)
        new_labels = torch.cat([prefix_labels, labels], dim=1)
        
        outputs = self.peft_model(
            inputs_embeds=inputs_embeds,
            attention_mask=new_attention_mask,
            labels=new_labels,
            output_hidden_states=True 
        )
        return {
            "loss": outputs.loss,
            "logits": outputs.logits,
            "last_hidden_state": outputs.hidden_states[-1] 
        }

# === 训练函数 ===

def run_training_experiment(dataset_path: Path, syntactic_dims: list[str], output_dir: Path, epochs=2):
    print(f"\nStarted training experiment for dataset: {dataset_path}")
    print(f"Syntactic Dimensions ({len(syntactic_dims)}): {syntactic_dims}")
    
    prag_style_vocab = ['kind', 'modest', 'clingy', 'playful', 'cold', 'proud', 'sharp_tongued', 'subservient', 'submissive', 'controlling',
                        'strong', 'defensive', 'tsukkomi', 'rational', 'curious', 'imaginative', 'cautious', 'idealistic', 'conservative',
                        'radical', 'obsessive', 'hesitant', 'energetic', 'optimistic', 'confident', 'passionate', 'melancholy', 'serious',
                        'emotional', 'sensitive', 'shy', 'irritable', 'anxious', 'lazy', 'tsundere', 'yandere', 'chuunibyou', 'cute', 'naive',
                        'airhead', 'elegant', 'humorous', 'loyal', 'responsible', 'willful', 'antisocial', 'talkative', 'masochistic', 'sadistic', 'evil']
    
    BATCH_SIZE = 6
    LEARNING_RATE = 1e-5
    lambda_recon = 0.05
    lambda_style = 0.5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load Setup
    try:
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    except Exception:
        print("Model path not found, using a dummy path for compilation check or valid path required.")
        return

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token 

    dataset = StyleDataset(dataset_path, tokenizer, prag_style_vocab, syntactic_dims)
        
    train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)
    collator = StyleDataCollator(tokenizer, pad_to_multiple_of=8)
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, collate_fn=collator)
    
    styled_model = StyleConditionedLoRAModel(BASE_MODEL, len(syntactic_dims))
    styled_model.to(device)
    
    mse_loss_fn = nn.MSELoss()
    bce_loss_fn = nn.BCEWithLogitsLoss()
    style_predictor_head = nn.Linear(styled_model.hidden_size, len(syntactic_dims) + len(prag_style_vocab)).to(device)
    
    optimizer = torch.optim.AdamW(list(styled_model.parameters()) + list(style_predictor_head.parameters()), lr=LEARNING_RATE)
    num_training_steps = len(train_loader) * epochs
    lr_scheduler = get_scheduler("cosine", optimizer=optimizer, num_warmup_steps=100, num_training_steps=num_training_steps)
    
    # Training Loop
    styled_model.train()
    style_predictor_head.train()
    global_step = 0
    epsilon = 1e-8 

    for epoch in range(epochs):
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            syntactic_vec = batch["syntactic_vec"].to(device)
            pragmatic_tags = batch["prag_vec"].to(device)

            outputs = styled_model(input_ids, attention_mask, labels, syntactic_vec)
            
            L_lm = outputs["loss"]
            style_prefix_hidden_state = outputs["last_hidden_state"][:, 0] 
            
            aux_preds = style_predictor_head(style_prefix_hidden_state.float())
            pred_syntactic = aux_preds[:, :len(syntactic_dims)]
            pred_pragmatic = aux_preds[:, len(syntactic_dims):]
            
            L_recon = mse_loss_fn(pred_syntactic, syntactic_vec)
            L_style_cls = bce_loss_fn(pred_pragmatic, pragmatic_tags.float())
            
            L_total = L_lm + lambda_recon * (L_recon / (L_recon.detach() + epsilon)) \
                   + lambda_style * (L_style_cls / (L_style_cls.detach() + epsilon))
            
            L_total.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            global_step += 1
            
            progress_bar.set_postfix({"L_total": L_total.item(), "L_lm": L_lm.item()})
            
    # Save
    output_dir.mkdir(parents=True, exist_ok=True)
    styled_model.peft_model.save_pretrained(output_dir / "lora") # type: ignore
    torch.save(styled_model.style_encoder.state_dict(), output_dir / "style_encoder.pt")
    torch.save(style_predictor_head.state_dict(), output_dir / "style_predictor_head.pt")
    tokenizer.save_pretrained(output_dir / "tokenizer")
    print(f"Model saved to {output_dir}")

# 执行训练
syntactic_dims_10 = sorted(list(set(rule_to_10dim.values())))
syntactic_dims_18 = sorted(list(set(rule_to_18dim.values())))


In [None]:
run_training_experiment(data_path_10d, syntactic_dims_10, OUTPUTS_DIR / "styled-qwen-10d")

run_training_experiment(data_path_18d, syntactic_dims_18, OUTPUTS_DIR / "styled-qwen-18d")

## Construct Character Style Vectors

In [8]:
from typing_extensions import TypedDict, Optional
from collections import Counter

import json

Pragmatic_Muice = "./outputs/pragmatic/muice.jsonl"
Pragmatic_ayaka = "./outputs/pragmatic/ayaka.jsonl"
Pragmatic_zhongli = "./outputs/pragmatic/zhongli.jsonl"
Pragmatic_hutao = "./outputs/pragmatic/hutao.jsonl"
Pragmatic_haruhi = "./outputs/pragmatic/haruhi.jsonl"

class RawPCFGItem(TypedDict):
    prompt: str
    response: str
    pragmatic_styles: list[dict[str, float]]

class PCFGItem(TypedDict):
    response: str
    pragmatic_styles: list[str]

def read_pcfg_jsonl_file(jsonl_file: Path, threshold: Optional[float] = None, top_n: int = 5) -> list[str]:
    with open(jsonl_file, "r", encoding="utf-8") as f:
        lines = f.readlines()

    raw_items: list[RawPCFGItem] = []
    items: list[str] = []

    for line in lines:
        if line := line.strip():
            raw_item: RawPCFGItem = json.loads(line)
            raw_items.append(raw_item)

    # list[dict[str, float]] -> dict[str, float] -> list[str]
    for raw_item in raw_items:
        raw_pragmatic_styles = raw_item["pragmatic_styles"]
        pragmatic_styles: dict[str, float] = {}

        for vec in raw_pragmatic_styles:
            pragmatic_styles.update(vec)

        threshold = threshold or 0
        final_styles: list[str] = []

        for key, value in pragmatic_styles.items():
            if value > threshold:
                final_styles.append(key)
        
        items.extend(final_styles)

    # 返回 Top N 风格
    style_counter = Counter(items)
    most_common_styles = [style for style, _ in style_counter.most_common(top_n)]

    return most_common_styles

pcfg_muice_items = read_pcfg_jsonl_file(Path(Pragmatic_Muice), 0.4)
pcfg_ayaka_items = read_pcfg_jsonl_file(Path(Pragmatic_ayaka), 0.4)
pcfg_zhongli_items = read_pcfg_jsonl_file(Path(Pragmatic_zhongli), 0.4)
pcfg_hutao_items = read_pcfg_jsonl_file(Path(Pragmatic_hutao), 0.4)
pcfg_haruhi_items = read_pcfg_jsonl_file(Path(Pragmatic_haruhi), 0.4)


## Batch Inference

In [7]:
import json
from dataclasses import dataclass
from pathlib import Path
from collections import Counter
from typing import Optional, Any
from typing_extensions import TypedDict

import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoTokenizer, Qwen3ForCausalLM
from peft import PeftModel

# === Collect all output 对齐区（schema/命名对齐 Collect_all_output.ipynb）===
TEST_FILE = Path("./data/neutral_sentences_eval.jsonl")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open(TEST_FILE, "r", encoding="utf-8") as f:
    test_data: list[str] = [json.loads(line)["neutral"] for line in f.readlines()]

class InferenceResult(TypedDict):
    model: str
    neutral: str
    output: str
    character: str

results: list[InferenceResult] = []

# 硬编码关键词
KEYWORDS_MAP = {
    "Muice": ["喵", "沐沐", "AI", "恼", "沐雪", "女孩子", "~", "⭐", "不行", "聊天", "呀", "可爱", "才", "叫", "唔", "谁", "不会", "吃", "睡觉", "笨蛋", "答", "谢谢", "把", "即", "吧"],
    "Ayaka": ["稻妻国", "神里家", "稻妻", "大小姐", "家族", "传统", "奉行", "文化", "人民", "眼狩令", "神", "当地", "社", "舞蹈", "美丽", "茶道", "神社", "祭典", "眼", "美食", "继承", "剑术", "国家", "将军", "责任"],
    "Zhongli": ['岩石', '岩', '璃月', '力', '璃', '契约', '炼金术', '月', '盐', '帝君', '魔神', '操控', '王', '并非', '岩王', '大地', '封印', '作战', '掌握', '大陆', '学问', '研究', '七星', '客卿', '岩元素'],
    "Hutao": ['往生堂', '嘿嘿', '嘻嘻', '可是', '堂主', '哎呀呀', '哦哦哦', '宝藏', '惊喜', '诗歌', '可不是', '灵魂', '胡桃', '神秘', '生死', '谜题', '哈哈哈', '不过', '有趣', '亡灵', '秘密', '意想不到', '巫师', '哇', '奇妙'],
    "Haruhi": ['团', 'SOS', '阿虚', '社团', '哼', '事件', '学校', '超自然', '朝比奈', '文化祭', '创意', '吸引', '古泉', '电影', '创新', '组织', '实玖瑠', '当然', '与众不同', '主题', '加入', '束缚', '凉宫', '团长', '外星人'],
}

@dataclass
class CharacterProfile:
    name: str
    syntactic_vec: dict[str, float]
    pragmatic_styles: list[str]
    lexical_keywords: list[str]

class RawPCFGItem(TypedDict):
    prompt: str
    response: str
    pragmatic_styles: list[dict[str, float]]

def read_pcfg_jsonl_file(jsonl_file: Path, threshold: Optional[float] = None, top_n: int = 5) -> list[str]:
    """从 jsonl 文件中提取 Top N 语用风格 (Collect_all_output 实现)"""
    if not jsonl_file.exists():
        print(f"Warning: Pragmatic file not found: {jsonl_file}")
        return []

    with open(jsonl_file, "r", encoding="utf-8") as f:
        lines = f.readlines()

    raw_items: list[RawPCFGItem] = []
    items: list[str] = []

    for line in lines:
        if line := line.strip():
            raw_items.append(json.loads(line))

    for raw_item in raw_items:
        raw_pragmatic_styles = raw_item.get("pragmatic_styles", [])
        pragmatic_styles: dict[str, float] = {}
        for vec in raw_pragmatic_styles:
            pragmatic_styles.update(vec)

        thresh = threshold or 0
        final_styles: list[str] = []
        for key, value in pragmatic_styles.items():
            if value > thresh:
                final_styles.append(key)
        items.extend(final_styles)

    style_counter = Counter(items)
    most_common_styles = [style for style, _ in style_counter.most_common(top_n)]
    return most_common_styles

def load_profiles(rule_mapping: dict[str, str]) -> list[CharacterProfile]:
    """严格对齐 Collect_all_output 的 profile 结构：关键词硬编码 + 语用 topN + 句法向量由 rule_mapping 动态提取"""
    profiles: list[CharacterProfile] = []
    char_map = [
        ("Muice", "沐雪", "muice"),
        ("Ayaka", "神里绫华", "ayaka"),
        ("Zhongli", "钟离", "zhongli"),
        ("Hutao", "胡桃", "hutao"),
        ("Haruhi", "凉宫春日", "haruhi"),
    ]
    for en, zh, key in char_map:
        cons_path = Path("./outputs/cons") / f"{key}.json"
        syn_vec: dict[str, float] = load_and_extract_pcfg(cons_path, rule_mapping) if cons_path.exists() else {}
        prag_path = Path("./outputs/pragmatic") / f"{key}.jsonl"
        prag_styles = read_pcfg_jsonl_file(prag_path, threshold=0.4, top_n=5)
        keywords = KEYWORDS_MAP.get(en, [])
        profiles.append(CharacterProfile(name=zh, syntactic_vec=syn_vec, pragmatic_styles=prag_styles, lexical_keywords=keywords))
    return profiles

class StyleEncoder(nn.Module):
    def __init__(self, input_dim: int, out_dim: int):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(input_dim, out_dim // 2),
            nn.ReLU(),
            nn.Linear(out_dim // 2, out_dim),
            nn.Tanh(),
        )

    def forward(self, syntactic_vec: torch.Tensor) -> torch.Tensor:
        return self.proj(syntactic_vec)

# === 这些全局变量将由 _load_styled_model() 注入（避免 Pylance 未定义告警） ===
tokenizer: Any = None
model: Any = None
style_encoder: Any = None
syntactic_dims: list[str] = []

@torch.inference_mode()
def generate_styled_response(
    neutral_sentence: str,
    syntactic_vec: dict[str, float],
    character_name: str = "Ayaka",
    lexical_keywords: list[str] = list(),
    pragmatic_styles: list[str] = list(),
    temperature: float = 0.8,
    top_p: float = 0.95,
    repetition_penalty: float = 1.3,
    max_new_tokens: int = 100,
 ) -> str:
    """输入中性句和风格向量，生成风格化响应（实现/提示格式对齐 Collect_all_output.ipynb）"""
    assert tokenizer is not None and model is not None and style_encoder is not None, "Model/tokenizer/style_encoder 未加载"
    assert syntactic_dims, "syntactic_dims 未初始化（请先调用 _load_styled_model）"

    lexical_keywords = lexical_keywords or []
    pragmatic_styles = pragmatic_styles or []

    keywords = ", ".join(lexical_keywords) if lexical_keywords else "None"
    pragmatics = ", ".join(pragmatic_styles) if pragmatic_styles else "None"

    system_prompt = "You are a style transfer expert. Your task is to generate a new sentence that matches the target style, based on the content of a neutral sentence."
    user_prompt = (
        f"Target Character {character_name}\n"
        f"Personality: {pragmatics}\n"
        f"Keywords: {keywords}\n"
        f"Neutral Content: {neutral_sentence}\n"
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    input_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False,
    )

    tokenized = tokenizer(input_text, return_tensors="pt").to(DEVICE)
    input_ids = tokenized["input_ids"]
    attention_mask = tokenized["attention_mask"]

    syntactic_tensor = torch.tensor(
        [syntactic_vec.get(dim, 0.0) for dim in syntactic_dims],
        dtype=torch.float32, device=DEVICE,
    ).unsqueeze(0)

    enc_dtype = next(style_encoder.parameters()).dtype
    style_emb = style_encoder(syntactic_tensor.to(enc_dtype)).to(model.dtype)
    style_prefix = style_emb.unsqueeze(1)

    token_embeds = model.get_input_embeddings()(input_ids)  # type: ignore
    inputs_embeds = torch.cat([style_prefix, token_embeds], dim=1)

    prefix_mask = torch.ones((1, 1), dtype=torch.long, device=DEVICE)
    new_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)

    outputs = model.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=new_attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
    )

    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return result

In [None]:
# === 逐模型批量推理 ===
import gc

MODEL_PATH = Path("../Models/Qwen3-1.7B/")

SAVE_DIR_10D = OUTPUTS_DIR / "styled-qwen-10d"
SAVE_DIR_18D = OUTPUTS_DIR / "styled-qwen-18d"

def _check_model_dir(model_save_dir: Path) -> None:
    required = [model_save_dir / "tokenizer", model_save_dir / "lora", model_save_dir / "style_encoder.pt"]
    if any(not p.exists() for p in required):
        raise FileNotFoundError(f"模型目录不完整: {model_save_dir}")

def _load_styled_model(save_dir: Path, syntactic_dims_: list[str]):
    global tokenizer, model, style_encoder, syntactic_dims
    _check_model_dir(save_dir)

    syntactic_dims = syntactic_dims_

    tokenizer = AutoTokenizer.from_pretrained(save_dir / "tokenizer")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    base_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=torch.bfloat16)
    model = PeftModel.from_pretrained(base_model, save_dir / "lora")
    model.to(DEVICE)
    model.eval()

    hidden_size = base_model.config.hidden_size
    style_encoder = StyleEncoder(len(syntactic_dims), hidden_size)
    style_encoder.load_state_dict(torch.load(save_dir / "style_encoder.pt", map_location=DEVICE))
    style_encoder.to(DEVICE)
    style_encoder.eval()

    print(f"✅ Styled model loaded: {save_dir}")
    return base_model

def _unload_model(base_model):
    global tokenizer, model, style_encoder
    try:
        del tokenizer
    except Exception:
        pass
    try:
        del model
    except Exception:
        pass
    try:
        del style_encoder
    except Exception:
        pass
    try:
        del base_model
    except Exception:
        pass
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

# === 运行 10d ===
profiles_10d = load_profiles(rule_to_10dim)
base_model_10d = _load_styled_model(SAVE_DIR_10D, syntactic_dims_10)

def batch_inference(profile: CharacterProfile, model_name: str):
    for item in tqdm(test_data, desc=f"Running {profile.name}"):
        output = generate_styled_response(
            item,
            profile.syntactic_vec,
            character_name=profile.name,
            lexical_keywords=profile.lexical_keywords,
            pragmatic_styles=profile.pragmatic_styles,
        )
        results.append(InferenceResult(model=model_name, neutral=item, output=output, character=profile.name))

for prof in profiles_10d:
    batch_inference(prof, "10d-model")

_unload_model(base_model_10d)

# === 运行 18d ===
profiles_18d = load_profiles(rule_to_18dim)
base_model_18d = _load_styled_model(SAVE_DIR_18D, syntactic_dims_18)

for prof in profiles_18d:
    batch_inference(prof, "18d-model")

_unload_model(base_model_18d)

print(f"✅ Done. Total inference items: {len(results)}")

## Export Final Results

In [10]:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

OUTPUT_PATH = Path("./outputs/DimensionDissolution/batch_run_result.jsonl")
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)

# 清空旧文件，避免重复累计
with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
    pass

for item in results:
    with open(OUTPUT_PATH, "a", encoding="utf-8") as f:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

print(f"✅ Saved: {OUTPUT_PATH} (n={len(results)})")

✅ Saved: evaluate\outputs\DimensionDissolution\batch_run_result.jsonl (n=1500)


## Automated Evaluation

### Load Batch Inference Results

In [11]:
from pathlib import Path
from typing_extensions import TypedDict
import json

class InferenceItem(TypedDict):
    model: str
    neutral: str
    output: str
    character: str

# 读取本笔记本导出的消融实验输出（schema 与 Collect_all_output 一致）
BATCH_INFERENCE_FILE = Path("./outputs/DimensionDissolution/batch_run_result.jsonl")

batch_inference_items: list[InferenceItem] = []

with open(BATCH_INFERENCE_FILE, "r", encoding="utf-8") as f:
    lines = f.readlines()
    for line in lines:
        item: InferenceItem = json.loads(line.strip())
        item["output"] = item["output"].split("</think>")[1].strip() if "</think>" in item["output"] else item["output"]
        batch_inference_items.append(item)

### Style Consistency Evaluation

In [12]:
from pathlib import Path
from typing import Tuple, Dict, List
from tqdm import tqdm

from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
from tokenizers import Tokenizer

import numpy as np
import torch.nn as nn
import torch
import os

BACKBONE_PATH = str(Path('../Models/chinese-roberta-wwm-ext').resolve())
CHECKPOINT_PATH = Path('outputs/style-classifier')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class CharacterStyleClassifier(nn.Module):
    def __init__(self, backbone_name: str, embed_dim: int = 768, proj_dim: int = 256, num_roles: int = 6, dropout: float = 0.4):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(backbone_name, output_hidden_states=False)
        self.hidden_size = self.backbone.config.hidden_size
        assert self.hidden_size == embed_dim, f"Backbone hidden_size {self.hidden_size} != embed_dim {embed_dim}"

        self.proj = nn.Sequential(
            nn.Linear(self.hidden_size, proj_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.classifier = nn.Linear(proj_dim, num_roles)

    def mean_pooling(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_hidden = torch.sum(last_hidden_state * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
        mean_pooled = sum_hidden / sum_mask
        return mean_pooled

    def forward(self, input_ids, attention_mask):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        last_hidden = out.last_hidden_state
        sent_emb = self.mean_pooling(last_hidden, attention_mask)
        proj = self.proj(sent_emb)
        logits = self.classifier(proj)
        return logits, proj

def load_checkpoint(path: str, device='cpu') -> Tuple[nn.Module, 'Tokenizer', LabelEncoder, Dict[str, np.ndarray]]:
    tokenizer = AutoTokenizer.from_pretrained(path)
    with open(os.path.join(path, 'label_encoder.json'), 'r', encoding='utf-8') as f:
        le_json = json.load(f)
    le = LabelEncoder()
    le.classes_ = np.array(le_json['classes'])

    backbone_hidden = AutoModel.from_pretrained(path).config.hidden_size
    model = CharacterStyleClassifier(path, embed_dim=backbone_hidden, proj_dim=256, num_roles=len(le.classes_))
    chk = torch.load(os.path.join(path, 'head.pt'), map_location=device)
    model.proj.load_state_dict(chk['proj_state'])
    model.classifier.load_state_dict(chk['classifier_state'])
    centers_npz = np.load(os.path.join(path, 'role_centers.npz'))
    role_centers = {k: centers_npz[k] for k in centers_npz.files}
    return model.to(device), tokenizer, le, role_centers

@torch.no_grad()
def compute_role_centers(model: nn.Module, dataloader: DataLoader, label_encoder: LabelEncoder, device='cpu') -> Dict[str, np.ndarray]:
    model.eval()
    accum: Dict[int, List[np.ndarray]] = {}
    for batch in tqdm(dataloader, desc="Computing centers"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].cpu().numpy()
        _, embeddings = model(input_ids=input_ids, attention_mask=attention_mask)
        emb_np = embeddings.detach().cpu().numpy()
        for lbl, e in zip(labels, emb_np):
            accum.setdefault(int(lbl), []).append(e)
    role_centers = {}
    for lbl, vecs in accum.items():
        avg = np.mean(np.stack(vecs, axis=0), axis=0)
        role_name = label_encoder.inverse_transform([lbl])[0]
        role_centers[role_name] = avg
    return role_centers

def get_style_score(model: nn.Module, tokenizer, text: str, role_center: np.ndarray, device='cpu', max_length=128) -> float:
    model.eval()
    enc = tokenizer(text, truncation=True, max_length=max_length, padding='max_length', return_tensors='pt')
    input_ids = enc['input_ids'].to(device)
    attention_mask = enc['attention_mask'].to(device)
    with torch.no_grad():
        _, emb = model(input_ids=input_ids, attention_mask=attention_mask)
        emb_np = emb.detach().cpu().numpy()[0]
    num = float(np.dot(emb_np, role_center))
    den = float(np.linalg.norm(emb_np) * np.linalg.norm(role_center) + 1e-9)
    return num / den

model, tokenizer, label_encoder, role_centers = load_checkpoint(str(CHECKPOINT_PATH), device)

In [13]:
import pandas as pd

results = []

for item in tqdm(batch_inference_items, desc="Evaluating Style"):
    character = item['character']
    output_text = item['output']

    if not character in role_centers:
        print(f"Warning: Character '{character}' not found in role centers. Skipping.")
        continue

    center = role_centers[character]
    score = get_style_score(model, tokenizer, output_text, center, device=device)

    results.append({
        **item,
        "style_score": score,
    })

model_scores = {}
for res in results:
    m_name = res['model']
    if m_name not in model_scores:
        model_scores[m_name] = []
    model_scores[m_name].append(res['style_score'])

print("\nStyle Consistency Scores:")
for m_name, scores in model_scores.items():
    avg_score = np.mean(scores)
    print(f"Model: {m_name}, Average Style Score: {avg_score:.4f} (n={len(scores)})")

df_results = pd.DataFrame(results)

Evaluating Style: 100%|██████████| 1500/1500 [00:14<00:00, 101.73it/s]


Style Consistency Scores:
Model: 10d-model, Average Style Score: 0.6581 (n=750)
Model: 18d-model, Average Style Score: 0.6644 (n=750)





### Semantic Preservation Evaluation

In [15]:
from sentence_transformers import SentenceTransformer, util

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

MODEL_PATH = "../Models/bge-large-zh-v1.5"
semantic_model = SentenceTransformer(MODEL_PATH)

def compute_semantic_similarity(neutral_text, stylized_text):
    embeddings = semantic_model.encode([neutral_text, stylized_text], normalize_embeddings=True)
    similarity = util.cos_sim(embeddings[0], embeddings[1])
    return similarity.item()

In [16]:
for item in tqdm(results, desc="Evaluating Semantic"):
    neutral_text = item['neutral']
    stylized_text = item['output']
    score = compute_semantic_similarity(neutral_text, stylized_text)
    item['semantic_score'] = score

# === 计算 H-Score 与有效风格得分 ===
SEM_THRESHOLD = 0.75  # 语义门槛

for item in results:
    style_score = item.get('style_score', 0.0)
    sem_score = item.get('semantic_score', 0.0)
    h_score = (2 * style_score * sem_score) / (style_score + sem_score + 1e-9)
    penalty_factor = 1.0 if sem_score > SEM_THRESHOLD else 0.1
    penalized_style = style_score * penalty_factor
    item['h_score'] = h_score
    item['valid_style_score'] = penalized_style

model_metrics = {}
for item in results:
    m_name = item['model']
    char_name = item['character']
    sem = item['semantic_score']
    sty = item['style_score']
    h = item['h_score']
    vsty = item['valid_style_score']
    if m_name not in model_metrics:
        model_metrics[m_name] = {
            'semantic': [],
            'style': [],
            'h_score': [],
            'valid_style': [],
            'per_char': {},
        }
    model_metrics[m_name]['semantic'].append(sem)
    model_metrics[m_name]['style'].append(sty)
    model_metrics[m_name]['h_score'].append(h)
    model_metrics[m_name]['valid_style'].append(vsty)
    if char_name not in model_metrics[m_name]['per_char']:
        model_metrics[m_name]['per_char'][char_name] = []
    model_metrics[m_name]['per_char'][char_name].append(sem)

print(f"\nSemantic Preservation + Style Scores(SEM_THRESHOLD={SEM_THRESHOLD}):")
print("=" * 80)
for m_name, metrics in model_metrics.items():
    avg_sem = np.mean(metrics['semantic'])
    avg_style = np.mean(metrics['style'])
    avg_h = np.mean(metrics['h_score'])
    avg_vsty = np.mean(metrics['valid_style'])
    print(f"Model: {m_name}")
    sorted_chars = sorted(metrics['per_char'].keys())
    for char_name in sorted_chars:
        scores = metrics['per_char'][char_name]
        avg_char_sem = np.mean(scores)
        print(f"  - Character: {char_name:<10} | Sem: {avg_char_sem:.4f} (n={len(scores)})")
    print(f"  >> Avg Semantic: {avg_sem:.4f}")
    print(f"  >> Avg Style  : {avg_style:.4f}")
    print(f"  >> Avg H-Score: {avg_h:.4f}")
    print(f"  >> Avg Valid Style: {avg_vsty:.4f}")
    print("-" * 80)

df_results = pd.DataFrame(results)

Evaluating Semantic: 100%|██████████| 1500/1500 [11:08<00:00,  2.24it/s]


Semantic Preservation + Style Scores(SEM_THRESHOLD=0.75):
Model: 10d-model
  - Character: 凉宫春日       | Sem: 0.8511 (n=150)
  - Character: 沐雪         | Sem: 0.8881 (n=150)
  - Character: 神里绫华       | Sem: 0.8605 (n=150)
  - Character: 胡桃         | Sem: 0.8575 (n=150)
  - Character: 钟离         | Sem: 0.8915 (n=150)
  >> Avg Semantic: 0.8698
  >> Avg Style  : 0.6581
  >> Avg H-Score: 0.7088
  >> Avg Valid Style: 0.5877
--------------------------------------------------------------------------------
Model: 18d-model
  - Character: 凉宫春日       | Sem: 0.8537 (n=150)
  - Character: 沐雪         | Sem: 0.8915 (n=150)
  - Character: 神里绫华       | Sem: 0.8799 (n=150)
  - Character: 胡桃         | Sem: 0.8445 (n=150)
  - Character: 钟离         | Sem: 0.8953 (n=150)
  >> Avg Semantic: 0.8730
  >> Avg Style  : 0.6644
  >> Avg H-Score: 0.7139
  >> Avg Valid Style: 0.5896
--------------------------------------------------------------------------------





### Generate 2D Scatter Plot/Comprehensive Metrics Table

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(8, 6))
sns.scatterplot(data=df_results, x="semantic_score", y="style_score", hue="model", alpha=0.6)
plt.xlabel("Semantic Score")
plt.ylabel("Style Score (Raw)")
plt.title("Semantic vs Style Trade-off")
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.legend(title="Model")
plt.grid(True, linestyle="--", alpha=0.3)
plt.show()

leaderboard_rows = []
for m_name, metrics in model_metrics.items():
    semantic_score = round(np.mean(metrics["semantic"]), 4)
    style_score = round(np.mean(metrics["style"]), 4)
    h_score = round(np.mean(metrics["h_score"]), 4)
    valid_style_score = round(np.mean(metrics["valid_style"]), 4)
    leaderboard_rows.append({
        "model": m_name,
        "semantic_score": semantic_score,
        "style_score_raw": style_score,
        "h_score": h_score,
        "valid_style_score": valid_style_score,
    })

df_leaderboard = pd.DataFrame(leaderboard_rows).sort_values(by=["valid_style_score"], ascending=False)
df_leaderboard