In [1]:
from __future__ import annotations

from collections import Counter, defaultdict
from typing import Dict, Iterable, List, Tuple
from pypinyin import pinyin, Style
import math

import pandas as pd

# =========================
#  DATA LOADING
# =========================

dataframe = pd.read_csv(
    "C:\\Users\\tbhro\\PycharmProjects\\nlp-bpc\\zho_data\\first_part.csv",
    encoding="utf-8",
)
zho_text = dataframe["text"].tolist()
zho_text_cleaned = [text.replace("\r\n", "") for text in zho_text]

# All pinyin-initial tokens you got from your model
TOKEN_PATTERNS =  ['ds', 'zd', 'zy', 'js', 'zs', 'dy', 'zm', 'zg', 'zj', 'ys', 'dd', 'wm', 'bs', 'yg', 'zz', 'ss', 'yy', 'jd', 'ws', 'zl', 'yd', 'hs', 'ns', 'nm', 'ls', 'xx', 'hd', 'ld', 'yb', 'zb', 'xs', 'yj', 'gs', 'nd', 'hl', 'xd', 'wd', 'zh', 'ng', 'td', 'yw', 'zx', 'bd', 'zc', 'tm', 'll', 'ts', 'yx', 'nb', 'zw', 'qs', 'jj', 'my', 'xl', 'zn', 'cl', 'ms', 'cs', 'yh', 'yl', 'jl', 'nn', 'yq', 'jg', 'cd', 'md', 'wl', 'hb', 'jb', 'wb', 'ks', 'qd', 'gd', 'ql', 'hh', 'zt', 'jt', 'rs', 'jh', 'gn', 'zq', 'wx', 'kd', 'lb', 'nh', 'rd', 'fs', 'nx', 'rh', 'tb', 'jx', 'yc', 'wh', 'gl', 'nl', 'qb', 'dsh', 'wy', 'jc', 'wg', 'fx', 'mb', 'yt', 'fd', 'bb', 'rg', 'wj', 'xh', 'ym', 'bl', 'yn', 'wt', 'gx', 'zf', 'zr', 'bx', 'kl', 'zds', 'tl', 'ml', 'bg', 'bh', 'wn', 'es', 'mm', 'zjs', 'nj', 'nk', 'gj', 'zjd', 'ny', 'hy', 'tx', 'zyd', 'hx', 'wq', 'cx', 'zk', 'jy', 'ww', 'za', 'jr', 'gg', 'dl', 'nr', 'tt', 'fl', 'mn', 'cc', 'yk', 'th', 'br', 'bf', 'zzd', 'qn', 'jm', 'jn', 'tg', 'bn', 'yf', 'zp', 'lg', 'wc', 'dss', 'zgs', 'yr', 'dx', 'qc', 'wk', 'qg', 'ze', 'xg', 'fc', 'zys', 'zss', 'ln', 'mg', 'dh', 'qk', 'bj', 'wzd', 'as', 'qx', 'an', 'bc', 'jss', 'lx', 'hn', 'kn', 'tj', 'ps', 'zzy', 'zzs', 'dys', 'hg', 'hj', 'cj', 'mt', 'bq', 'mx', 'cg', 'wmd', 'kk', 'bt', 'pd', 'mh', 'lw', 'bss', 'rw', 'lj', 'ad', 'zzg', 'lh', 'qt', 'yss', 'rn', 'ygs', 'qj', 'gc', 'yds', 'zms', 'gw', 'zdd', 'mw', 'ky', 'lt', 'zmd', 'rj', 'ht', 'qq', 'dyg', 'ly', 'bk', 'wf', 'lm', 'by', 'ne', 'nw', 'wms', 'aw', 'xy', 'tc', 'hc', 'ty', 'qw', 'bw', 'xj', 'xc', 'zbs', 'zzj', 'wsm', 'ddd', 'xf', 'xw', 'zzm', 'qh', 'ysm', 'qy', 'jf', 'ed', 'nzm', 'nt', 'el', 'bzd', 'jk', 'qm', 'nc', 'zls', 'rl', 'zwm', 'yyd', 'xm', 'gt', 'hm', 'cf', 'xb', 'hw', 'jw', 'pl', 'nmd', 'zgd', 'bm', 'rc', 'py', 'sss', 'xq', 'jq', 'nzd', 'nq', 'zsd', 'rb', 'kj', 'dsd', 'rt', 'nms', 'gm', 'gh', 'wjd', 'fm', 'ydy', 'zhd', 'gb', 'gr', 'zhs', 'zdl', 'yjs', 'xn', 'ch', 'xt', 'ysd', 'cq', 'rm', 'dyd', 'lk', 'hk', 'fg', 'zxs', 'zxx', 'zws', 'fq', 'ssd', 'cm', 'cb', 'jsd', 'hq', 'dsq', 'ddx', 'zxd', 'yyg', 'ff', 'zns', 'gq', 'yzd', 'dsj', 'yp', 'tn', 'zzl', 'tk', 'mj', 'zcs', 'mys', 'ya', 'cw', 'ygr', 'yzy', 'zwd', 'zml', 'fh', 'zbd', 'fb', 'fj', 'nds', 'xxd', 'ygd', 'rx', 'tmd', 'zyg', 'mq', 'eq', 'zld', 'ct', 'zhl', 'zze', 'zyq', 'dyx', 'la', 'cn', 'tw', 'ssm', 'bsd', 'njs', 'lc', 'lsd', 'ydd', 'jds', 'ha', 'na', 'zyb', 'wds', 'kx', 'hsd', 'znd', 'wzj', 'tms', 'ba', 'yzs', 'mc', 'hr', 'dse', 'xsd', 'qa', 'nsd', 'ck', 'dsl', 'ybs', 'tr', 'tq', 'nzs', 'fk', 'yys', 'wzs', 'jdd', 'znm', 'pp', 'xxs', 'gsd', 'hzs', 'wsd', 'wjs', 'mf', 'nys', 'jdl', 'xr', 'gf', 'zmb', 'hys', 'hzd', 'ybd', 'nf', 'nzj', 'aq', 'hf', 'kb', 'he', 'cy', 'jzd', 'ngs', 'xzd', 'zts', 'hzy', 'qf', 'ja', 'ztd', 'wss', 'xa', 'nzg', 'wr', 'np', 'hp', 'ddl', 'jzy', 'jsl', 'gy', 'mk', 'ry', 'myd']


