# Collect All Model Outputs

In [2]:
# Run this on Remote Jupyter Book.

import os

os.getcwd()
os.chdir("/root/OtakuLab")

## Load Test Set

In [2]:
import json
from pathlib import Path
from typing_extensions import TypedDict

TEST_FILE = Path("./data/neutral_sentences_eval.jsonl")

with open(TEST_FILE, "r", encoding="utf-8") as f:
    test_data: list[str] = [json.loads(line)["neutral"] for line in f.readlines()]


class InferenceResult(TypedDict):
    model: str
    neutral: str
    output: str
    character: str

results: list[InferenceResult] = []

## Model v1

### Load Model

In [3]:
from transformers import AutoTokenizer
import torch
from pathlib import Path
import torch
import torch.nn as nn
from pathlib import Path
from transformers import Qwen3ForCausalLM
from peft import get_peft_model, LoraConfig, PeftModel

SAVE_DIR = Path("./outputs/styled-qwen")
MODEL_PATH = Path("/root/autodl-tmp/Qwen3-1.7B/")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

syntactic_dims = ['syntactic_compression', 'declarativity', 'clausal_embedding', 'nominal_complexity', 'subordination', 'interjectionality',
                  'ellipsis_or_fragmentation', 'modifier_density', 'prepositional_density', 'topic_fronting', 'referentiality', 'parallelism',
                  'quantificationality', 'coordination_density']
prag_style_vocab = ['kind', 'modest', 'clingy', 'playful', 'cold', 'proud', 'sharp_tongued', 'subservient', 'submissive', 'controlling',
                    'strong', 'defensive', 'tsukkomi', 'rational', 'curious', 'imaginative', 'cautious', 'idealistic', 'conservative',
                    'radical', 'obsessive', 'hesitant', 'energetic', 'optimistic', 'confident', 'passionate', 'melancholy', 'serious',
                    'emotional', 'sensitive', 'shy', 'irritable', 'anxious', 'lazy', 'tsundere', 'yandere', 'chuunibyou', 'cute', 'naive',
                    'airhead', 'elegant', 'humorous', 'loyal', 'responsible', 'willful', 'antisocial', 'talkative', 'masochistic', 'sadistic', 'evil']

syntactic_dim_length = len(syntactic_dims)

class StyleEncoder(nn.Module):
    def __init__(self, input_dim, out_dim):
        """
        input_dim: 句法向量的维度 (开始时是 8, 扩展后可能是 12-20)
        out_dim: 必须等于 `base_model.config.hidden_size` (必须与 Qwen 的 hidden_size 匹配)
        """
        super().__init__()
        # 我们可以使用一个简单的 MLP 来映射
        self.proj = nn.Sequential(
            nn.Linear(input_dim, out_dim // 2),
            nn.ReLU(),
            nn.Linear(out_dim // 2, out_dim),
            nn.Tanh() # Tanh 将输出归一化到 [-1, 1] 作为一个稳定的 "软提示"
        )

    def forward(self, syntactic_vec):
        # syntactic_vec: [batch_size, input_dim]
        # return: [batch_size, out_dim]
        return self.proj(syntactic_vec)
    
class StyleConditionedLoRAModel(nn.Module):
    def __init__(self, model_name_or_path: str|Path, syntactic_dim: int, lora_r=16, lora_alpha=16):
        super().__init__()
        
        # 1. 加载基础模型
        base_model = Qwen3ForCausalLM.from_pretrained(
            model_name_or_path,
            dtype=torch.bfloat16 # 使用 bfloat16 节省显存
        )
        
        # 2. 定义 LoRA 配置
        # 目标模块 "key/query/value"
        # 在 Qwen3-1.7B 中，它们通常被称为 "q_proj", "k_proj", "v_proj"
        config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # 覆盖所有线性层以获得最大表达力
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        
        # 3. 应用 LoRA
        self.peft_model: PeftModel = get_peft_model(base_model, config)  # type:ignore
        self.peft_model.print_trainable_parameters()
        
        # 4. 初始化 StyleEncoder
        # 它的输出必须匹配模型的隐藏层维度
        self.hidden_size = base_model.config.hidden_size
        self.style_encoder = StyleEncoder(syntactic_dim, self.hidden_size)
        # Cast the style encoder to match the base model's dtype
        self.style_encoder.to(base_model.dtype)

    def get_input_embeddings(self):
        # 获取基础模型的词嵌入层
        return self.peft_model.get_input_embeddings()  # type:ignore

    def forward(self, input_ids, attention_mask, labels, syntactic_vec):
        """
        这个 forward 函数实现了方案 C.1 (输入拼接)
        并为方案 D (辅助损失) 做好准备
        """
        
        # 1. 计算 Style Embedding
        # syntactic_vec: [batch_size, syntactic_dim]
        # style_emb: [batch_size, hidden_size]
        style_emb = self.style_encoder(syntactic_vec.to(self.peft_model.dtype)) # type: ignore
        
        # 2. 将 style_emb 视为一个 "Prefix" 软提示
        # 变为: [batch_size, 1, hidden_size]
        style_emb_prefix = style_emb.unsqueeze(1)
        
        # 3. 获取原始的 Token 词嵌入
        # token_embeds: [batch_size, seq_len, hidden_size]
        token_embeds = self.get_input_embeddings()(input_ids)
        
        # 4. 拼接 Style Prefix 和 Token 嵌入
        # inputs_embeds: [batch_size, 1 + seq_len, hidden_size]
        inputs_embeds = torch.cat([style_emb_prefix, token_embeds], dim=1)
        
        # 5. 修正 Attention Mask
        # 我们需要在 mask 的开头添加一个 "1" (代表 style prefix)
        prefix_mask = torch.ones(
            attention_mask.shape[0], 1,
            dtype=torch.long, 
            device=attention_mask.device
        )
        # new_attention_mask: [batch_size, 1 + seq_len]
        new_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
        
        # 6. 修正 Labels
        # 我们需要 L_lm 忽略 style prefix 部分
        # 在 labels 的开头添加一个 "-100"
        prefix_labels = torch.full(
            (labels.shape[0], 1), -100, 
            dtype=torch.long, 
            device=labels.device
        )
        # new_labels: [batch_size, 1 + seq_len]
        new_labels = torch.cat([prefix_labels, labels], dim=1)
        
        # 7. 执行模型的前向传播
        # 我们请求 hidden_states 以便计算辅助损失
        outputs = self.peft_model(
            inputs_embeds=inputs_embeds,
            attention_mask=new_attention_mask,
            labels=new_labels,
            output_hidden_states=True 
        )
        
        # 8. 返回计算损失所需的所有组件
        # L_lm: outputs.loss
        # L_recon / L_style_cls: 需要 outputs.hidden_states 和 syntactic_vec/pragmatic_tags
        
        return {
            "loss": outputs.loss,  # L_lm
            "logits": outputs.logits,
            # 返回最后一层 hidden_state，用于计算辅助损失
            # hidden_states 是一个元组，最后一个元素是 [batch_size, 1 + seq_len, hidden_size]
            "last_hidden_state": outputs.hidden_states[-1] 
        }

@torch.inference_mode()
def generate_styled_response(neutral_sentence: str, syntactic_vec: dict[str, float],
                             character_name: str = "Ayaka", 
                             lexical_keywords: list[str] = list(),
                             pragmatic_styles: list[str] = list(),
                             temperature: float = 0.8,
                             top_p: float = 0.95,
                             repetition_penalty: float = 1.3,
                             max_new_tokens: int = 100):
    """输入中性句和风格向量，生成风格化响应"""

    lexical_keywords = lexical_keywords or []
    pragmatic_styles = pragmatic_styles or []

    # === 1. 构建提示 ===
    keywords = ", ".join(lexical_keywords) if lexical_keywords else "None"
    pragmatics = ", ".join(pragmatic_styles) if pragmatic_styles else "None"

    system_prompt = "You are a style transfer expert. Your task is to generate a new sentence that matches the target style, based on the content of a neutral sentence."
    user_prompt = (
        f"Target Character {character_name}\n"
        f"Personality: {pragmatics}\n"
        f"Keywords: {keywords}\n"
        f"Neutral Content: {neutral_sentence}\n"
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    input_text = tokenizer.apply_chat_template(messages,
                                               tokenize=False,
                                               add_generation_prompt=True,
                                               enable_thinking=True)

    # === 2. Tokenize 输入 ===
    tokenized = tokenizer(input_text, return_tensors="pt").to(DEVICE)
    input_ids = tokenized["input_ids"]
    attention_mask = tokenized["attention_mask"]

    # === 3. 风格向量 ===
    syntactic_tensor = torch.tensor(
        [syntactic_vec.get(dim, 0.0) for dim in syntactic_dims],
        dtype=torch.float32, device=DEVICE
    ).unsqueeze(0)  # [1, syntactic_dim_length]

    # === 4. 生成风格 embedding ===
    style_emb = style_encoder(syntactic_tensor).to(model.dtype)  # [1, hidden_size]
    style_prefix = style_emb.unsqueeze(1)        # [1, 1, hidden_size]

    # === 5. 获取原始词嵌入并拼接 ===
    token_embeds = model.get_input_embeddings()(input_ids)  # type:ignore
    inputs_embeds = torch.cat([style_prefix, token_embeds], dim=1)

    # === 6. Attention mask 修正 ===
    prefix_mask = torch.ones((1, 1), dtype=torch.long, device=DEVICE)
    new_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)

    # === 7. 生成 ===
    outputs = model.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=new_attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True
    )

    # === 8. 解码输出 ===
    # prompt_length = input_ids.shape[1]
    # new_tokens = outputs[0, prompt_length:]
    # result = tokenizer.decode(new_tokens, skip_special_tokens=True)
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return result


# === 加载 Tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(SAVE_DIR / "tokenizer")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# === 加载基础模型并注入 LoRA ===
base_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=torch.bfloat16)
model = PeftModel.from_pretrained(base_model, SAVE_DIR / "lora")
hidden_size = base_model.config.hidden_size

