In [1]:
from pathlib import Path
import re
from itertools import product
from typing import List

## Read reference dict

In [2]:
base_dir = Path("/Users/seantyh/Documents/MFA/pretrained_models/dictionary/")
ref_dict = base_dir / "mandarin_taiwan_mfa.dict"

In [3]:
pdict_list = Path(ref_dict).read_text()
pdict_list = pdict_list.strip().split("\n")
pdict_list = [x.split("\t") for x in pdict_list]
pdict_list = [x for x in pdict_list if len(x)>0]

In [4]:
pdict_list[:5]

[['<eps>', '1.0', '0.0', '0.0', '0.0', 'sil'],
 ['<unk>', '0.99', '0.27', '1.86', '0.86', 'spn'],
 ['㐌', '0.99', '0.26', '1.0', '1.0', 'i˧˥'],
 ['㐖', '0.99', '0.26', '1.0', '1.0', 'ɕ j e˧˥'],
 ['㐖毒', '0.99', '0.26', '1.0', '1.0', 'ɕ j e˧˥ t u˧˥']]

## Set mapping rules

In [5]:
mappings = {
    "ʈʂ ʐ̩": "ts z̩",
    "ʈʂʰ ʐ̩": "tsʰ z̩",
    "ʂ ʐ̩": "s z̩",
    "ts z̩": "ʈʂ ʐ̩",
    "tsʰ z̩": "ʈʂʰ ʐ̩",
    "s z̩": "ʂ ʐ̩",
    "ʈʂ ": "ts ",
    "ʈʂʰ ": "tsʰ ",
    "ʂ ": "s ",    
    "ts ": "ʈʂ ",
    "tsʰ ": "ʈʂʰ ",
    "s ": "ʂ "
}

pat = re.compile("|".join(mappings.keys()))

In [6]:
def make_variants(ori_pron: str, matches: List[re.Match], debug=False):
    spans = [x.span() for x in matches] 
    variants = []
    _print = print if debug else lambda *x: ...
    _print("Origin:", ori_pron)
    for comb in product(*[[0, 1]]*len(matches)):    
        # first segment before first match
        new_pron = ori_pron[:spans[0][0]]
        for idx, flag in enumerate(comb):
            ori_phone = matches[idx].group()
            if flag == 0:
                new_pron += ori_phone
            else:
                new_pron += mappings[ori_phone]
            if idx < len(comb)-1:
                # there is a following match
                new_pron += ori_pron[spans[idx][1]:spans[idx+1][0]]
            else:
                # this is the last match
                new_pron += ori_pron[spans[idx][1]:]
        _print(comb, new_pron)
        variants.append(new_pron)
    return variants

In [7]:
make_variants(s:="ʈʂ ə˥˥ n ʂ ʐ̩˧˥", list(pat.finditer(s)))

['ʈʂ ə˥˥ n ʂ ʐ̩˧˥', 'ʈʂ ə˥˥ n s z̩˧˥', 'ts ə˥˥ n ʂ ʐ̩˧˥', 'ts ə˥˥ n s z̩˧˥']

In [8]:
make_variants(s:="ʈʂ ʐ̩˥˩ ʈʂʰ o˧˥ ŋ", list(pat.finditer(s)))

['ʈʂ ʐ̩˥˩ ʈʂʰ o˧˥ ŋ',
 'ʈʂ ʐ̩˥˩ tsʰ o˧˥ ŋ',
 'ts z̩˥˩ ʈʂʰ o˧˥ ŋ',
 'ts z̩˥˩ tsʰ o˧˥ ŋ']

In [9]:
## simple testing
def compare_list(x,y):
    return frozenset(x) == frozenset(y)
assert compare_list(  # 中
            make_variants(s:="ʈʂ u˥˥ ŋ", list(pat.finditer(s)), debug=True),
            ['ʈʂ u˥˥ ŋ', 'ts u˥˥ ŋ'])
