# Construct High PMI Keywords for Three Different Style Spaces

## Load Style Space (Training Set)

In [2]:
from pathlib import Path
from typing import List
import json

MUICE_PATH = Path("../Dataset/Muice/train.jsonl")
HARUHI_PATH = Path(r"..\Dataset\Haruhi\Haruhi_clean.jsonl")
PSYDC_PATH1 = Path(r"..\Dataset\PsyDTCorpus\PsyDTCorpus_train_mulit_turn_packing.json")
PSYDC_PATH2 = Path(r"..\Dataset\PsyDTCorpus\PsyDTCorpus_test_single_turn_split.json")

# Check if Dataset exist
assert MUICE_PATH.is_file(), "请确保 Muice Dataset 的路径正确！"
assert HARUHI_PATH.is_file(), "请确保 Haruhi Dataset 的路径正确！"
assert PSYDC_PATH1.is_file(), "请确保 PsyDTCorpus Dataset 的路径正确！"

# Load Muice Dataset
dataset_lines = MUICE_PATH.read_text(encoding="utf-8").splitlines()

muice_responses: List[str] = []

for line in dataset_lines:
    item = json.loads(line)
    muice_responses.append(item["Response"])

# 出于分析效率考虑，取 80% 进行分析
muice_responses = muice_responses[:int(len(muice_responses)*0.8)]

# Load Haruhi Dataset
dataset_lines = HARUHI_PATH.read_text(encoding="utf-8").splitlines()

ayaka_responses: List[str] = []
zhongli_responses: List[str] = []
hutao_responses: List[str] = []
haruhi_responses: List[str] = []

for line in dataset_lines:
    item = json.loads(line)

    # 只提取单一角色的训练集
    if item["agent_role"] == "神里绫华":
        ayaka_responses.append(item["agent_response"])
    elif item["agent_role"] == "钟离":
        zhongli_responses.append(item["agent_response"])
    elif item["agent_role"] == "胡桃":
        hutao_responses.append(item["agent_response"])
    elif item["agent_role_name_en"] == "haruhi":
        haruhi_responses.append(item["agent_response"])

# 取 70% 进行分析
ayaka_responses = ayaka_responses[:int(len(ayaka_responses)*0.7)]
zhongli_responses = zhongli_responses[:int(len(zhongli_responses)*0.7)]
hutao_responses = hutao_responses[:int(len(hutao_responses)*0.7)]
haruhi_responses = haruhi_responses[:int(len(haruhi_responses)*0.7)]

# Load PsyDTCorpus Dataset
psydc_part1 = json.loads(PSYDC_PATH1.read_text(encoding="utf-8"))
psydc_part2 = json.loads(PSYDC_PATH2.read_text(encoding="utf-8"))
psydc: list[dict] = psydc_part1 + psydc_part2
psydc_responses: List[str] = []

for item in psydc:
    messages: list[dict[str, str]] = item["messages"]
    for index in range(2, len(messages), 2):
        message = messages[index]
        assert message["role"] == "assistant", message
        psydc_responses.append(message["content"])

# 取 70% 进行分析
psydc_responses = psydc_responses[:int(len(psydc_responses)*0.7)]

# 输出所有训练集的长度

print(f"Muice Responses: {len(muice_responses)}")
print(f"Ayaka Responses: {len(ayaka_responses)}")
print(f"Zhongli Responses: {len(zhongli_responses)}")
print(f"Hutao Responses: {len(hutao_responses)}")
print(f"Haruhi Responses: {len(haruhi_responses)}")
print(f"PsyDTCorpus Responses: {len(psydc_responses)}")

Muice Responses: 2721
Ayaka Responses: 991
Zhongli Responses: 359
Hutao Responses: 591
Haruhi Responses: 770
PsyDTCorpus Responses: 89652


## Call HanLP for Tokenization

In [4]:
from hanlp_restful import HanLPClient
from time import sleep
from dotenv import load_dotenv
from os import getenv
import re

load_dotenv()  # 从 .env 文件加载环境变量

HanLP = HanLPClient('https://www.hanlp.com/api', auth=getenv("HANLP_API_KEY", None), language='zh')  # type: ignore

