## build_vocab_mapping.py

In [28]:
from collections import defaultdict
from transformers import BertTokenizer
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

tokenizer = BertTokenizer.from_pretrained('fnlp/bart-base-chinese')
tokenizer.save_vocabulary('vocab-bart-base-chinese.txt')

from lib import token_id_to_token, token_to_token_id, conv_table, is_alpha_char, is_cjkv

In [29]:
token_new_to_candidate_token_ids = defaultdict(set)

for token_id, token in token_id_to_token.items():
    if not is_cjkv(token):  # non-CJKV
        token_new_to_candidate_token_ids[token].add(token_id)

    else:  # CJKV
        trads = conv_table.get(token)

        if trads is None:
            token_new_to_candidate_token_ids[token].add(token_id)
        elif len(trads) == 1:
            trad = trads
            token_new_to_candidate_token_ids[trad].add(token_id)
        else:
            trad_first, *trad_rests = trads

            # trad_first
            token_new_to_candidate_token_ids[trad_first].add(token_id)

            # trad_rests
            for trad_rest in trad_rests:
                if trad_rest in token_to_token_id:
                    token_id_new = token_to_token_id[trad_rest]
                else:
                    token_id_new = token_id
                token_new_to_candidate_token_ids[trad_rest].add(token_id_new)

In [30]:
def filter_candidate_token_ids(token_new: str, candidate_token_ids: set[int]) -> set[int]:
    # non-CJKV tokens

    if not is_cjkv(token_new):
        return candidate_token_ids

    # CJKV tokens with length of 1

    if len(candidate_token_ids) == 1:
        return candidate_token_ids

    # CJKV tokens with length greater than 1

    candidate_token_ids_new = set()

    for candidate_token_id in candidate_token_ids:
        candidate_token = token_id_to_token[candidate_token_id]

        if not is_alpha_char(token_new) and candidate_token == token_new:
            continue

        candidate_token_ids_new.add(candidate_token_id)

    return candidate_token_ids_new

token_new_to_candidate_token_ids = {
    token_new: filter_candidate_token_ids(token_new, candidate_token_ids)
    for token_new, candidate_token_ids
    in token_new_to_candidate_token_ids.items()
}

In [31]:
token_new_to_token_id_new = []

for token_new, candidate_token_ids in token_new_to_candidate_token_ids.items():
    if len(candidate_token_ids) == 1:
        token_id_new = next(iter(candidate_token_ids))

    elif len(candidate_token_ids) > 1:
        candidate_tokens = [
            token_id_to_token[candidate_token_id]
            for candidate_token_id
            in candidate_token_ids
        ]

        # print(token_new, ''.join(candidate_tokens))  # for generating `preferences`

        preferences = {
            '麼': '么',  # 么麽
            '於': '于',  # 於于
            '夥': '伙',  # 伙夥
            '餘': '余',  # 余馀
            '徵': '征',  # 徵征
            '鍾': '钟',  # 钟锺
            '諮': '咨',  # 咨谘
            '麪': '麺',  # 麺面
            '钁': '钁', 
            '託': '託',
            '穭': '穭', # still raised KeyError
            '線': '线',
            '縕': '縕',
            '蘋': '苹',
            '蘊': '蕴',
            '轀': '辒',
            '醞': '酝',
            '驄': '骢',
        }
        try:
            candidate_token = preferences[token_new]  # guaranteed that `token_new` is always inside `preferences`
            token_id_new = token_to_token_id[candidate_token]
        except KeyError as ke:
            print("KeyError:", ke)

    else:  # len(candidate_token_ids) == 0
        raise ValueError('The length of `candidate_token_ids` should not be zero.')

    token_new_to_token_id_new.append((token_new, token_id_new))

KeyError: '穭'


In [32]:
with open('vocab_mapping.txt', 'w', encoding='utf-8') as f:
    for token_new, token_id_new in token_new_to_token_id_new:
        print(token_new, token_id_new, file=f)

## add_cantonese_tokens.py

In [20]:
from collections import Counter
#from random import Random
from transformers import BertTokenizerFast
import pickle

from lib import token_to_token_id as vocab_old, is_unused, is_cjkv # load_lihkg, 

#rng = Random(42)
# sentences = rng.choices(sentences, k=524288) # why 524288?

sentences = pickle.load(open("../data/wordsegs.pickle", "rb"))
print("Number of Word Segments:", len(sentences))

Number of Word Segments: 462064472


In [4]:
vocab_new = set()

with open('vocab_mapping.txt', encoding='utf-8') as f:
    for line in f:
        token, token_id = line.rstrip('\n').rsplit(' ', 1)
        if is_unused(token):
            print(token)
            continue
        vocab_new.add(token)

