# Construct PCFG Models for Three Different Style Spaces

## Load Style Space (Training Set)

In [5]:
from pathlib import Path
from typing import List
import json

MUICE_PATH = Path("../Dataset/Muice/train.jsonl")
HARUHI_PATH = Path(r"..\Dataset\Haruhi\Haruhi_clean.jsonl")
PSYDC_PATH1 = Path(r"..\Dataset\PsyDTCorpus\PsyDTCorpus_train_mulit_turn_packing.json")
PSYDC_PATH2 = Path(r"..\Dataset\PsyDTCorpus\PsyDTCorpus_test_single_turn_split.json")

# Check if Dataset exist
assert MUICE_PATH.is_file(), "请确保 Muice Dataset 的路径正确！"
assert HARUHI_PATH.is_file(), "请确保 Haruhi Dataset 的路径正确！"
assert PSYDC_PATH1.is_file(), "请确保 PsyDTCorpus Dataset 的路径正确！"

# Load Muice Dataset
dataset_lines = MUICE_PATH.read_text(encoding="utf-8").splitlines()

muice_responses: List[str] = []

for line in dataset_lines:
    item = json.loads(line)
    muice_responses.append(item["Response"])

# 出于计算效率，只取 80% 进行计算
muice_responses = muice_responses[:int(len(muice_responses)*0.8)]

# Load Haruhi Dataset
dataset_lines = HARUHI_PATH.read_text(encoding="utf-8").splitlines()

ayaka_responses: List[str] = []
zhongli_responses: List[str] = []
hutao_responses: List[str] = []
haruhi_responses: List[str] = []

for line in dataset_lines:
    item = json.loads(line)

    # 只提取单一角色的训练集
    if item["agent_role"] == "神里绫华":
        ayaka_responses.append(item["agent_response"])
    elif item["agent_role"] == "钟离":
        zhongli_responses.append(item["agent_response"])
    elif item["agent_role"] == "胡桃":
        hutao_responses.append(item["agent_response"])
    elif item["agent_role_name_en"] == "haruhi":
        haruhi_responses.append(item["agent_response"])

# 出于计算效率，只取 70% 进行计算
ayaka_responses = ayaka_responses[:int(len(ayaka_responses)*0.7)]
zhongli_responses = zhongli_responses[:int(len(zhongli_responses)*0.7)]
hutao_responses = hutao_responses[:int(len(hutao_responses)*0.7)]
haruhi_responses = haruhi_responses[:int(len(haruhi_responses)*0.7)]

# Load PsyDTCorpus Dataset
psydc_part1 = json.loads(PSYDC_PATH1.read_text(encoding="utf-8"))
psydc_part2 = json.loads(PSYDC_PATH2.read_text(encoding="utf-8"))
psydc: list[dict] = psydc_part1 + psydc_part2
psydc_responses: List[str] = []

for item in psydc:
    messages: list[dict[str, str]] = item["messages"]
    for index in range(2, len(messages), 2):
        message = messages[index]
        assert message["role"] == "assistant", message
        psydc_responses.append(message["content"])

# 取 70 %作为训练集
psydc_responses = psydc_responses[:int(len(psydc_responses)*0.7)]

# 输出所有训练集的长度

print(f"Muice Responses: {len(muice_responses)}")
print(f"Ayaka Responses: {len(ayaka_responses)}")
print(f"Zhongli Responses: {len(zhongli_responses)}")
print(f"Hutao Responses: {len(hutao_responses)}")
print(f"Haruhi Responses: {len(haruhi_responses)}")
print(f"PsyDTCorpus Responses: {len(psydc_responses)}")

Muice Responses: 2721
Ayaka Responses: 991
Zhongli Responses: 359
Hutao Responses: 591
Haruhi Responses: 770
PsyDTCorpus Responses: 89652


## Call HanLP for Constituency Parsing

In [2]:
from hanlp_restful import HanLPClient
from hanlp_common.document import Document
from time import sleep
from dotenv import load_dotenv
from os import getenv
import re

load_dotenv()

HanLP = HanLPClient('https://www.hanlp.com/api', auth=getenv("HANLP_API_KEY", None), language='zh')  # type: ignore

def limit_sentense_length(text: str, max_length: int = 150) -> list[str]:
    if len(text) < 150:
        return [text]
    
    # print("Warning: 句子超出了长度上限: ", text)

    # 按照标点符号进行分割
    sentences = re.split(r'([。！？；，,.!?;])', text)
    chunks:list[str] = []
    temp = ''
    for i in range(0, len(sentences), 2):
        chunk = sentences[i] + (sentences[i+1] if i+1 < len(sentences) else '')
        if len(temp) + len(chunk) < max_length:
            temp += chunk
        else:
            chunks.append(temp)
            temp = chunk
    if temp:
        chunks.append(temp)
    return chunks