# === 初始化风格编码器 ===
style_encoder = StyleEncoder(syntactic_dim_length, hidden_size)
style_encoder.load_state_dict(torch.load(SAVE_DIR / "style_encoder.pt", map_location=DEVICE))
style_encoder.to(DEVICE)
style_encoder.eval()

model.to(DEVICE)
model.eval()
print("✅ Styled model loaded and ready for inference.")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Styled model loaded and ready for inference.


### Construct Character Style Vectors

In [3]:
from typing_extensions import TypedDict, Optional
from collections import Counter

import json

Pragmatic_Muice = "./outputs/pragmatic/muice.jsonl"
Pragmatic_ayaka = "./outputs/pragmatic/ayaka.jsonl"
Pragmatic_zhongli = "./outputs/pragmatic/zhongli.jsonl"
Pragmatic_hutao = "./outputs/pragmatic/hutao.jsonl"
Pragmatic_haruhi = "./outputs/pragmatic/haruhi.jsonl"

class RawPCFGItem(TypedDict):
    prompt: str
    response: str
    pragmatic_styles: list[dict[str, float]]

class PCFGItem(TypedDict):
    response: str
    pragmatic_styles: list[str]

def read_pcfg_jsonl_file(jsonl_file: Path, threshold: Optional[float] = None, top_n: int = 5) -> list[str]:
    with open(jsonl_file, "r", encoding="utf-8") as f:
        lines = f.readlines()

    raw_items: list[RawPCFGItem] = []
    items: list[str] = []

    for line in lines:
        if line := line.strip():
            raw_item: RawPCFGItem = json.loads(line)
            raw_items.append(raw_item)

    # list[dict[str, float]] -> dict[str, float] -> list[str]
    for raw_item in raw_items:
        raw_pragmatic_styles = raw_item["pragmatic_styles"]
        pragmatic_styles: dict[str, float] = {}

        for vec in raw_pragmatic_styles:
            pragmatic_styles.update(vec)

        threshold = threshold or 0
        final_styles: list[str] = []

        for key, value in pragmatic_styles.items():
            if value > threshold:
                final_styles.append(key)
        
        items.extend(final_styles)

    # 返回 Top N 风格
    style_counter = Counter(items)
    most_common_styles = [style for style, _ in style_counter.most_common(top_n)]

    return most_common_styles

pcfg_muice_items = read_pcfg_jsonl_file(Path(Pragmatic_Muice), 0.4)
pcfg_ayaka_items = read_pcfg_jsonl_file(Path(Pragmatic_ayaka), 0.4)
pcfg_zhongli_items = read_pcfg_jsonl_file(Path(Pragmatic_zhongli), 0.4)
pcfg_hutao_items = read_pcfg_jsonl_file(Path(Pragmatic_hutao), 0.4)
pcfg_haruhi_items = read_pcfg_jsonl_file(Path(Pragmatic_haruhi), 0.4)


In [5]:
from dataclasses import dataclass

@dataclass
class CharacterProfile:
    name: str
    syntactic_vec: dict[str, float]
    pragmatic_styles: list[str]
    lexical_keywords: list[str]

muice_profile = CharacterProfile(name="沐雪",
                                syntactic_vec={
                                    'declarativity': 0.1103257643217572,
                                    'parallelism': 0.02918150786583556,
                                    'ellipsis_or_fragmentation': 0.08529979222321163,
                                    'subordination': 0.19688705847432472,
                                    'interjectionality': 0.008644998515880083,
                                    'clausal_embedding': 0.034784060552092606,
                                    'referentiality': 0.11624369249035323,
                                    'syntactic_compression': 0.18596022558622738,
                                    'nominal_complexity': 0.03342980112793114,
                                    'coordination_density': 0.002615761353517364,
                                    'quantificationality': 0.038197536360937964,
                                    'modifier_density': 0.1393217571979816,
                                    'prepositional_density': 0.01910804392994954
                                    },
                                pragmatic_styles=pcfg_muice_items,
                                lexical_keywords=["喵", "沐沐", "AI", "恼", "沐雪", "女孩子", "~", "⭐", "不行", "聊天", "呀", "可爱", "才", "叫", "唔", "谁", "不会", "吃", "睡觉", "笨蛋", "答", "谢谢", "把", "即", "吧"]
                                )