print("Number of Word Segments:", len(vocab_new))

Number of Word Segments: 48265


In [22]:
import regex as re

pattern = re.compile(r'[\u4e00-\u9fff]+')

def is_cjkv2(token: str) -> bool:
    return bool(pattern.fullmatch(token))

In [24]:
%%time
import pandas as pd

def check_valid_token(t):
    return (is_cjkv2(t) and t not in vocab_new)

#counter = Counter(sentences.loc[sentences[0].apply(check_valid_token), 0])
counter = Counter(filter(check_valid_token, sentences[0]))

cjkv_new = set((token for token, _ in counter.most_common(15000))) # from 150 to 15000
print("Number of Vocabs:", len(cjkv_new))

Number of Vocabs: 15000
CPU times: user 4min 48s, sys: 13.9 ms, total: 4min 48s
Wall time: 4min 48s


In [16]:
print(sentences)

              0
0             又
1             靠
2           象徵性
3             嘅
4             白
...         ...
462064467     -
462064468    嶺東
462064469  高速公路
462064470    歌謠
462064471     祭

[462064472 rows x 1 columns]


In [25]:
print(counter.most_common(15000))

[('可以', 1490575), ('自己', 1476024), ('吓', 1355171), ('真係', 1287517), ('唔係', 1134929), ('香港', 1122459), ('其實', 856066), ('已經', 792448), ('如果', 787417), ('覺得', 782074), ('唔會', 775420), ('唔好', 759778), ('應該', 728988), ('一個', 727754), ('因為', 669496), ('好多', 642726), ('好似', 620370), ('所以', 553489), ('就係', 538692), ('不過', 534682), ('唔知', 528478), ('大家', 523274), ('係咪', 519137), ('之後', 470919), ('可能', 454299), ('而家', 450557), ('呢個', 426055), ('問題', 410641), ('點解', 405477), ('中國', 352468), ('政府', 350523), ('香港人', 349659), ('有冇', 346096), ('開始', 327002), ('其他', 324744), ('之前', 320613), ('一定', 310819), ('依家', 307827), ('但係', 306678), ('時間', 304205), ('老母', 301987), ('只係', 299584), ('見到', 287022), ('有人', 286300), ('一樣', 283487), ('仲有', 281182), ('仲要', 280428), ('最後', 276360), ('美國', 271282), ('人地', 268698), ('希望', 266130), ('成日', 264736), ('一齊', 258835), ('根本', 256841), ('今日', 255952), ('同埋', 255886), ('支持', 254239), ('就算', 237074), ('手足', 234672), ('仆街', 232291), ('呢啲', 229617), ('我們', 229040), (

In [26]:
print(cjkv_new)

{'金錢', '街上', '距離', '原本', '間諜', '著緊', '多兩', '身位', '學者', '中鋒', '除夕', '農民', '第二天', '支柱', '取決', '上月', '辦事處', '罪名', '慶幸', '行街', '汽油', '刑事', '負債', '觸及', '又是', '牧師', '擺脫', '派對', '偏心', '圍繞', '業界', '同時', '定價', '逃稅', '十點', '不用', '邊套', '安定', '證明', '各地', '要睇', '法拉', '打機', '憐憫', '痾屎', '曖昧', '格式', '估算', '水平', '前世', '至於', '另有', '有希望', '此時', '勇氣', '其餘', '大飛', '反感', '老屎忽', '氣力', '兼且', '票數', '單打', '致死', '西灣河', '成書', '百病', '二號', '意想不到', '阿公', '出發', '手寫', '初時', '安眠藥', '大家姐', '戲份', '偷偷', '社長', '維吾爾人', '流料', '全能', '介乎', '不再', '東莞', '從來不', '風暴', '開燈', '但是', '掌上壓', '趁機', '香水', '發起', '體力', '小馬', '連敗', '影響力', '四萬', '肚腩', '有期徒刑', '只顧', '叮噹', '週年', '兔子', '吐露港', '奧地利', '腰斬', '踢出', '朋友仔', '俾錢', '成身', '細分', '門面', '深刻', '感到', '懷中', '爸打', '股票', '份量', '知名', '挑釁', '海關', '秒鐘', '其他', '路面', '進口', '耐心', '讓路', '幹部', '瀨屎', '撥款', '上心', '大全', '下水', '內疚', '就見到', '扮嘢', '熱線', '常態', '比較', '長命', '豬頭', '風格', '劇集', '淨低', '副主席', '坦克', '朝早', '極品', '夜間', '人民日報', '連接', '靈異', '升溫', '翻譯', '拍片', '屌閪', '笑笑口', '下人', '大中華', '指向', '進出', '湧入', '恐

In [27]:
with open('counter.pickle', 'wb') as c:
    pickle.dump(counter, c)

In [33]:
with open('yue.txt', encoding='utf-8') as f:
    text = f.read()

for c in text:
    if is_cjkv(c) and c not in vocab_new:
        cjkv_new.add(c)

In [34]:
cs = set()
for c in cjkv_new:
    if c not in vocab_old:
        cs.add(c)

In [49]:
with open('add_token.txt', 'w', encoding='utf-8') as f:
    print('\n'.join(sorted(cs)), file=f) # changed to one token each line

## replace_embedding.py

In [54]:
import jax; jax.config.update('jax_platforms', 'cpu')

import jax.numpy as np
from transformers import BertTokenizer, FlaxBartModel

from lib import random_init_embed, save_params, seed2key

tokenizer = BertTokenizer.from_pretrained('fnlp/bart-base-chinese')
model = FlaxBartModel.from_pretrained('fnlp/bart-base-chinese', from_pt=True)

key = seed2key(42)

Some weights of the model checkpoint at fnlp/bart-base-chinese were not used when initializing FlaxBartModel: {('final_logits_bias',), ('decoder', 'embed_tokens', 'kernel'), ('lm_head', 'kernel'), ('encoder', 'embed_tokens', 'kernel')}
- This IS expected if you are initializing FlaxBartModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBartModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [37]:
token_new_to_token_id = {}

with open('vocab_mapping.txt', encoding='utf-8') as f:
    for line in f:
        token_new, token_id = line.rstrip('\n').rsplit(' ', 1)
        token_id = int(token_id)
        token_new_to_token_id[token_new] = token_id

In [50]:
with open('add_token.txt', encoding='utf-8') as f:
    #new_chars = f.read().rstrip('\n')
    new_chars = [line.rstrip() for line in f]

In [51]:
print(new_chars)

['㖭', '㗱', '㗾', '㧬', '㾓', '䏙', '䚗', '䟕', '一一', '一下', '一下子', '一世', '一世人', '一九', '一事', '一事無成', '一二', '一二邊', '一些', '一代', '一任', '一份', '一份子', '一併', '一來', '一係', '一個', '一個一個', '一個中國', '一個二個', '一個人', '一個個', '一個字', '一個對', '一個站', '一個鐘', '一倍', '一億', '一入', '一兩個', '一兩日', '一兩次', '一共', '一再', '一分', '一切', '一刻', '一則', '一千', '一千幾百', '一半', '一口', '一口氣', '一句', '一句話', '一同', '一向', '一味', '一哥', '一啖', '一啲', '一單', '一嘢', '一回', '一回事', '一國兩制', '一團', '一地', '一堂', '一堆', '一堆人', '一場', '一塊', '一大段', '一大班', '一大輪', '一天', '一套', '一如', '一姐', '一字', '一定', '一定會', '一定要', '一家', '一家人', '一家大細', '一實', '一對', '一對一', '一小時', '一層', '一巴', '一帶', '一帶一路', '一年', '一年半', '一年半載', '一年多', '一度', '一廂情願', '一張', '一律', '一心', '一意孤行', '一成', '一戰', '一手', '一打', '一批', '一把', '一招', '一排', '一改', '一文不值', '一方', '一方面', '一於', '一旁', '一族', '一日', '一日到黑', '一旦', '一早', '一早就', '一時', '一時三刻', '一時之間', '一時間', '一會', '一會兒', '一月', '一本', '一架', '一條', '一棟', '一概', '一模一樣', '一樣', '一樣嘢', '一次', '一次又一次', '一次性', '一次過', '一步', '一步一步', '一步步', '一段', '一派', '一流', '一炮過', '一片', '一班', '一班人', '一球', '一生

In [56]:
num_of_unused = sum(1 for token in token_new_to_token_id if token.startswith('[unused'))
assert num_of_unused == 99

new_chars_a = new_chars[:num_of_unused]
new_chars_b = new_chars[num_of_unused:]

In [57]:
emb_old = list(model.params['shared']['embedding'])
emb_new = list(random_init_embed(key, len(new_chars_b)))

In [58]:
token_new_to_emb = {}

unused_idx = 1
new_chars_iter = iter(new_chars_a)

for token_new, token_id in token_new_to_token_id.items():
    if token_new.startswith('[unused'):
        token_id = token_new_to_token_id[f'[unused{unused_idx}]']
        new_char = next(new_chars_iter)
        token_new_to_emb[new_char] = emb_old[token_id]
        unused_idx += 1
    else:
        token_new_to_emb[token_new] = emb_old[token_id]

for i, token_new in enumerate(new_chars_b):
    token_new_to_emb[token_new] = emb_new[i]

In [59]:
vocab = []
emb_all = []

for i, (token_new, emb) in enumerate(token_new_to_emb.items()):
    assert isinstance(token_new, str)
    assert emb.shape[0] == 768
    assert token_new != '\n'
    vocab.append(token_new)
    emb_all.append(emb)

emb_all = np.vstack(emb_all)

In [60]:
with open('vocab-bart-base-cantonese.txt', 'w', encoding='utf-8') as f:
    for token in vocab:
        print(token, file=f)

save_params(emb_all, 'embed_params.dat')

## upload_to_hub.py

In [44]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('vocab-bart-base-cantonese.txt')
tokenizer.push_to_hub('raptorkwok/bart-base-cantonese', private=True)

model = FlaxBartModel.from_pretrained('fnlp/bart-base-chinese', from_pt=True)
model.push_to_hub('raptorkwok/bart-base-cantonese', private=True)


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'BertTokenizer'.
Some weights of the model checkpoint at fnlp/bart-base-chinese were not used when initializing FlaxBartModel: {('final_logits_bias',), ('decoder', 'embed_tokens', 'kernel'), ('lm_head', 'kernel'), ('encoder', 'embed_tokens', 'kernel')}
- This IS expected if you are initializing FlaxBartModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBartModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


CommitInfo(commit_url='https://huggingface.co/raptorkwok/bart-base-cantonese/commit/c1580f0f39489014a15330577ca5502e73c1bbe7', commit_message='Upload tokenizer', commit_description='', oid='c1580f0f39489014a15330577ca5502e73c1bbe7', pr_url=None, pr_revision=None, pr_num=None)

In [61]:
## Inference

from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline
tokenizer = BertTokenizer.from_pretrained('raptorkwok/bart-base-cantonese')
model = BartForConditionalGeneration.from_pretrained('raptorkwok/bart-base-cantonese', from_flax=True) # from_flax=True is required
text2text_generator = Text2TextGenerationPipeline(model, tokenizer)  
output = text2text_generator('聽日就要返香港，我激動到[MASK]唔着', max_length=50, do_sample=False)
print(output[0]['generated_text'].replace(' ', ''))

Some weights of the Flax model were not used when initializing the PyTorch model BartForConditionalGeneration: ['shared.weight', 'encoder.embed_positions.weight', 'encoder.layers.0.self_attn.k_proj.weight', 'encoder.layers.0.self_attn.v_proj.weight', 'encoder.layers.0.self_attn.q_proj.weight', 'encoder.layers.0.self_attn.out_proj.weight', 'encoder.layers.0.self_attn_layer_norm.weight', 'encoder.layers.0.fc1.weight', 'encoder.layers.0.fc2.weight', 'encoder.layers.0.final_layer_norm.weight', 'encoder.layers.1.self_attn.k_proj.weight', 'encoder.layers.1.self_attn.v_proj.weight', 'encoder.layers.1.self_attn.q_proj.weight', 'encoder.layers.1.self_attn.out_proj.weight', 'encoder.layers.1.self_attn_layer_norm.weight', 'encoder.layers.1.fc1.weight', 'encoder.layers.1.fc2.weight', 'encoder.layers.1.final_layer_norm.weight', 'encoder.layers.2.self_attn.k_proj.weight', 'encoder.layers.2.self_attn.v_proj.weight', 'encoder.layers.2.self_attn.q_proj.weight', 'encoder.layers.2.self_attn.out_proj.weig

YY


In [63]:
from transformers import pipeline

test_texts = ["今晚我哋去飲酒！", "再嘈我打爆你個嘴！", "尋晚瞓得唔好", "我哋今日去睇陳奕迅演唱會", "如果你想見陳生，你要事先預約。"]

tokenizer = BertTokenizer.from_pretrained('raptorkwok/bart-base-cantonese')

translator_case1 = pipeline("translation", model=model, tokenizer="raptorkwok/bart-base-cantonese", max_length=50)
for test_text in test_texts:
    print("Translation:", translator_case1(test_text)[0]['translation_text'].replace(' ', ''))
    print("Tokenize:", tokenizer.tokenize(test_text))

Translation: YY
Tokenize: ['今', '晚', '我', '哋', '去', '飲', '酒', '！']
Translation: YY
Tokenize: ['再', '嘈', '我', '打', '爆', '你', '個', '嘴', '！']
Translation: YY
Tokenize: ['尋', '晚', '瞓', '得', '唔', '好']
Translation: YY
Tokenize: ['我', '哋', '今', '日', '去', '睇', '陳', '奕', '迅', '演', '唱', '會']
Translation: YY
Tokenize: ['如', '果', '你', '想', '見', '陳', '生', '，', '你', '要', '事', '先', '預', '約', '。']