assert compare_list( # 蜘蛛
            make_variants(s:="ʈʂ ʐ̩˥˥ ʈʂ u˥˥", list(pat.finditer(s)), debug=True),
            ['ʈʂ ʐ̩˥˥ ʈʂ u˥˥', 'ʈʂ ʐ̩˥˥ ts u˥˥', 'ts z̩˥˥ ʈʂ u˥˥', 'ts z̩˥˥ ts u˥˥'])
assert compare_list( # 中山裝
            make_variants(s:="ʈʂ u˥˥ ŋ ʂ a˥˥ n ʈʂ w a˥˥ ŋ", list(pat.finditer(s)), debug=True),
            ['ʈʂ u˥˥ ŋ ʂ a˥˥ n ʈʂ w a˥˥ ŋ',
             'ʈʂ u˥˥ ŋ ʂ a˥˥ n ts w a˥˥ ŋ',
             'ʈʂ u˥˥ ŋ s a˥˥ n ʈʂ w a˥˥ ŋ',
             'ʈʂ u˥˥ ŋ s a˥˥ n ts w a˥˥ ŋ',
             'ts u˥˥ ŋ ʂ a˥˥ n ʈʂ w a˥˥ ŋ',
             'ts u˥˥ ŋ ʂ a˥˥ n ts w a˥˥ ŋ',
             'ts u˥˥ ŋ s a˥˥ n ʈʂ w a˥˥ ŋ',
             'ts u˥˥ ŋ s a˥˥ n ts w a˥˥ ŋ'])


Origin: ʈʂ u˥˥ ŋ
(0,) ʈʂ u˥˥ ŋ
(1,) ts u˥˥ ŋ
Origin: ʈʂ ʐ̩˥˥ ʈʂ u˥˥
(0, 0) ʈʂ ʐ̩˥˥ ʈʂ u˥˥
(0, 1) ʈʂ ʐ̩˥˥ ts u˥˥
(1, 0) ts z̩˥˥ ʈʂ u˥˥
(1, 1) ts z̩˥˥ ts u˥˥
Origin: ʈʂ u˥˥ ŋ ʂ a˥˥ n ʈʂ w a˥˥ ŋ
(0, 0, 0) ʈʂ u˥˥ ŋ ʂ a˥˥ n ʈʂ w a˥˥ ŋ
(0, 0, 1) ʈʂ u˥˥ ŋ ʂ a˥˥ n ts w a˥˥ ŋ
(0, 1, 0) ʈʂ u˥˥ ŋ s a˥˥ n ʈʂ w a˥˥ ŋ
(0, 1, 1) ʈʂ u˥˥ ŋ s a˥˥ n ts w a˥˥ ŋ
(1, 0, 0) ts u˥˥ ŋ ʂ a˥˥ n ʈʂ w a˥˥ ŋ
(1, 0, 1) ts u˥˥ ŋ ʂ a˥˥ n ts w a˥˥ ŋ
(1, 1, 0) ts u˥˥ ŋ s a˥˥ n ʈʂ w a˥˥ ŋ
(1, 1, 1) ts u˥˥ ŋ s a˥˥ n ts w a˥˥ ŋ


## Main Loop

In [10]:
vardict_list = []
for item_x in pdict_list:    
    ori_pron = item_x[5]
    matches = list(pat.finditer(ori_pron))
    n_matches = len(matches)                  
    if n_matches == 0:
        vardict_list.append(item_x)
    else:
        variants = make_variants(ori_pron, matches)
        word = item_x[0]
        pron_prior = str(round(1/(2**len(matches)),2))
        sil_weight = str(0.5)
        vardict_list.extend([word, pron_prior, sil_weight, sil_weight, sil_weight, var_x]
                             for var_x in variants)

In [11]:
len(pdict_list), len(vardict_list)

(82661, 131379)

## Write file

In [12]:

var_dict_path = base_dir / "mandarin_taiwan_mfa_retrovar.dict"
with var_dict_path.open("w") as f:
    for item_x in vardict_list:
        f.write("\t".join(item_x))
        f.write("\n")

In [13]:
!cd {var_dict_path.parent} && sha1sum {var_dict_path.name}

b853afe049ce2f12fce2818abc615d416fb01584  mandarin_taiwan_mfa_retrovar.dict