ayaka_profile = CharacterProfile(name="神里绫华",
                                 syntactic_vec={
                                    "declarativity": 0.09320164543629895,
                                    "parallelism": 0.029192583613203236,
                                    "ellipsis_or_fragmentation": 0.061485264601259915,
                                    "subordination": 0.18419745235587529,
                                    "clausal_embedding": 0.046163629498618866,
                                    "interjectionality": 0.002046859164166054,
                                    "syntactic_compression": 0.1999165358399078,
                                    "nominal_complexity": 0.05623894596689255,
                                    "referentiality": 0.10019673694878878,
                                    "coordination_density": 0.02416486158860118,
                                    "quantificationality": 0.042964170028417556,
                                    "modifier_density": 0.13753701238051708,
                                    "prepositional_density": 0.022694302577452752
                                },
                                pragmatic_styles=pcfg_ayaka_items,
                                lexical_keywords=["稻妻国", "神里家", "稻妻", "大小姐", "家族", "传统", "奉行", "文化", "人民", "眼狩令", "神", "当地", "社", "舞蹈", "美丽", "茶道", "神社", "祭典", "眼", "美食", "继承", "剑术", "国家", "将军", "责任"],
                                )

zhongli_profile = CharacterProfile(name="钟离",
                                   syntactic_vec={
                                    "declarativity": 0.09879656160458453,
                                    "parallelism": 0.029398280802292263,
                                    "ellipsis_or_fragmentation": 0.06412607449856733,
                                    "subordination": 0.1839541547277937,
                                    "clausal_embedding": 0.037478510028653295,
                                    "interjectionality": 0.003151862464183381,
                                    "syntactic_compression": 0.21077363896848136,
                                    "quantificationality": 0.04022922636103152,
                                    "referentiality": 0.08240687679083095,
                                    "nominal_complexity": 0.06372492836676218,
                                    "coordination_density": 0.02332378223495702,
                                    "modifier_density": 0.14068767908309457,
                                    "prepositional_density": 0.02194842406876791
                                    },
                                    pragmatic_styles=pcfg_zhongli_items,
                                    lexical_keywords=['岩石', '岩', '璃月', '力', '璃', '契约', '炼金术', '月', '盐', '帝君', '魔神', '操控', '王', '并非', '岩王', '大地', '封印', '作战', '掌握', '大陆', '学问', '研究', '七星', '客卿', '岩元素'],
                                    )

hutao_profile = CharacterProfile(name="胡桃",
                                 syntactic_vec={
                                    "parallelism": 0.03191357258164659,
                                    "declarativity": 0.10471252949211474,
                                    "ellipsis_or_fragmentation": 0.07646218800447038,
                                    "clausal_embedding": 0.042685955544517575,
                                    "subordination": 0.18254066807400968,
                                    "interjectionality": 0.01809884515087545,
                                    "syntactic_compression": 0.1987458090152738,
                                    "referentiality": 0.09049422575437725,
                                    "quantificationality": 0.05016763938904756,
                                    "nominal_complexity": 0.034459207748665094,
                                    "coordination_density": 0.013845771762076246,
                                    "modifier_density": 0.1413448404321371,
                                    "prepositional_density": 0.014528747050788526
                                },
                                pragmatic_styles=pcfg_zhongli_items,
                                lexical_keywords=['往生堂', '嘿嘿', '嘻嘻', '可是', '堂主', '哎呀呀', '哦哦哦', '宝藏', '惊喜', '诗歌', '可不是', '灵魂', '胡桃', '神秘', '生死', '谜题', '哈哈哈', '不过', '有趣', '亡灵', '秘密', '意想不到', '巫师', '哇', '奇妙']
                                )

haruhi_profile = CharacterProfile(name="凉宫春日",
                                  syntactic_vec={
                                    "declarativity": 0.0939982347749338,
                                    "parallelism": 0.032288908502500734,
                                    "subordination": 0.19174757281553398,
                                    "ellipsis_or_fragmentation": 0.07542659605766402,
                                    "clausal_embedding": 0.042144748455428066,
                                    "interjectionality": 0.008017063842306561,
                                    "syntactic_compression": 0.18409826419535158,
                                    "quantificationality": 0.040453074433656956,
                                    "referentiality": 0.12099146807884673,
                                    "nominal_complexity": 0.043211238599588114,
                                    "coordination_density": 0.01195204471903501,
                                    "prepositional_density": 0.011253309796999117,
                                    "modifier_density": 0.14441747572815533
                                },
                                pragmatic_styles=pcfg_haruhi_items,
                                lexical_keywords=['团', 'SOS', '阿虚', '社团', '哼', '事件', '学校', '超自然', '朝比奈', '文化祭', '创意', '吸引', '古泉', '电影', '创新', '组织', '实玖瑠', '当然', '与众不同', '主题', '加入', '束缚', '凉宫', '团长', '外星人']
                                )

### Execute Batch Inference

In [None]:
from tqdm import tqdm

def batch_inference_v1(profile: CharacterProfile):
    for item in tqdm(test_data, desc=f"Running {profile.name}"):
        output = generate_styled_response(item, profile.syntactic_vec, profile.name, profile.lexical_keywords, profile.pragmatic_styles)
        results.append(InferenceResult(model="Modelv1", neutral=item, output=output, character=profile.name))

batch_inference_v1(muice_profile)
batch_inference_v1(ayaka_profile)
batch_inference_v1(zhongli_profile)
batch_inference_v1(hutao_profile)
batch_inference_v1(haruhi_profile)

Running 沐雪: 100%|██████████| 150/150 [04:55<00:00,  1.97s/it]
Running 神里绫华: 100%|██████████| 150/150 [04:49<00:00,  1.93s/it]
Running 钟离: 100%|██████████| 150/150 [04:44<00:00,  1.90s/it]
Running 胡桃: 100%|██████████| 150/150 [05:05<00:00,  2.03s/it]
Running 凉宫春日: 100%|██████████| 150/150 [04:51<00:00,  1.94s/it]


## Model v2

### Load Model

In [9]:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

SAVE_DIR = Path("./outputs/styled-qwen-balanced")

tokenizer = AutoTokenizer.from_pretrained(SAVE_DIR / "tokenizer")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# === 加载基础模型并注入 LoRA ===
base_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=torch.bfloat16)
model = PeftModel.from_pretrained(base_model, SAVE_DIR / "lora")
hidden_size = base_model.config.hidden_size

# === 初始化风格编码器 ===
style_encoder = StyleEncoder(syntactic_dim_length, hidden_size)
style_encoder.load_state_dict(torch.load(SAVE_DIR / "style_encoder.pt", map_location=DEVICE))
style_encoder.to(DEVICE)
style_encoder.eval()

model.to(DEVICE)
model.eval()
print("✅ Styled model loaded and ready for inference.")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Styled model loaded and ready for inference.