# =========================
#  BASIC UTILITIES
# =========================

def is_chinese_char(ch: str) -> bool:
    """Rudimentary check if ch is a CJK Unified Ideograph."""
    return "\u4e00" <= ch <= "\u9fff"


def char_to_initial(ch: str) -> str | None:
    """
    Convert a single Chinese character to the first letter of its pinyin.
    Example: '我' -> 'w', '要' -> 'y'
    Returns None if pinyin can't be obtained.
    """
    if not is_chinese_char(ch):
        return None

    py_list = pinyin(ch, style=Style.NORMAL, strict=False)
    if not py_list or not py_list[0]:
        return None

    syllable = py_list[0][0]  # e.g. 'wo', 'yao', 'shui'
    return syllable[0].lower()  # 'w', 'y', 's', ...


# =========================
#  NON-OVERLAPPING COUNTS
# =========================

def segment_initials_longest(initials: str, patterns: Iterable[str]) -> List[str]:
    """
    Greedy longest-match segmentation over an initials string.

    - initials: e.g. 'ygydz...'
    - patterns: iterable of valid patterns (e.g. TOKEN_PATTERNS)
    Returns: list of segments like ['yg', 'y', 'd', ...]
    """
    initials = initials.strip().lower()
    n = len(initials)
    i = 0
    segments: List[str] = []

    pat_set = set(pat for pat in patterns if pat)
    if not pat_set:
        return list(initials)

    max_len = max(len(p) for p in pat_set)

    while i < n:
        matched = None
        # try longest candidate first
        for L in range(max_len, 0, -1):
            if i + L > n:
                continue
            cand = initials[i:i + L]
            if cand in pat_set:
                matched = cand
                break

        if matched is None:
            # fallback: unknown single letter
            matched = initials[i]

        segments.append(matched)
        i += len(matched)

    return segments


