In [34]:
from collections import defaultdict
import json
from tqdm import tqdm

def get_scores(splits: dict, corpus: list[str])->dict:
    subword_freq = defaultdict(int)
    pair_freq = defaultdict(int)
    for word in corpus:
        n = len(splits[word])
        for i, c in enumerate(splits[word]):
            subword_freq[c] += 1
            if i + 1 < n:
                pair_freq[(c, splits[word][i + 1])] += 1
    # scores = { (p0, p1): p_f / (subword_freq[p0] * subword_freq[p1]) for (p0, p1), p_f in pair_freq.items() }
    # scores = { (p0, p1): p_f / (subword_freq[p0]) for (p0, p1), p_f in pair_freq.items() }
    scores = { (p0, p1): p_f for (p0, p1), p_f in pair_freq.items() }
    return scores

class WordpieceTokenizer:
    def __init__(self) -> None:
        self.vocab = []
        self.unk_token = '<unk>'

    def load(self, path: str):
        with open(path, 'r') as f:
            self.vocab = json.load(f)

    def train(self, corpus: list[str], vocab_size: int, unk_token: str = '<unk>'):
        """
        corpus: a list of words from pre-tokenization
        """
        self.vocab = [unk_token]
        vocab_set = {unk_token}
        splits = {}
        words = set()
        for word in corpus:
            split = []
            for id, c in enumerate(word):
                if id != 0:
                    c = '##' + c
                if c not in vocab_set:
                    self.vocab.append(c)
                    vocab_set.add(c)
                split.append(c)
            splits[word] = split
            words.add(word)

        # print(splits)

        while len(self.vocab) < vocab_size:
            
            if len(self.vocab) % 100 == 0:
                print(best_p, best_s)
                print(len(self.vocab))

            scores = get_scores(splits, corpus)
            if len(scores) == 0:
                break

            best_p, best_s = '', None
            for p, s in scores.items():
                if best_s is None or s > best_s:
                    best_p = p
                    best_s = s
            if best_s == 1:
                break
            new_subword = best_p[0] + best_p[1][2:]
            self.vocab.append(new_subword)
            
            for word in words:
                split = splits[word]
                new_split = []
                ignore = -1
                for i in range(len(split)):
                    if i == ignore:
                        continue
                    sw = split[i]
                    if i + 1 < len(split) and (sw + split[i + 1][2:] == new_subword):
                        new_split.append(new_subword)
                        ignore = i + 1
                    else:
                        new_split.append(sw)
                splits[word] = new_split       

            # break             

    def tokenize(self, word: str)->list[str]:
        """
        word: a word produced from pre-tokenization
        """
        toks = []
        first = True
        while len(word):
            tok = None
            pos = None
            for i in range(len(word)):
                prefix = '##' + word[:i + 1] if not first else word[:i + 1]
                if prefix in self.vocab:
                    tok = prefix
                    pos = i + 1
            if tok is None:
                return [self.unk_token]
            word = word[pos:]
            toks.append(tok)
            first = False
        return toks


In [35]:
T = WordpieceTokenizer()
data = ['angel', 'of', 'vitality', 'cemetery', 'gatekeeper']
T.train(data, 37)
for w in data:
    print(T.tokenize(w))

['a', '##n', '##g', '##e', '##l']
['o', '##f']
['v', '##it', '##a', '##l', '##it', '##y']
['c', '##e', '##m', '##e', '##te', '##r', '##y']
['g', '##a', '##te', '##k', '##e', '##e', '##p', '##e', '##r']


In [2]:
from dataset.mtgcards import CardName
from torchtext.legacy.data import Field
import spacy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
spacy_en = spacy.load('en_core_web_sm')
spacy_zh = spacy.load("zh_core_web_sm")

def tokenizer_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

def tokenizer_zh(text):
    return [tok.text for tok in spacy_zh.tokenizer(text)]


SRC = Field(tokenize = tokenizer_en, 
                init_token = '<sos>', 
                eos_token = '<eos>', 
                lower = True)
TRG = Field(tokenize = tokenizer_zh, 
                init_token = '<sos>', 
                eos_token = '<eos>', 
                lower = True)
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}
train_data, valid_data, test_data = CardName.splits(fields=fields)

print(f'Number of train data: {len(train_data)}')
print(f'Number of valid data: {len(valid_data)}')
print(f'Number of test data: {len(test_data)}')

Number of train data: 17014
Number of valid data: 447
Number of test data: 449


In [52]:
data = []
for example in train_data:
    for _ in example.src:
        data.append(_)
data = list(set(data))
print(len(data), data[:10])

11945 ['cut', 'zellix', 'nettlevine', 'bandar', 'lethemancer', 'silumgar', 'kargan', 'gnawing', 'flutterfox', 'blindblast']


In [53]:
import random
T.train(random.sample(data, 10000), vocab_size=1000)
print(len(T.vocab))
print(*T.vocab)

('##g', '##e') 206
100
('##b', '##er') 57
200
('##an', '##ce') 31
300
('s', '##ur') 21
400
('c', '##al') 17
500
('##o', '##le') 13
600
('##in', '##er') 11
700
('ar', '##c') 9
800
('o', '##ff') 8
900
1000
<unk> i ##r ##i ##d ##a ##n w ##g ##e ##s ##l ##b ##c ##k t ##t c ##o ##m ##p b h l ##u g s ##y e ##z f r k ##h d a q p m o n x v ##x ##v ##f u ##' ##w ##q j y z ##j ' ##. & ##- ##â ##û , ! ##é ##ü - ##ö ##er ##in ##ar ##on ##or ##an ##at ##en ##st ##al ##ing ##re ##le ##ra ##it ##ed ##ro ##el ##ri ##ou ##es ##is ##ch ##ion ##la ##il ##ic ##et ##ol ##oo ##un ##ad ##ur ##ge ##ter ##sh ##id ##ul st ##ig ##ent ##ak ##ec ##th ##om ##am ##der ##ver ##ow ##us in ##ist ##em ##ir ##ac sh ##ut sp ##um ##ss re ##ers ##ht ##im ##ve ##ag ##ap ##ast th sk ##ce ##ant ##ot ch de ##as ##ood ##ation ##ard sc ##ath ##if ##all ##ous bl un con ar ##and ##ker ##ind ##os ##ate ##per ##ight ##ct ##ire ##av ##de ##ear ##orn ##ell ##age ##op ##ru ##ck ##ble ##orm ##ine ##od ##aw ##og en ##ling ##ph ##se ##ith 

In [61]:
# T.load('result/BPE-10k-10k.json')
# T.vocab = [s for s in T.vocab if len(s) <= 5 or (s[0] == '#' and len(s) <= 10)]
print(len(T.vocab))
for word in random.sample(data, 10):
    print(word, T.tokenize(word))

1000
waste ['w', '##ast', '##e']
rabblemaster ['ra', '##b', '##ble', '##master']
akoum ['ak', '##ou', '##m']
meditation ['med', '##it', '##ation']
heirloom ['he', '##irl', '##oom']
scrounging ['sc', '##ro', '##un', '##ging']
miasma ['m', '##ia', '##sm', '##a']
blightning ['blight', '##ning']
dispeller ['disp', '##eller']
syrix ['sy', '##ri', '##x']


In [62]:
import json
with open('result/BPE-10k-1k.json', 'w') as f:
    json.dump(list(T.vocab), f)