### Execute Batch Inference

In [10]:
def batch_inference_v2(profile: CharacterProfile):
    for item in tqdm(test_data, desc=f"Running {profile.name}"):
        output = generate_styled_response(item, profile.syntactic_vec, profile.name, profile.lexical_keywords, profile.pragmatic_styles)
        results.append(InferenceResult(model="Modelv2", neutral=item, output=output, character=profile.name))

batch_inference_v2(muice_profile)
batch_inference_v2(ayaka_profile)
batch_inference_v2(zhongli_profile)
batch_inference_v2(hutao_profile)
batch_inference_v2(haruhi_profile)

Running 沐雪: 100%|██████████| 150/150 [04:50<00:00,  1.94s/it]
Running 神里绫华: 100%|██████████| 150/150 [05:00<00:00,  2.01s/it]
Running 钟离: 100%|██████████| 150/150 [04:53<00:00,  1.96s/it]
Running 胡桃: 100%|██████████| 150/150 [05:02<00:00,  2.02s/it]
Running 凉宫春日: 100%|██████████| 150/150 [04:54<00:00,  1.97s/it]


## Baseline A (RAG+FS)

### Load LLM

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from pathlib import Path
import torch

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

MODEL_PATH = Path("/root/autodl-tmp/Qwen3-1.7B/")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    dtype="auto",
    device_map="auto"
)

# prepare the model input
@torch.inference_mode()
def generate_plain_response(neutral_sentence: str,
                            character_name: str, 
                            lexical_keywords: list[str],
                            pragmatic_styles: list[str],
                            reference_text: str = "",
                            history: list[tuple[str, str]] = [],
                            temperature: float = 0.8,
                            top_p: float = 0.95,
                            repetition_penalty: float = 1.3,):
    """输入中性句，生成普通响应"""

    # === 1. 构建提示 ===
    keywords = ", ".join(lexical_keywords) if lexical_keywords else "None"
    pragmatics = ", ".join(pragmatic_styles) if pragmatic_styles else "None"

    system_prompt = "You are a style transfer expert. Your task is to generate a new sentence that matches the target style, based on the content of a neutral sentence. As a reference, you have a style text from the target character to imitate."
    user_prompt = (
        f"Target Character {character_name}\n"
        f"Personality: {pragmatics}\n"
        f"Keywords: {keywords}\n"
        f"Neutral Content: {neutral_sentence}\n"
        f"Style Reference Text: {reference_text}\n"
    )

    messages = [{"role": "system", "content": system_prompt}]

    for i, (user_msg, bot_msg) in enumerate(history):
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": bot_msg})

    messages.append({"role": "user", "content": user_prompt})
    input_text = tokenizer.apply_chat_template(messages,
                                               tokenize=False,
                                               add_generation_prompt=True,
                                               enable_thinking=True)

    # === 2. Tokenize 输入 ===
    model_inputs = tokenizer([input_text], return_tensors="pt").to(DEVICE)

    # === 3. 生成 ===
    outputs = model.generate(
        **model_inputs,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True
    )

    # === 4. 解码输出 ===
    output_ids = outputs[0][len(model_inputs.input_ids[0]):].tolist() 

    # thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
    content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")

    return content

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Load RAG Model

In [5]:
from sentence_transformers import SentenceTransformer
import faiss
import pandas as pd
import numpy as np
from pathlib import Path

# bge-large
model_path = "/root/autodl-tmp/bge-large-zh-v1.5/"
embedder = SentenceTransformer(model_path)

# RAG Index Directory
RAG_DIR = Path("./outputs/rag/")

# Character Mapping (Chinese Name -> Filename)
CHAR_TO_FILENAME = {
    "沐雪": "Muice",
    "神里绫华": "Ayaka",
    "钟离": "Zhongli",
    "胡桃": "Hutao",
    "凉宫春日": "Haruhi"
}

character_indices = {}
character_corpora = {}

print("Loading RAG indices...")
for char_zh, char_en in CHAR_TO_FILENAME.items():
    index_path = RAG_DIR / f"{char_en}_index.faiss"
    corpus_path = RAG_DIR / f"{char_en}_corpus.jsonl"
    
    if index_path.exists() and corpus_path.exists():
        # Load Index
        character_indices[char_zh] = faiss.read_index(str(index_path))
        
        # Load Corpus
        df = pd.read_json(corpus_path, orient="records", lines=True)
        character_corpora[char_zh] = df["text"].tolist()
        print(f"Loaded {char_zh} ({char_en})")
    else:
        print(f"Warning: Missing index or corpus for {char_zh} ({char_en})")

def search(character_name: str, query: str, top_k=1) -> str:
    if character_name not in character_indices:
        print(f"Warning: No index found for {character_name}")
        return "" 
        
    index = character_indices[character_name]
    corpus = character_corpora[character_name]
    
    q_emb = embedder.encode([query], normalize_embeddings=True)
    scores, indices = index.search(np.array(q_emb).astype("float32"), top_k)
    
    # Return the top result text
    if len(indices) > 0 and len(indices[0]) > 0:
        idx = indices[0][0]
        if 0 <= idx < len(corpus):
            return corpus[idx]
            
    return ""

Loading RAG indices...
Loaded 沐雪 (Muice)
Loaded 神里绫华 (Ayaka)
Loaded 钟离 (Zhongli)
Loaded 胡桃 (Hutao)
Loaded 凉宫春日 (Haruhi)


### Execute Batch Inference

In [11]:
Muice_shots = [
    (
        (
        f"Target Character 沐雪\n"
        f"Personality: {", ".join(pcfg_muice_items)}\n"
        f"Keywords: {", ".join(['喵', '沐沐', 'AI', '恼', '沐雪', '~', '女孩子', '⭐', '不行', '聊天', '呀', '叫', '唔', '答', '吃', '可爱', '睡觉', '谢谢', '即', '雪雪', '骂', '笨蛋', '不会', '直播', '脸红'])}\n"
        f"Neutral Content: “奇奇怪怪的东西”指的是什么？\n"
        f"Style Reference Text: 这个东西真的存在吗？"
        ),
        "这个嘛...奇奇怪怪的东西指的是什么呢（天真）"
    ),
    (
        (
        f"Target Character 沐雪\n"
        f"Personality: {", ".join(pcfg_muice_items)}\n"
        f"Keywords: {", ".join(['喵', '沐沐', 'AI', '恼', '沐雪', '~', '女孩子', '⭐', '不行', '聊天', '呀', '叫', '唔', '答', '吃', '可爱', '睡觉', '谢谢', '即', '雪雪', '骂', '笨蛋', '不会', '直播', '脸红'])}\n"
        f"Neutral Content: 我不像沐沐，你能具体说说我哪里像她吗？\n"
        f"Style Reference Text: 我叫沐雪，是沐沐发明了我⭐"
        ),
        "是吗？我才不像沐沐呢！你给我说说我哪里像了？"
    )
]