def build_non_overlapping_counts(
    texts: Iterable[str],
    patterns: Iterable[str],
) -> Tuple[Dict[str, Counter], Dict[str, Counter]]:
    """
    Build two frequency tables using non-overlapping segmentation:

      - single_initial_counts: 'y'  -> Counter({'有': ..., '要': ..., '也': ...})
      - pattern_word_counts:   'yg' -> Counter({'一个': ..., '一共': ...})

    Pipeline per line:
      1. Filter out non-Chinese characters.
      2. For remaining characters, compute initials (1 per char).
      3. Segment initials string using greedy longest-match over TOKEN_PATTERNS.
      4. For each segment:
         - length == 1: update single_initial_counts[initial][char]
         - length > 1:  update pattern_word_counts[pattern][multi_char_word]

    This ensures:
      - multi-letter patterns "claim" their positions,
      - single initials are trained only on positions not part of a multi-pattern.
    """
    single_initial_counts: Dict[str, Counter] = defaultdict(Counter)
    pattern_word_counts: Dict[str, Counter] = defaultdict(Counter)

    for text in texts:
        # Keep only Chinese chars, with their initials
        chars: List[str] = []
        initials: List[str] = []

        for ch in text:
            init = char_to_initial(ch)
            if init is None:
                # ignore non-Chinese or un-mappable chars for stats
                continue
            chars.append(ch)
            initials.append(init)

        if not chars:
            continue

        hanzi_seq = "".join(chars)
        initials_seq = "".join(initials)

        # segment initials into patterns
        segs = segment_initials_longest(initials_seq, patterns)
        pos = 0

        for seg in segs:
            L = len(seg)
            if pos + L > len(hanzi_seq):
                break  # safety

            hanzi_chunk = hanzi_seq[pos:pos + L]
            pos += L

            if L == 1:
                ini = seg
                char = hanzi_chunk  # single character
                single_initial_counts[ini][char] += 1
            else:
                pattern_word_counts[seg][hanzi_chunk] += 1

    return single_initial_counts, pattern_word_counts


# =========================
#  SIMPLE PER-LETTER DECODER
# =========================

def build_initial_to_best_char(single_initial_counts: Dict[str, Counter]) -> Dict[str, str]:
    """
    For each initial letter, pick the single most frequent character.
    Example: {'w': Counter({'我': 1234, '问': 200, ...})}
    -> {'w': '我'}
    """
    best: Dict[str, str] = {}
    for initial, counter in single_initial_counts.items():
        if not counter:
            continue
        most_common_char, _ = counter.most_common(1)[0]
        best[initial] = most_common_char
    return best


def decode_initial_sequence(
    initials: str,
    initial_to_best_char: Dict[str, str],
    unknown_char: str = "□",
) -> str:
    """
    Convert an initial-letter sequence (e.g. 'wysj')
    to a Chinese string using a simple per-initial mapping.

    unknown_char: used when an initial is not in the learned mapping.
    """
    result_chars = []
    for letter in initials.lower().strip():
        result_chars.append(initial_to_best_char.get(letter, unknown_char))
    return "".join(result_chars)


# =========================
#  MULTI-PATTERN DECODER (DP)
# =========================

def build_pattern_to_best_word(
    single_initial_counts: Dict[str, Counter],
    pattern_word_counts: Dict[str, Counter],
    min_count_single: int = 1,
    min_count_multi: int = 2,
):
    """
    Combine:
      - single_initial_counts: 'y'  -> Counter of single chars
      - pattern_word_counts:   'yg' -> Counter of multi-char words

    into:
      - pattern_to_word:  'y'  -> '有', 'yg' -> '一个', ...
      - pattern_to_score: log(count(best_candidate))
      - max_pattern_len:  maximum length of any pattern
    """
    pattern_to_word: Dict[str, str] = {}
    pattern_to_score: Dict[str, float] = {}
    max_len = 0

    # single-letter patterns
    for pat, counter in single_initial_counts.items():
        if not counter:
            continue
        best_word, cnt = counter.most_common(1)[0]
        if cnt < min_count_single:
            continue
        pattern_to_word[pat] = best_word
        pattern_to_score[pat] = math.log(cnt)
        max_len = max(max_len, len(pat))

    # multi-letter patterns
    for pat, counter in pattern_word_counts.items():
        if not counter:
            continue
        best_word, cnt = counter.most_common(1)[0]
        if cnt < min_count_multi:
            continue
        pattern_to_word[pat] = best_word
        pattern_to_score[pat] = math.log(cnt)
        max_len = max(max_len, len(pat))

    return pattern_to_word, pattern_to_score, max_len