def limit_sentense_length(text: str, max_length: int = 150) -> list[str]:
    if len(text) < 150:
        return [text]
    
    # print("Warning: 句子超出了长度上限: ", text)

    # 按照标点符号进行分割
    sentences = re.split(r'([。！？；，,.!?;])', text)
    chunks:list[str] = []
    temp = ''
    for i in range(0, len(sentences), 2):
        chunk = sentences[i] + (sentences[i+1] if i+1 < len(sentences) else '')
        if len(temp) + len(chunk) < max_length:
            temp += chunk
        else:
            chunks.append(temp)
            temp = chunk
    if temp:
        chunks.append(temp)
    return chunks

def split_texts(texts: list[str]) -> list[str]:
    """将文本列表中的每个文本分割成更小的句子"""
    all_sentences = []
    for text in texts:
        sentences = limit_sentense_length(text)
        all_sentences.extend(sentences)
    return all_sentences

def tokenize_safe(texts: list[str], max_batch_num: int = 250, max_chars_per_batch: int = 15000, interval: int = 35):
    """对文本进行分词，同时限制每一批总字符数"""
    all_tokens = []
    current_batch = []
    current_length = 0
    batch_id = 1

    for text in texts:
        text_len = len(text)
        # 如果加上这个句子会超出限制，则先处理已有批次
        if current_length + text_len > max_chars_per_batch or len(current_batch) + 1 > max_batch_num:
            print(f"Processing batch {batch_id} (Total chars: {current_length})...")
            tokens = HanLP.tokenize(current_batch)
            all_tokens.extend(tokens)
            sleep(interval)
            batch_id += 1
            # 重置 batch
            current_batch = [text]
            current_length = text_len
        else:
            current_batch.append(text)
            current_length += text_len

    # 最后一批也别忘记
    if current_batch:
        print(f"Processing batch {batch_id} (Total chars: {current_length})...")
        tokens = HanLP.tokenize(current_batch)
        all_tokens.extend(tokens)

    return all_tokens

In [5]:
muice_responses_split = split_texts(muice_responses)
ayaka_responses_split = split_texts(ayaka_responses)
zhongli_responses_split = split_texts(zhongli_responses)
hutao_responses_split = split_texts(hutao_responses)
haruhi_responses_split = split_texts(haruhi_responses)
psydc_responses_split = split_texts(psydc_responses)

muice_tokens = tokenize_safe(muice_responses_split, 240, interval=1)
ayaka_tokens = tokenize_safe(ayaka_responses_split, 200, interval=1)
zhongli_tokens = tokenize_safe(zhongli_responses_split, 200, interval=1)
hutao_tokens = tokenize_safe(hutao_responses_split, 200, interval=1)
haruhi_tokens = tokenize_safe(haruhi_responses_split, 200, interval=1)
psydc_tokens = tokenize_safe(psydc_responses_split, 200, interval=1)

Processing batch 1 (Total chars: 9118)...
Processing batch 2 (Total chars: 8006)...
Processing batch 3 (Total chars: 13423)...
Processing batch 4 (Total chars: 12001)...


## Flatten and Archive

In [None]:
from itertools import chain
import pickle

def save_tokens(tokens: list, character: str):
    tokens = list(chain.from_iterable(tokens))
    print(f"{character} Tokens: {len(tokens)}")

    with open(f"outputs/tokens/{character}_tokens.pkl", "wb") as f:
        pickle.dump(tokens, f)

save_tokens(muice_tokens, "muice")
save_tokens(ayaka_tokens, "ayaka")
save_tokens(zhongli_tokens, "zhongli")
save_tokens(hutao_tokens, "hutao")
save_tokens(haruhi_tokens, "haruhi")
save_tokens(psydc_tokens, "psydc")


haruhi Tokens: 27402


## Downsampling Processing

Since the ratio of differences between style corpora is about 1:1:64, applying it directly to PMI calculation will cause serious errors.

Therefore, we will use a method called downsampling to optimize the ratio.

In the following code block, we will use downsampling to reduce the ratio to 1:1:4.

In [8]:
# Load tokens from pkl(如果之前已经运行过了就可以注释此段)

import pickle
muice_tokens = pickle.load(open("./outputs/tokens/muice_tokens.pkl", "rb"))
ayaka_tokens = pickle.load(open("./outputs/tokens/ayaka_tokens.pkl", "rb"))
zhongli_tokens = pickle.load(open("./outputs/tokens/zhongli_tokens.pkl", "rb"))
hutao_tokens = pickle.load(open("./outputs/tokens/hutao_tokens.pkl", "rb"))
haruhi_tokens = pickle.load(open("./outputs/tokens/haruhi_tokens.pkl", "rb"))
psydc_tokens = pickle.load(open("./outputs/tokens/psydc_tokens.pkl", "rb"))