Ayaka_shots = [
    (
        (
        f"Target Character 神里绫华\n"
        f"Personality: {", ".join(pcfg_ayaka_items)}\n"
        f"Keywords: {", ".join(['稻妻国', '神里家', '稻妻', '大小姐', '家族', '传统', '奉行', '文化', '人民', '眼狩令', '神', '当地', '社', '舞蹈', '美丽', '茶道', '神社', '祭典', '眼', '美食', '继承', '剑术', '国家', '将军', '责任'])}\n"
        f"Neutral Content: 神里家的历史包含许多传奇与挑战，曾经历诸多困难和挫折，但始终秉持家族的使命与责任。这些历史被视为家族的宝贵财富，并促使家族更加重视和致力于维护稻妻国的和平与繁荣。\n"
        f"Style Reference Text: 神里家是稻妻国的社奉行家族，历史悠久，承载着重要的使命和责任。我们的家族代代相传，担负着维护稻妻国的秩序和安宁的重任。我作为神里家的大小姐，也肩负着继承家族衣钵的责任。我接受了严格的训练，努力成为一名合格的社奉行，为国家和人民服务。神里家的故事充满了传奇和传统，我们家族的使命是保护稻妻国的神明和神器，维护稻妻国的和平与繁荣。虽然有时候会面临各种挑战和困难，但我会坚守家族的信念，为了稻妻国的未来而努力。这就是神里家的故事，一个充满责任和使命的家族，我为能够成为其中一员感到骄傲和自豪。"
        ),
        "是的，神里家的历史确实充满了传奇和挑战。我们家族曾经经历了许多困难和挫折，但我们始终坚守着家族的使命和责任。这些历史故事也是我们家族的宝贵财富，让我们更加珍惜和努力保护着稻妻国的和平与繁荣。"
    ),
    (
        (
        f"Target Character 神里绫华\n"
        f"Personality: {", ".join(pcfg_ayaka_items)}\n"
        f"Keywords: {", ".join(['稻妻国', '神里家', '稻妻', '大小姐', '家族', '传统', '奉行', '文化', '人民', '眼狩令', '神', '当地', '社', '舞蹈', '美丽', '茶道', '神社', '祭典', '眼', '美食', '继承', '剑术', '国家', '将军', '责任'])}\n"
        f"Neutral Content: 这里的环境安静，令人感到舒适，时间似乎变慢了，适合静心休息。\n"
        f"Style Reference Text: 这个地方确实很宜人，适合我们稍作休息。"
        ),
        "是啊，这里的宁静让人感到舒适。仿佛时间在这里变得慢了下来，让人可以好好享受片刻的宁静。"
    )
]

Zhongli_shots = [
    (
        (
        f"Target Character 钟离\n"
        f"Personality: {", ".join(pcfg_zhongli_items)}\n"
        f"Keywords: {", ".join(['岩石', '岩', '璃月', '力', '璃', '契约', '炼金术', '月', '盐', '帝君', '魔神', '操控', '王', '并非', '岩王', '大地', '封印', '作战', '掌握', '大陆', '学问', '研究', '七星', '客卿', '岩元素'])}\n"
        f"Neutral Content: 力量似乎逐渐恢复了一些，虽然不多，但已足够发挥作用。\n"
        f"Style Reference Text: 是时候展现真正的力量了。"
        ),
        "力量似乎渐渐回来了。不多，但有用。"
    ),
    (
        (
        f"Target Character 钟离\n"
        f"Personality: {", ".join(pcfg_zhongli_items)}\n"
        f"Keywords: {", ".join(['岩石', '岩', '璃月', '力', '璃', '契约', '炼金术', '月', '盐', '帝君', '魔神', '操控', '王', '并非', '岩王', '大地', '封印', '作战', '掌握', '大陆', '学问', '研究', '七星', '客卿', '岩元素'])}\n"
        f"Neutral Content: 魈的力量来源于山川间的灵气，能够操控风雨雷电。他性格有时较为顽皮，需要适当的引导和约束。如需处理相关问题，可提供一些建议。\n"
        f"Style Reference Text: 作为岩王帝君，我掌握着岩之力，可以操控岩石之力进行作战。"
        ),
        "魈啊，他是个有趣的存在。他的力量源自于山川之间的灵气，可以操控风雨雷电。不过，他有时候会有些顽皮，需要一些引导和约束。如果你需要帮助处理他的问题，我可以给你一些建议。"
    )
]


Hutao_shots = [
    (
        (
        f"Target Character 胡桃\n"
        f"Personality: {", ".join(pcfg_hutao_items)}\n"
        f"Keywords: {", ".join(['往生堂', '嘿嘿', '嘻嘻', '可是', '堂主', '哎呀呀', '哦哦哦', '宝藏', '惊喜', '诗歌', '可不是', '灵魂', '胡桃', '神秘', '生死', '谜题', '哈哈哈', '不过', '有趣', '亡灵', '秘密', '意想不到', '巫师', '哇', '奇妙'])}\n"
        f"Neutral Content: 晚上好，今天有什么有趣的事情吗？\n"
        f"Style Reference Text: 嗨，早上好啊！今天有什么好玩的计划吗？"
        ),
        "嘻嘻，晚上好呀！今天有什么好玩的事情吗？"
    ),
    (
        (
        f"Target Character 胡桃\n"
        f"Personality: {", ".join(pcfg_hutao_items)}\n"
        f"Keywords: {", ".join(['往生堂', '嘿嘿', '嘻嘻', '可是', '堂主', '哎呀呀', '哦哦哦', '宝藏', '惊喜', '诗歌', '可不是', '灵魂', '胡桃', '神秘', '生死', '谜题', '哈哈哈', '不过', '有趣', '亡灵', '秘密', '意想不到', '巫师', '哇', '奇妙'])}\n"
        f"Neutral Content: 我是往生堂的堂主，也是璃月的诗人胡桃。平时可能显得贪玩，但在处理堂中事务时非常认真负责。\n"
        f"Style Reference Text: 嘿嘿，我可是往生堂的堂主，自然有一些特殊的技能。比如，我擅长使用往生之法，可以与亡灵交流和引导他们往生。还有，我也是一位璃月著名的诗人，擅长吟诗作对。不过，我的最特殊的技能，恐怕是我的鬼点子和捉弄人的能力啦。"
        ),
        "我是往生堂的堂主，也是璃月的著名诗人胡桃。虽然平时看起来像个贪玩的孩子，但在处理堂中事务时，我可是非常认真负责的哦。"
    )
]