def decode_initials_dp(
    initials: str,
    pattern_to_word: Dict[str, str],
    pattern_to_score: Dict[str, float],
    max_pattern_len: int,
    unknown_char: str = "□",
) -> str:
    """
    Dynamic programming decoder for initials → Hanzi.

    It chooses the segmentation and sequence of patterns that maximize
    the sum of log-count scores.
    """
    initials = initials.strip().lower()
    n = len(initials)

    best = [-math.inf] * (n + 1)
    back: List[str | None] = [None] * (n + 1)
    best[0] = 0.0

    for i in range(n):
        if best[i] == -math.inf:
            continue
        for L in range(1, max_pattern_len + 1):
            j = i + L
            if j > n:
                break

            pat = initials[i:j]
            if pat not in pattern_to_word:
                continue

            score = best[i] + pattern_to_score[pat]
            if score > best[j]:
                best[j] = score
                back[j] = pat

    if best[n] == -math.inf:
        return unknown_char * n

    # backtracking
    out: List[str] = []
    pos = n
    while pos > 0:
        pat = back[pos]
        if pat is None:   # safeguard
            out.append(unknown_char)
            pos -= 1
            continue
        out.append(pattern_to_word[pat])
        pos -= len(pat)

    out.reverse()
    return "".join(out)



In [None]:
# =========================
#  TRAIN STATS ON YOUR CSV
# =========================

training_texts = zho_text_cleaned

# 1) Build clean, non-overlapping statistics
single_initial_counts, pattern_word_counts = build_non_overlapping_counts(
    training_texts,
    patterns=TOKEN_PATTERNS,
)

# 2) Simple one-letter mapping
initial_to_best_char = build_initial_to_best_char(single_initial_counts)

# 3) Multi-pattern DP decoder tables
pattern_to_word, pattern_to_score, max_pattern_len = build_pattern_to_best_word(
    single_initial_counts,
    pattern_word_counts,
    min_count_single=1,
    min_count_multi=2,
)

# 20 mins


In [4]:
if __name__ == "__main__":
    # Tiny smoke test: use the first line of your corpus
    if training_texts:
        sample_hanzi = training_texts[0]

        # build initials for the Chinese chars only (for decoding test)
        initials_seq = "".join(
            init
            for init in (char_to_initial(ch) for ch in sample_hanzi)
            if init is not None
        )

        print("ORIGINAL HANZI (first 80):", sample_hanzi[:80])
        print("INITIALS         (first 80):", initials_seq[:80])

        baseline = decode_initial_sequence(initials_seq, initial_to_best_char)

        decoded_dp = decode_initials_dp(
            initials=initials_seq,
            pattern_to_word=pattern_to_word,
            pattern_to_score=pattern_to_score,
            max_pattern_len=max_pattern_len,
        )

        print("BASELINE (per-letter)      :", baseline[:80])
        print("DP MULTI-PATTERN DECODER   :", decoded_dp[:80])


ORIGINAL HANZI (first 80): 来自南方都市报播报武汉大学老牌坊被撞损一事持续引发关注 六月八十日 南都记者从武汉市公安局洪山分局一位工作人员处获悉 撞损牌坊的肇事司机因涉嫌过失损毁文物罪被依
INITIALS         (first 80): lznfdsbbbwhdxlpfbzsyscxyfgzlybsrndjzcwhsgajhsfjywgzrychxzspfdzssjysxgsshwwzbyfxs
BASELINE (per-letter)      : 了长你方的是不不不我孩的现了平方不长是一是成现一方个长了一不是人你的就长成我孩是个啊就孩是方就一我个长人一成孩现长是平方的长是是就一是现个是是孩我我长不一方现是
DP MULTI-PATTERN DECODER   : 了长你方的是不不不我孩的现了平方不长是一是成现一方个长了一不是人你的就长成我孩是个啊就孩是方就一我个长人一成孩现长是平方的长是是就一是现个是是孩我我长不一方现是