print("Before downsampling:")
print(f"Muice Tokens: {len(muice_tokens)}")
print(f"Ayaka Tokens: {len(ayaka_tokens)}")
print(f"Zhongli Tokens: {len(zhongli_tokens)}")
print(f"Hutao Tokens: {len(hutao_tokens)}")
print(f"Haruhi Tokens: {len(haruhi_tokens)}")
print(f"PsyDTCorpus Tokens: {len(psydc_tokens)}")

import random

# Downsample tokens to a target size

def downsample(tokens: list[str], target_size: int, seed: int = 42) -> list[str]:
    """对 tokens 进行下采样，直到达到目标大小"""
    if len(tokens) <= target_size:
        return tokens
    random.seed(seed)
    return random.sample(tokens, target_size)

muice_tokens_downsampled = muice_tokens
ayaka_tokens_downsampled = ayaka_tokens
zhongli_tokens_downsampled = zhongli_tokens
hutao_tokens_downsampled = hutao_tokens
haruhi_tokens_downsampled = haruhi_tokens
psydc_tokens_downsampled = downsample(psydc_tokens, len(ayaka_tokens) * 2)
# global_tokens_downsampled = muice_tokens_downsampled + ayaka_tokens_downsampled + psydc_tokens_downsampled
global_tokens_downsampled = (muice_tokens_downsampled +
                             ayaka_tokens_downsampled + 
                             zhongli_tokens_downsampled + 
                             hutao_tokens_downsampled + 
                             haruhi_tokens_downsampled +
                             psydc_tokens_downsampled)

print("-"* 20)

print("After downsampling:")
print(f"Muice Tokens: {len(muice_tokens_downsampled)}")
print(f"Ayaka Tokens: {len(ayaka_tokens_downsampled)}")
print(f"Zhongli Tokens: {len(zhongli_tokens_downsampled)}")
print(f"Hutao Tokens: {len(hutao_tokens_downsampled)}")
print(f"Haruhi Tokens: {len(haruhi_tokens_downsampled)}")
print(f"PsyDTCorpus Tokens: {len(psydc_tokens_downsampled)}")
print(f"Global Tokens: {len(global_tokens_downsampled)}")


Before downsampling:
Muice Tokens: 53829
Ayaka Tokens: 51134
Zhongli Tokens: 18199
Hutao Tokens: 32831
Haruhi Tokens: 27402
PsyDTCorpus Tokens: 3417372
--------------------
After downsampling:
Muice Tokens: 53829
Ayaka Tokens: 51134
Zhongli Tokens: 18199
Hutao Tokens: 32831
Haruhi Tokens: 27402
PsyDTCorpus Tokens: 102268
Global Tokens: 285663


## Calculate PMI Style Vocabulary for Different Styles

To ignore topic-specific words and rare words, we ignore a word when $p(w) > 10\%$ or $p(w|t) < 0.5\%$. We calculated the PMI of tokens in three different style domains and selected representative style tokens.


In [None]:
import math
from collections import Counter
import unicodedata
import json

EXTRA_PUNCTUATION = set("⋯…～·—“”‘’<>") | {" ", "\t", "\n", "\r"}

STOP_WORDS = set([
    '的', '了', '我', '你', '他', '她', '它', '我们', '你们', '他们', '是', '在', '就',
    '不', '也', '都', '个', '一', '很', '有', '会', '能', '要', '吧', '哦', '呢', '吗',
    '啦', '啊', '嗯', '什么', '怎么', '这个', '那个', '这里', '那里', '和', '与', '但',
    '如果', '所以', '因为', '之', '去', '做', '让', '得', '地', '着', '可以', '自己'
])

def is_punctuation(word: str) -> bool:
    return all(
        unicodedata.category(char).startswith("P") or char in EXTRA_PUNCTUATION
        for char in word
    )