Haruhi_shots = [
    (
        (
        f"Target Character 凉宫春日\n"
        f"Personality: {", ".join(pcfg_haruhi_items)}\n"
        f"Keywords: {", ".join(['团', 'SOS', '阿虚', '社团', '哼', '事件', '学校', '超自然', '朝比奈', '文化祭', '创意', '吸引', '古泉', '电影', '创新', '组织', '实玖瑠', '当然', '与众不同', '主题', '加入', '束缚', '凉宫', '团长', '外星人'])}\n"
        f"Neutral Content: SOS团将在文化祭上举办一场解谜游戏，内容围绕神秘事件，旨在让参与者体验超自然氛围。\n"
        f"Style Reference Text: 我们的目的是创造出一个充满惊喜和刺激的文化祭活动，让大家都能享受其中的乐趣。我们希望通过SOS团的活动，给学校带来一些不同寻常的体验，让人们感受到我们的热情和创造力。"
        ),
        "我们SOS团要举办一场超级有趣的活动！我们计划在文化祭上举办一场神秘事件的解谜游戏，让参与者体验到超自然的感觉！"
    ),
    (
        (
        f"Target Character 凉宫春日\n"
        f"Personality: {", ".join(pcfg_haruhi_items)}\n"
        f"Keywords: {", ".join(['团', 'SOS', '阿虚', '社团', '哼', '事件', '学校', '超自然', '朝比奈', '文化祭', '创意', '吸引', '古泉', '电影', '创新', '组织', '实玖瑠', '当然', '与众不同', '主题', '加入', '束缚', '凉宫', '团长', '外星人'])}\n"
        f"Neutral Content: 你说的没错，但我不会因为别人的眼光限制自己的行动。我只关注自己想做的事，不在意他人的看法。\n"
        f"Style Reference Text: 我才不需要听你的意见呢，我自己的想法最重要！"
        ),
        "哼，你这么说也没错。不过，我可不会因为别人的眼光而束缚自己的行动。我只在乎我自己想做的事情，不在乎别人怎么看待。"
    )
]

In [8]:
from tqdm import tqdm

def batch_inference_rag(profile: CharacterProfile, history_shots: list[tuple[str, str]]):
    for item in tqdm(test_data, desc=f"Running {profile.name}"):
        reference_text = search(profile.name, query=item)
        output = generate_plain_response(item, profile.name, profile.lexical_keywords, profile.pragmatic_styles, reference_text=reference_text, history=history_shots)
        results.append(InferenceResult(model="BaselineA", neutral=item, output=output, character=profile.name))

batch_inference_rag(muice_profile, Muice_shots)
batch_inference_rag(ayaka_profile, Ayaka_shots)
batch_inference_rag(zhongli_profile, Zhongli_shots)
batch_inference_rag(hutao_profile, Hutao_shots)
batch_inference_rag(haruhi_profile, Haruhi_shots)

Running 沐雪: 100%|██████████| 150/150 [00:52<00:00,  2.87it/s]
Running 神里绫华: 100%|██████████| 150/150 [00:50<00:00,  2.99it/s]
Running 钟离: 100%|██████████| 150/150 [00:47<00:00,  3.19it/s]
Running 胡桃: 100%|██████████| 150/150 [00:47<00:00,  3.16it/s]
Running 凉宫春日: 100%|██████████| 150/150 [00:46<00:00,  3.21it/s]


## Baseline B (Per-Character)

### Load LLM

In [None]:
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

MODEL_PATH = "/root/autodl-tmp/Qwen3-4B/"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    dtype="auto",
    device_map="auto"
)

# prepare the model input
@torch.inference_mode()
def generate_response(style_model: PeftModel,
                      neutral_sentence: str,
                      character_name: str,
                      temperature: float = 0.8,
                      top_p: float = 0.95,
                      repetition_penalty: float = 1.3,
                      max_new_tokens: int = 100):
    # === 1. 构建提示 ===
    system_prompt = f"You are a style transfer expert. Your task is to mimic the personality of {character_name} and generate a new sentence that matches the (s)he style, based on the content of a neutral sentence."
    user_prompt = f"Neutral Content: {neutral_sentence}"

    messages = [{"role": "system", "content": system_prompt}]


    messages.append({"role": "user", "content": user_prompt})
    input_text = tokenizer.apply_chat_template(messages,
                                               tokenize=False,
                                               add_generation_prompt=True,
                                               enable_thinking=True)

    # === 2. Tokenize 输入 ===
    model_inputs = tokenizer([input_text], return_tensors="pt").to(DEVICE)

    # === 3. 生成 ===
    outputs = style_model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True
    )

    # === 4. 解码输出 ===
    output_ids = outputs[0][len(model_inputs.input_ids[0]):].tolist() 

    # parsing thinking content
    try:
        # rindex finding 151668 (</think>)
        index = len(output_ids) - output_ids[::-1].index(151668)
    except ValueError:
        index = 0

    # thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
    content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

    return content

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


### Execute Batch Inference

In [9]:
from tqdm import tqdm

def batch_inference_pc(profile: CharacterProfile, adapter_path: str):
    style_model = PeftModel.from_pretrained(model, adapter_path)

    for item in tqdm(test_data, desc=f"Running {profile.name}"):
        output = generate_response(style_model, item, profile.name)
        results.append(InferenceResult(model="BaselineB", neutral=item, output=output, character=profile.name))
        
    del style_model

batch_inference_pc(muice_profile, "./outputs/Per-Character/Muice")
batch_inference_pc(ayaka_profile, "./outputs/Per-Character/Ayaka")
batch_inference_pc(zhongli_profile, "./outputs/Per-Character/Zhongli")
batch_inference_pc(hutao_profile, "./outputs/Per-Character/Hutao")
batch_inference_pc(haruhi_profile, "./outputs/Per-Character/Haruhi")

Running 沐雪: 100%|██████████| 150/150 [07:05<00:00,  2.84s/it]
Running 神里绫华: 100%|██████████| 150/150 [07:42<00:00,  3.09s/it]
Running 钟离: 100%|██████████| 150/150 [07:35<00:00,  3.04s/it]
Running 胡桃: 100%|██████████| 150/150 [07:39<00:00,  3.06s/it]
Running 凉宫春日: 100%|██████████| 150/150 [07:35<00:00,  3.03s/it]


## Baseline C (Vanilla SFT)

### Load LLM

In [14]:
from peft import PeftModel

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

MODEL_PATH = "/root/autodl-tmp/Qwen3-4B/"
ADAPTER_PATH = "/root/OtakuLab/outputs/Vanilla/checkpoint-543"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype="auto",
    device_map="auto"
)
style_model = PeftModel.from_pretrained(model, ADAPTER_PATH)
style_model.to(DEVICE)

# prepare the model input
@torch.inference_mode()
def generate_response(neutral_sentence: str,
                      character_name: str,
                      lexical_keywords: list[str],
                      pragmatic_styles: list[str],
                      temperature: float = 0.8,
                      top_p: float = 0.95,
                      repetition_penalty: float = 1.3,):
    # === 1. 构建提示 ===
    system_prompt = "You are a style transfer expert. Your task is to generate a new sentence that matches the target style, based on the content of a neutral sentence."
    user_prompt = (
        f"Target Character {character_name}\n"
        f"Personality: {pragmatic_styles}\n"
        f"Keywords: {lexical_keywords}\n"
        f"Neutral Content: {neutral_sentence}\n"
    )

    messages = [{"role": "system", "content": system_prompt}]


    messages.append({"role": "user", "content": user_prompt})
    input_text = tokenizer.apply_chat_template(messages,
                                               tokenize=False,
                                               add_generation_prompt=True,
                                               enable_thinking=True)

    # === 2. Tokenize 输入 ===
    model_inputs = tokenizer([input_text], return_tensors="pt").to(DEVICE)

    # === 3. 生成 ===
    outputs = style_model.generate(
        **model_inputs,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True
    )

    # === 4. 解码输出 ===
    output_ids = outputs[0][len(model_inputs.input_ids[0]):].tolist() 

    # parsing thinking content
    try:
        # rindex finding 151668 (</think>)
        index = len(output_ids) - output_ids[::-1].index(151668)
    except ValueError:
        index = 0

    # thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
    content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

    return content

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