def split_texts(texts: list[str]) -> list[str]:
    """将文本列表中的每个文本分割成更小的句子"""
    all_sentences = []
    for text in texts:
        sentences = limit_sentense_length(text)
        all_sentences.extend(sentences)
    return all_sentences

def constituency_parsing_safe(texts: list[str], max_batch_num: int = 250, max_chars_per_batch: int = 15000, interval: int = 35) -> List[Document]:
    """对文本进行分词，同时限制每一批总字符数"""
    all_docs = []
    current_batch = []
    current_length = 0
    batch_id = 1

    for text in texts:
        text_len = len(text)

        # 如果加上这个句子会超出限制，则先处理已有批次
        if current_length + text_len > max_chars_per_batch or len(current_batch) + 1 > max_batch_num:
            print(f"Processing batch {batch_id} (Total chars: {current_length})...", end='')
            doc = HanLP.parse(current_batch, tasks=['pos', 'con'])
            all_docs.append(doc)
            print("done.")
            sleep(interval)

            batch_id += 1
            # 重置 batch
            current_batch = [text]
            current_length = text_len
        else:
            current_batch.append(text)
            current_length += text_len

    # 最后一批也别忘记
    if current_batch:
        print(f"Processing batch {batch_id} (Total chars: {current_length})...", end='')
        doc = HanLP.parse(current_batch, tasks=['pos', 'con'])
        all_docs.append(doc)
        print("done.")

    return all_docs

In [6]:
muice_responses_split = split_texts(muice_responses)
ayaka_responses_split = split_texts(ayaka_responses)
zhongli_responses_split = split_texts(zhongli_responses)
hutao_responses_split = split_texts(hutao_responses)
haruhi_responses_split = split_texts(haruhi_responses)
psydc_responses_split = split_texts(psydc_responses)

muice_cons = constituency_parsing_safe(muice_responses_split, 200, interval=1)
ayaka_cons = constituency_parsing_safe(ayaka_responses_split, 200, interval=1)
zhongli_cons = constituency_parsing_safe(zhongli_responses_split, 200, interval=1)
hutao_cons = constituency_parsing_safe(hutao_responses_split, 200, interval=1)
haruhi_cons = constituency_parsing_safe(haruhi_responses_split, 200, interval=1)
psydc_cons = constituency_parsing_safe(psydc_responses_split, 200, interval=1)

Processing batch 1 (Total chars: 9118)...done.
Processing batch 2 (Total chars: 8006)...done.
Processing batch 3 (Total chars: 13423)...done.
Processing batch 4 (Total chars: 12001)...done.


## Archive Raw Results to JSON File

In [None]:
import json

def save_cons_to_json(cons_docs: List[Document], character: str):
    cons_json = [doc.to_dict() for doc in cons_docs]
    file_path = f'./outputs/cons/{character}.json'
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(cons_json, f, ensure_ascii=False)

save_cons_to_json(muice_cons, "muice")
save_cons_to_json(ayaka_cons, "ayaka")
save_cons_to_json(zhongli_cons, "zhongli")
save_cons_to_json(hutao_cons, "hutao")
save_cons_to_json(haruhi_cons, "haruhi")
save_cons_to_json(psydc_cons, "psydc")


## Calculate Log-Likelihood Ratio of Two Styles Based on Psydc

In [None]:
import json
import math
from typing import List, Dict, Tuple, Any, Optional, Literal
from collections import defaultdict, Counter