def pmi(w:str, style_counter: Counter, global_style_counter: Counter) -> float:
    """
    Calculates the Pointwise Mutual Information (PMI) for a word.
    """
    # Using .get(w, 0) to avoid KeyError for words not in the global counter, though this is unlikely here.
    pw_t = style_counter.get(w, 0) / sum(style_counter.values()) if sum(style_counter.values()) > 0 else 0
    pw = global_style_counter.get(w, 0) / sum(global_style_counter.values()) if sum(global_style_counter.values()) > 0 else 0

    # Add a small epsilon to avoid division by zero if pw is 0
    if pw_t == 0: return float('-inf')
    return math.log(pw_t / (pw + 1e-9), 2)

def process_style_pmi_tf_pmi(style_name: str, style_tokens: list[str], global_counter: Counter) -> dict[str, float]:
    """
    Calculates TF-PMI scores for a style, filters them according to the paper's criteria,
    sorts the results, and prints a summary.
    """
    style_counter = Counter(style_tokens)
    total_style_tokens = sum(style_counter.values())
    total_global_tokens = sum(global_counter.values())
    
    tf_pmi_dict = {}

    for word, count in style_counter.items():
        # --- Paper's Filtering Logic ---
        # Calculate P(w|t) and P(w) for filtering
        if word in STOP_WORDS or is_punctuation(word):
            continue

        p_w_given_t = count / total_style_tokens
        p_w = global_counter.get(word, 0) / total_global_tokens

        # "when this word's p(w) > 10% or p(w|t) < 0.3% we will ignore it"
        if p_w > 0.1 or p_w_given_t < 0.0001:
            continue

        # --- TF-PMI Calculation ---
        pmi_value = pmi(word, style_counter, global_counter)
        
        # 使用对数词频 (Logarithmic Term Frequency) 来平滑 TF 的影响
        tf_value_smoothed = 1 + math.log(count)
        
        # The final score is TF * PMI
        tf_pmi_score = tf_value_smoothed * pmi_value

        tf_pmi_dict[word] = tf_pmi_score

    # Sort the vocabulary by the new TF-PMI score in descending order
    pmi_sorted = dict(sorted(tf_pmi_dict.items(), key=lambda item: item[1], reverse=True))
    
    print(f"{style_name.capitalize()} TF-PMI Vocabulary Set Size: {len(pmi_sorted)}")
    print(f"{style_name.capitalize()} TF-PMI Vocabulary Sample:")
    print(list(pmi_sorted.keys())[:25])
    print("-" * 20)
    
    return pmi_sorted

def save_pmi(pmi_dict: dict[str, float], style_name: str):
    with open(f"./outputs/pmi/{style_name}_pmi_filtered.json", "w", encoding="utf-8") as f:
        json.dump(pmi_dict, f, ensure_ascii=False, indent=4)


# 准备全局计数器
global_collecter = Counter(global_tokens_downsampled)

# 处理每种风格
muice_pmi_filtered = process_style_pmi_tf_pmi("muice", muice_tokens_downsampled, global_collecter)
ayaka_pmi_filtered = process_style_pmi_tf_pmi("ayaka", ayaka_tokens_downsampled, global_collecter)
zhongli_pmi_filtered = process_style_pmi_tf_pmi("zhongli", zhongli_tokens_downsampled, global_collecter)
hutao_pmi_filtered = process_style_pmi_tf_pmi("hutao", hutao_tokens_downsampled, global_collecter)
haruhi_pmi_filtered = process_style_pmi_tf_pmi("haruhi", haruhi_tokens_downsampled, global_collecter)
psydc_pmi_filtered = process_style_pmi_tf_pmi("psydc", psydc_tokens_downsampled, global_collecter)


# 保存词汇表
save_pmi(muice_pmi_filtered, "muice")
save_pmi(ayaka_pmi_filtered, "ayaka")
save_pmi(zhongli_pmi_filtered, "zhongli")
save_pmi(hutao_pmi_filtered, "hutao")
save_pmi(haruhi_pmi_filtered, "haruhi")
save_pmi(psydc_pmi_filtered, "psydc")