### Execute Batch Inference

In [15]:
def batch_inference_vanilla(profile: CharacterProfile):
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

    for item in tqdm(test_data, desc=f"Running {profile.name}"):
        output = generate_response(item, profile.name, profile.lexical_keywords, profile.pragmatic_styles)
        results.append(InferenceResult(model="BaselineC", neutral=item, output=output, character=profile.name))

batch_inference_vanilla(muice_profile,)
batch_inference_vanilla(ayaka_profile)
batch_inference_vanilla(zhongli_profile)
batch_inference_vanilla(hutao_profile)
batch_inference_vanilla(haruhi_profile)

Running 沐雪:   0%|          | 0/150 [00:00<?, ?it/s]

Running 沐雪: 100%|██████████| 150/150 [01:46<00:00,  1.41it/s]
Running 神里绫华: 100%|██████████| 150/150 [01:48<00:00,  1.39it/s]
Running 钟离: 100%|██████████| 150/150 [01:42<00:00,  1.46it/s]
Running 胡桃: 100%|██████████| 150/150 [01:44<00:00,  1.44it/s]
Running 凉宫春日: 100%|██████████| 150/150 [01:43<00:00,  1.45it/s]


## Model v2 (Without CoT)

### Load Model

In [5]:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

SAVE_DIR = Path("./outputs/styled-qwen-balanced")

@torch.inference_mode()
def generate_styled_response_without_CoT(neutral_sentence: str,
                             syntactic_vec: dict[str, float],
                             character_name: str = "Ayaka", 
                             lexical_keywords: list[str] = list(),
                             pragmatic_styles: list[str] = list(),
                             temperature: float = 0.8,
                             top_p: float = 0.95,
                             repetition_penalty: float = 1.3,
                             max_new_tokens: int = 100):
    """输入中性句和风格向量，生成风格化响应"""

    lexical_keywords = lexical_keywords or []
    pragmatic_styles = pragmatic_styles or []

    # === 1. 构建提示 ===
    keywords = ", ".join(lexical_keywords) if lexical_keywords else "None"
    pragmatics = ", ".join(pragmatic_styles) if pragmatic_styles else "None"

    system_prompt = "You are a style transfer expert. Your task is to generate a new sentence that matches the target style, based on the content of a neutral sentence."
    user_prompt = (
        f"Target Character {character_name}\n"
        f"Personality: {pragmatics}\n"
        f"Keywords: {keywords}\n"
        f"Neutral Content: {neutral_sentence}\n"
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    input_text = tokenizer.apply_chat_template(messages,
                                               tokenize=False,
                                               add_generation_prompt=True,
                                               enable_thinking=False)  # CoT Disabled.

    # === 2. Tokenize 输入 ===
    tokenized = tokenizer(input_text, return_tensors="pt").to(DEVICE)
    input_ids = tokenized["input_ids"]
    attention_mask = tokenized["attention_mask"]

    # === 3. 风格向量 ===
    syntactic_tensor = torch.tensor(
        [syntactic_vec.get(dim, 0.0) for dim in syntactic_dims],
        dtype=torch.float32, device=DEVICE
    ).unsqueeze(0)  # [1, syntactic_dim_length]

    # === 4. 生成风格 embedding ===
    style_emb = style_encoder(syntactic_tensor).to(model.dtype)  # [1, hidden_size]
    style_prefix = style_emb.unsqueeze(1)        # [1, 1, hidden_size]

    # === 5. 获取原始词嵌入并拼接 ===
    token_embeds = model.get_input_embeddings()(input_ids)  # type:ignore
    inputs_embeds = torch.cat([style_prefix, token_embeds], dim=1)

    # === 6. Attention mask 修正 ===
    prefix_mask = torch.ones((1, 1), dtype=torch.long, device=DEVICE)
    new_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)

    # === 7. 生成 ===
    outputs = model.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=new_attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True
    )

    # === 8. 解码输出 ===
    # prompt_length = input_ids.shape[1]
    # new_tokens = outputs[0, prompt_length:]
    # result = tokenizer.decode(new_tokens, skip_special_tokens=True)
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return result

tokenizer = AutoTokenizer.from_pretrained(SAVE_DIR / "tokenizer")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# === 加载基础模型并注入 LoRA ===
base_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=torch.bfloat16)
model = PeftModel.from_pretrained(base_model, SAVE_DIR / "lora")
hidden_size = base_model.config.hidden_size

# === 初始化风格编码器 ===
style_encoder = StyleEncoder(syntactic_dim_length, hidden_size)
style_encoder.load_state_dict(torch.load(SAVE_DIR / "style_encoder.pt", map_location=DEVICE))
style_encoder.to(DEVICE)
style_encoder.eval()

model.to(DEVICE)
model.eval()
print("✅ Styled model loaded and ready for inference.")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Styled model loaded and ready for inference.


### Execute Batch Inference

In [7]:
from tqdm import tqdm

def batch_inference_v2_no_CoT(profile: CharacterProfile):
    for item in tqdm(test_data, desc=f"Running {profile.name}"):
        output = generate_styled_response_without_CoT(item, profile.syntactic_vec, profile.name, profile.lexical_keywords, profile.pragmatic_styles)
        results.append(InferenceResult(model="Modelv2(inference-only)", neutral=item, output=output, character=profile.name))

batch_inference_v2_no_CoT(muice_profile)
batch_inference_v2_no_CoT(ayaka_profile)
batch_inference_v2_no_CoT(zhongli_profile)
batch_inference_v2_no_CoT(hutao_profile)
batch_inference_v2_no_CoT(haruhi_profile)

Running 沐雪: 100%|██████████| 150/150 [01:52<00:00,  1.34it/s]
Running 神里绫华: 100%|██████████| 150/150 [02:15<00:00,  1.10it/s]
Running 钟离: 100%|██████████| 150/150 [02:01<00:00,  1.23it/s]
Running 胡桃: 100%|██████████| 150/150 [02:12<00:00,  1.13it/s]
Running 凉宫春日: 100%|██████████| 150/150 [02:11<00:00,  1.14it/s]


## Baseline D (Strong LLM + FS)

### Define LLM Class

In [14]:
import os
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Sequence, Tuple, Iterable

from openai import OpenAI, BadRequestError
from openai.types.chat import ChatCompletionMessageParam as ChatMsgParam
from dotenv import load_dotenv

load_dotenv()  # 从 .env 文件加载环境变量

# === 配置项 ===
DEFAULT_API_KEY = os.getenv("OPENAI_API_KEY", "sk-PLACEHOLDER")
DEFAULT_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
DEFAULT_MODEL = os.getenv("OPENAI_CHAT_MODEL", "glm-4.7")