class PCFGExtractor:
    def __init__(self):
        self.rules_counter: Dict[str, Counter[Tuple[str, ...]]] = defaultdict(Counter)
        self.name: str = ""
        self.total_rules: int = 0

    def load_trees(self, file_path: str) -> List[Dict[str, Any]]:
        self.name = file_path
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)

    def extract_rules_from_tree(self, tree: Any):
        if not isinstance(tree, list) or len(tree) != 2:
            return
        lhs_symbol, rhs = tree
        if isinstance(rhs, list) and all(isinstance(child, list) and len(child) == 2 for child in rhs):
            rhs_symbols = tuple(child[0] for child in rhs)
            self.rules_counter[lhs_symbol][rhs_symbols] += 1
            self.total_rules += 1
            for child in rhs:
                self.extract_rules_from_tree(child)

    def extract_from_data(self, data: List[Dict[str, Any]]):
        for item in data:
            for tree in item.get("con", []):
                self.extract_rules_from_tree(tree)

    def build_pcfg(self) -> Dict[str, Dict[Tuple[str, ...], float]]:
        pcfg_distribution = {}
        for lhs_symbol, rhs_counter in self.rules_counter.items():
            total_count = sum(rhs_counter.values())
            pcfg_distribution[lhs_symbol] = {
                rhs: count / total_count for rhs, count in rhs_counter.items()
            }
        return pcfg_distribution

    def print_pcfg(
        self,
        pcfg: Dict[str, Dict[Tuple[str, ...], float]],
        sort_by: Literal["freq", "prob", "llr"] = 'freq',
        top_k: Optional[int] = None,
        baseline: Optional["PCFGExtractor"] = None,
        eps: float = 1e-5
    ):
        print(f"==={self.name} PCFG 产生式规则（按{'频率' if sort_by == 'freq' else ('对数似然比' if sort_by == 'llr' else '概率')}排序） ===")

        all_rules = []
        for lhs_symbol in self.rules_counter:
            for rhs_symbols in self.rules_counter[lhs_symbol]:
                freq = self.rules_counter[lhs_symbol][rhs_symbols]
                prob = pcfg[lhs_symbol][rhs_symbols]

                # PR / LLR
                pr = llr = None
                if baseline:
                    base_freq = baseline.rules_counter.get(lhs_symbol, {}).get(rhs_symbols, 0)
                    base_total = baseline.total_rules + eps
                    base_prob = base_freq / base_total

                    pr = (prob + eps) / (base_prob + eps)

                    k1, n1 = freq + eps, self.total_rules + eps
                    k2, n2 = base_freq + eps, base_total
                    mu = (k1 + k2) / (n1 + n2)
                    llr = 2 * (k1 * math.log(k1 / (n1 * mu)) + k2 * math.log(k2 / (n2 * mu)))

                all_rules.append((lhs_symbol, rhs_symbols, freq, prob, pr, llr))

        # 排序
        if sort_by == 'llr':
            all_rules.sort(key=lambda x: x[5] or 0, reverse=True)
        else:
            all_rules.sort(key=lambda x: x[2] if sort_by == 'freq' else x[3], reverse=True)

        # 打印
        for i, (lhs, rhs, freq, prob, pr, llr) in enumerate(all_rules):
            if top_k is not None and i >= top_k:
                break
            rhs_str = ' '.join(rhs)
            line = f"{lhs} → {rhs_str:<40} | freq={freq:<5} | P={prob:.4f}"
            if baseline:
                line += f" | PR={pr:.2f} | LLR={llr:.2f}"
            print(line)


def build_and_print_pcfg(file_path: str, baseline: Optional[PCFGExtractor] = None):
    extractor = PCFGExtractor()
    trees_data = extractor.load_trees(file_path)
    extractor.extract_from_data(trees_data)
    pcfg = extractor.build_pcfg()

    if baseline:
        extractor.print_pcfg(pcfg, sort_by='llr', top_k=15, baseline=baseline)
    else:
        extractor.print_pcfg(pcfg, sort_by='freq', top_k=15, baseline=baseline)
    print()

# 1. 先加载基准语料（psydc）
baseline_extractor = PCFGExtractor()
baseline_data = baseline_extractor.load_trees("./outputs/cons/psydc.json")
baseline_extractor.extract_from_data(baseline_data)
baseline_extractor.build_pcfg()  # 可选，但为了接口统一性

# 2. 比较 muice / haruhi 与 psydc 的差异
build_and_print_pcfg("./outputs/cons/muice.json", baseline=baseline_extractor)
build_and_print_pcfg("./outputs/cons/haruhi.json", baseline=baseline_extractor)
build_and_print_pcfg("./outputs/cons/psydc.json")  # 自身基准不做对比


===./outputs/cons/muice_cons.json PCFG 产生式规则（按对数似然比排序） ===
NP → NR                                       | freq=513   | P=0.0305 | PR=498.97 | LLR=3364.83
INTJ → IJ                                       | freq=259   | P=0.8548 | PR=27928.32 | LLR=1756.39
TOP → CP                                       | freq=396   | P=0.1415 | PR=506.86 | LLR=1549.77
VP → VV                                       | freq=2924  | P=0.1439 | PR=5.83 | LLR=670.21
CP → IP SP                                    | freq=1632  | P=0.3664 | PR=30.80 | LLR=585.91
UCP → IP PU CP                                 | freq=61    | P=0.0685 | PR=6695.62 | LLR=497.26
FLR → SP                                       | freq=84    | P=0.4615 | PR=16155.72 | LLR=471.29
DNP → ADJP DEG                                 | freq=108   | P=0.0816 | PR=11.13 | LLR=468.22
FLR → IJ                                       | freq=77    | P=0.4231 | PR=14354.33 | LLR=417.82
PP → P LCP                                    | freq=158   | P=0.1300 | P