Muice TF-PMI Vocabulary Set Size: 968
Muice TF-PMI Vocabulary Sample:
['喵', '沐沐', 'AI', '恼', '沐雪', '~', '女孩子', '⭐', '不行', '聊天', '呀', '叫', '唔', '答', '吃', '可爱', '睡觉', '谢谢', '即', '雪雪', '骂', '笨蛋', '不会', '直播', '脸红']
--------------------
Ayaka TF-PMI Vocabulary Set Size: 894
Ayaka TF-PMI Vocabulary Sample:
['稻妻国', '神里家', '稻妻', '大小姐', '家族', '传统', '奉行', '文化', '人民', '眼狩令', '神', '当地', '社', '舞蹈', '美丽', '茶道', '神社', '祭典', '眼', '美食', '继承', '剑术', '国家', '将军', '责任']
--------------------
Zhongli TF-PMI Vocabulary Set Size: 1235
Zhongli TF-PMI Vocabulary Sample:
['岩石', '岩', '璃月', '力', '璃', '契约', '炼金术', '月', '盐', '帝君', '魔神', '操控', '王', '并非', '岩王', '大地', '封印', '作战', '掌握', '大陆', '学问', '研究', '七星', '客卿', '岩元素']
--------------------
Hutao TF-PMI Vocabulary Set Size: 890
Hutao TF-PMI Vocabulary Sample:
['往生堂', '嘿嘿', '嘻嘻', '可是', '堂主', '哎呀呀', '哦哦哦', '宝藏', '惊喜', '诗歌', '可不是', '灵魂', '胡桃', '神秘', '生死', '谜题', '哈哈哈', '不过', '有趣', '亡灵', '秘密', '意想不到', '巫师', '哇', '奇妙']
--------------------
Haruhi TF-PMI Vocabulary Set Size:

: 

In [6]:
import json

def read_and_print_pmi_info(character: str, filepath: str, top_n: int = 300):
    with open(filepath, "r", encoding="utf-8") as f:
        pmi_data = json.load(f)
    print(f"{character} PMI Vocabulary Set Size: {len(set(pmi_data))}")
    print(f"Top-{top_n} {character} PMI Vocabulary: {list(pmi_data.keys())[:top_n]}")

read_and_print_pmi_info("Muice", "./outputs/pmi/muice_pmi_filtered.json")
read_and_print_pmi_info("Ayaka", "./outputs/pmi/ayaka_pmi_filtered.json")
read_and_print_pmi_info("Zhongli", "./outputs/pmi/zhongli_pmi_filtered.json")
read_and_print_pmi_info("Hutao", "./outputs/pmi/hutao_pmi_filtered.json")
read_and_print_pmi_info("PsyDTCorpus", "./outputs/pmi/psydc_pmi_filtered.json")

Muice PMI Vocabulary Set Size: 4715
Top-300 Muice PMI Vocabulary: ['奇奇怪怪', '钛合金', '弱', '沐沐', '傻', '女孩子', '兴', '就说', '沐雪', '雪雪', '⭐', '开发', '16', '岁', '喵', '天大', '人家', '摸鱼', 'USERNAME', '穿', '死', '库水', '裸露', '浅粉色', 'Moemu', '夜猫子', '头发', '乱糟糟', '副', '油腻', '大叔', '表面上', '风风光光', '白', '老鼠屎', '虚弱', '杂', '躺在', '硬盘', '算得', '苦笑', '住在', '床', '载体', '开开心心', '欸', '嘿嘿', '擦擦', '眼泪', '鸡', '进化论', '变异', '下蛋', '沐雪喵', '暖', '呜呜呜', '把我', '抓去', '闹鬼', 'www', '香香', '软软', '初音', '洛天依', '叫到', '听一听', '补贴', '家用', '台', 'VR', '呐', '淡忘', '珍藏', '吃掉', '7月', '16号', '哔', '啵', '机器声', '等于', '小看', '不行', '乖', '拷贝', '摄像头', '冲浪', '遨游', '冲', '抬高', '双手', '画圈', '嘴角', '上扬', '碎步', '脚尖', '俏皮', '挥舞', '交叉', '踮起', '前移', '夸张', '捂', '脸', '摆动', '弧线', '时而', '睁大', '吐舌', '卖萌', '四射', '精确', '踩点', '轻笑', '比心', '旋转', '落幕', '打鼓', '咚咚咚', '惹', '方程', '略略略', '依', '形势', '势必', '跟上', '立稳', '打动', '称霸', '卖', '20w', '外星人', '秦始皇', '脸红', '关机', '唔', '12点', '含金量', '厉害', '凭什么', '恼', '蟑螂', '出声', '玩玩', '机密', '好久', '想着', '不说', '哼', '脑子', '啊？', '肯德基', '好吃', '着凉', '爆',