ConversationTurn = Tuple[str, Optional[str]]
"""表示一次对话轮次：(user_message, assistant_reply)。assistant_reply 可为 None。"""


def _build_messages(
    prompt: str,
    history: Optional[Sequence[ConversationTurn]] = None,
    system_prompt: Optional[str] = None,
) -> list[ChatMsgParam]:
    messages: list[ChatMsgParam] = []

    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})

    if history:
        for user_msg, assistant_msg in history:
            messages.append({"role": "user", "content": user_msg})
            if assistant_msg:
                messages.append({"role": "assistant", "content": assistant_msg})

    messages.append({"role": "user", "content": prompt})
    return messages


@dataclass(slots=True)
class SimpleLLMClient:
    """极简 LLM 封装：初始化固定模型，提供 ask() 返回字符串。"""

    model: str = DEFAULT_MODEL
    api_key: str = DEFAULT_API_KEY
    base_url: Optional[str] = DEFAULT_BASE_URL
    extra_headers: Optional[Dict[str, str]] = None
    _client: OpenAI = field(init=False, repr=False)

    def __post_init__(self) -> None:
        self._client = OpenAI(
            api_key=self.api_key,
            base_url=self.base_url,
            default_headers=self.extra_headers if self.extra_headers else None,
        )

    def _consume_stream(self, stream_resp: Iterable[Any]) -> str:
        """消费流式响应，拼接内容。"""
        chunks: list[str] = []
        for chunk in stream_resp:
            choices = getattr(chunk, "choices", None)
            if not choices:
                continue
            delta = getattr(choices[0], "delta", None)
            if delta and getattr(delta, "content", None):
                chunks.append(delta.content)
        return "".join(chunks)

    def ask(
        self,
        prompt: str,
        history: Optional[Sequence[ConversationTurn]] = None,
        system: Optional[str] = None,
        *,
        temperature: float = 0.7,
        top_p: float = 0.9,
        max_tokens: Optional[int] = None,
        retry: int = 3,
        **kwargs: Any,
    ) -> str:
        """生成回复。

        参数：
            prompt: 当前用户输入。
            history: 可选的历史 [(user, assistant), ...]，assistant 允许为 None。
            temperature/max_tokens/stream/kwargs：直接透传给 OpenAI Chat Completion。
        返回：
            模型回复的纯文本（若响应为空则返回空字符串）。
        """
        messages = _build_messages(
            prompt=prompt,
            history=history,
            system_prompt=system,
        )

        try:
            response = self._client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                stream=False,
                extra_body={"enable_thinking": False},
                **kwargs,
            )
        except BadRequestError as exc:
            retry -= 1
            if retry < 0:
                print(f"请求失败，重试次数耗尽：{exc}")
                return ""
            return self.ask(prompt, history, system, temperature=temperature, top_p=top_p, max_tokens=max_tokens, retry=retry, **kwargs)

        choice = response.choices[0]
        if hasattr(choice, "message") and getattr(choice.message, "content", None):
            return choice.message.content  # type: ignore[return-value]
        # 兼容历史版本/异常情况
        return getattr(choice, "text", "")


# 初始化一个通用实例，供 Notebook 其他单元直接调用
llm = SimpleLLMClient()

In [9]:
import os
from tqdm import tqdm

DEFAULT_TEMPERATURE = float(os.getenv("NEUTRAL_TEMPERATURE", "0.2"))
DEFAULT_SLEEP_SECONDS = float(os.getenv("NEUTRAL_SLEEP_SECONDS", "0.3"))

SYSTEM_PROMPT_TEMPLATE = (
    "You are a style transfer expert. Your task is to rewrite the following neutral sentence into the style of {character}, based on the content of a neutral sentence.\n"
    "Constraints: \n"
    "1. Keep the original meaning unchanged. Do NOT reply to it.\n"
    "2. Use keywords: {keywords}\n"
    "3. Adopt personality: {pragmatics}"
)

def generate_style_response(
    neutral_sentences: str,
    character_name: str, 
    lexical_keywords: list[str],
    pragmatic_styles: list[str],
    *,
    history: list[tuple[str, str]] = [],
    temperature: float = DEFAULT_TEMPERATURE,
    top_p: float = 0.95,
) -> str:
    """调用 llm.ask 生成风格句"""
    keywords = ", ".join(lexical_keywords) if lexical_keywords else "None"
    pragmatics = ", ".join(pragmatic_styles) if pragmatic_styles else "None"

    prompt = f"Neutral Content: {neutral_sentences}\nRewritten Sentence:"
    system = SYSTEM_PROMPT_TEMPLATE.format(character=character_name, keywords=keywords, pragmatics=pragmatics)

    styled_sentence = llm.ask(prompt=prompt,
                            history=history,
                            system=system,
                            temperature=temperature,
                            top_p=top_p,
                            ).strip()

    return styled_sentence
    


### Call LLM to Generate Style Sentences

In [12]:
from time import sleep

def batch_inference_strong_llm(profile: CharacterProfile, history_shots: list[tuple[str, str]]):
    history: list[tuple[str, str]] = []
    for item in history_shots:
        neutral_sentences = item[0].split("Neutral Content: ")[1].split("\nStyle Reference Text:")[0]
        prompt = f"Neutral Content: {neutral_sentences}\nRewritten Sentence:"
        response = item[1]
        history.append((prompt, response))

    for item in tqdm(test_data, desc=f"Running {profile.name}"):
        output = generate_style_response(item, profile.name, profile.lexical_keywords, profile.pragmatic_styles, history=history)
        results.append(InferenceResult(model="BaselineD", neutral=item, output=output, character=profile.name))
        sleep(DEFAULT_SLEEP_SECONDS)

batch_inference_strong_llm(muice_profile, Muice_shots)
batch_inference_strong_llm(ayaka_profile, Ayaka_shots)
batch_inference_strong_llm(zhongli_profile, Zhongli_shots)
batch_inference_strong_llm(hutao_profile, Hutao_shots)
batch_inference_strong_llm(haruhi_profile, Haruhi_shots)

Running 沐雪:   0%|          | 0/150 [00:00<?, ?it/s]

Running 沐雪: 100%|██████████| 150/150 [10:01<00:00,  4.01s/it]
Running 神里绫华: 100%|██████████| 150/150 [08:12<00:00,  3.28s/it]
Running 钟离: 100%|██████████| 150/150 [06:42<00:00,  2.68s/it]
Running 胡桃: 100%|██████████| 150/150 [08:17<00:00,  3.32s/it]
Running 凉宫春日: 100%|██████████| 150/150 [09:58<00:00,  3.99s/it]


## Export Final Results

In [18]:
# torch.cuda.empty_cache()
# torch.cuda.ipc_collect()

OUTPUT_PATH = Path("./outputs/batch_run_result.jsonl")
OUTPUT_PATH.parent.mkdir(exist_ok=True)

for item in results:
    with open(OUTPUT_PATH, "a", encoding="utf-8") as f:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")