In [1]:
from collections import Counter
from typing import List
import re
import tokenizer

In [2]:
token_rex = r"\w+|[^\w\s]"

In [3]:
def corpus_common_tokens(texts, num_tokens=30000):
    token_counts = Counter()
    for text in texts:
        for match in re.findall(token_rex, text):
            token_counts[match] += 1
    return set(token for token, count in token_counts.most_common(num_tokens))

tokenizer.test_tokenizer_from_corpus_fn(corpus_common_tokens)

In [4]:
class Tokenizer:
    def __init__(self, token_list):
        self.token_list = token_list
        self.tokens_by_id = {entry['id']: entry['piece'] for entry in token_list}
        self.ids_by_token = {entry['piece']: entry['id'] for entry in token_list}
        assert 3 not in self.tokens_by_id.keys()
        assert '[UNK]' not in self.ids_by_token.keys()
        self.unk_token = '[UNK]'
        self.unk_token_id = 3
        self.tokens_by_id[self.unk_token_id] = self.unk_token
        self.ids_by_token[self.unk_token] = self.unk_token_id
        
    def decode(self, ids: List[int]) -> str:
        return ' '.join(self.tokens_by_id[id] for id in ids)
    
    def tokenize(self, text):
        return [self.ids_by_token.get(token, self.unk_token_id) for token in re.findall(token_rex, text)]

tokenizer.test_tokenizer(Tokenizer)

In [6]:
class Trie:
    def __init__(self, token_id=None, children=None):
        self.token_id = token_id
        self.children = {} if children is None else children
        
    def add_token(self, token, token_id):
        branch = self
        for c in token:
            try:
                branch = branch.children[c]
            except KeyError:
                new_node = Trie()
                branch.children[c] = new_node
                branch = new_node
        branch.token_id = token_id
        
    def invariant(self):
        for c in self.children.values():
            assert c.token_id is not None or c.children
            c.invariant()
        
    def remove_token(self, token):
        branch = self
        parents = []
        for c in token:
            parents.append(branch)
            try:
                branch = branch.children[c]
            except KeyError:
                raise KeyError(token)
        if branch.token_id is None:
            raise KeyError(token)
        branch.token_id = None
        for c in reversed(token):
            if branch.children or branch.token_id is not None:
                break
            branch = parents.pop()
            del branch.children[c]
    
    def __repr__(self):
        return f'Trie(token_id={self.token_id}, children={self.children})'
    
    def all_tokens(self):
        if self.token_id is not None:
            yield ('', self.token_id)
        for (c, child) in self.children.items():
            for token, token_id in child.all_tokens():
                yield (c + token, token_id)
    
    def find_longest(self, text, start_pos=0):
        token_found = None
        token_id = self.token_id
        cursor = self
        for i in range(start_pos, len(text)):
            c = text[i]
            try:
                cursor = cursor.children[c]
            except KeyError:
                break
            if cursor.token_id is not None:
                token_found = text[start_pos:i+1]
                token_id = cursor.token_id
        return token_found, token_id
    
trie = Trie()
for token_id, token in enumerate(['abc', 'ab', 'ad', 'ef']):
    trie.add_token(token=token, token_id=token_id)
print(trie.find_longest('abcdefg', start_pos=4))
trie.remove_token('ef')
trie.invariant()
print(list(trie.all_tokens()))

('ef', 3)
[('ab', 1), ('abc', 0), ('ad', 2)]


In [7]:
reference = tokenizer.BPETokenizer.from_corpus(tokenizer.minicorpus)

In [8]:
class BPETokenizer(Tokenizer):
    def __init__(self, tokens):
        self.token_trie = Trie()
        self.tokens_by_id = {}
        for token in tokens:
            token_id = token['id']
            piece = token['piece']
            self.token_trie.add_token(token=piece, token_id=token_id)
            self.tokens_by_id[token_id] = token
            
    def tokenize(self, text):
        i = 0
        tokens = []
        while i < len(text):
            token, token_id = self.token_trie.find_longest(text, start_pos=i)
            tokens.append((token, token_id))
            i += len(token)
        return tokens
    
mine = BPETokenizer(reference.token_list)
(
    reference.tokenize("hello, my name is tom trundlewich"),
    mine.tokens_by_id[894],
    mine.tokenize("hello, my name is tom trundlewich"),
)

([42, 296, 11, 464, 807, 80, 168, 83, 125, 83, 31, 504, 725, 118, 134],
 {'id': 894, 'piece': 'hel'},
 [('hel', 894),
  ('lo', 242),
  (', my ', 464),
  ('nam', 807),
  ('e ', 80),
  ('is ', 119),
  ('to', 503),
  ('m', 49),
  (' tru', 471),
  ('n', 30),
  ('d', 71),
  ('le', 725),
  ('wi', 118),
  ('ch', 134)])