In [None]:
from collections import Counter
from typing import List
import re
import tokenizer
import numpy as np

In [2]:
token_rex = r"\w+|[^\w\s]"

In [3]:
def corpus_common_tokens(texts, num_tokens=30000):
    token_counts = Counter()
    for text in texts:
        for match in re.findall(token_rex, text):
            token_counts[match] += 1
    return set(token for token, count in token_counts.most_common(num_tokens))

tokenizer.test_tokenizer_from_corpus_fn(corpus_common_tokens)

In [4]:
class Tokenizer:
    def __init__(self, token_list):
        self.token_list = token_list
        self.tokens_by_id = {entry['id']: entry['piece'] for entry in token_list}
        self.ids_by_token = {entry['piece']: entry['id'] for entry in token_list}
        assert 3 not in self.tokens_by_id.keys()
        assert '[UNK]' not in self.ids_by_token.keys()
        self.unk_token = '[UNK]'
        self.unk_token_id = 3
        self.tokens_by_id[self.unk_token_id] = self.unk_token
        self.ids_by_token[self.unk_token] = self.unk_token_id
        
    def decode(self, ids: List[int]) -> str:
        return ' '.join(self.tokens_by_id[id] for id in ids)
    
    def tokenize(self, text):
        return [self.ids_by_token.get(token, self.unk_token_id) for token in re.findall(token_rex, text)]

tokenizer.test_tokenizer(Tokenizer)

In [21]:
class Trie:
    def __init__(self, token_id=None, children=None):
        self.token_id = token_id
        self.children = {} if children is None else children
        
    def add_token(self, token, token_id):
        branch = self
        for c in token:
            try:
                branch = branch.children[c]
            except KeyError:
                new_node = Trie()
                branch.children[c] = new_node
                branch = new_node
        branch.token_id = token_id
        
    def invariant(self):
        for c in self.children.values():
            assert c.token_id is not None or c.children
            c.invariant()
        
    def remove_token(self, token):
        branch = self
        parents = []
        for c in token:
            parents.append(branch)
            try:
                branch = branch.children[c]
            except KeyError:
                raise KeyError(token)
        if branch.token_id is None:
            raise KeyError(token)
        branch.token_id = None
        for c in reversed(token):
            if branch.children or branch.token_id is not None:
                break
            branch = parents.pop()
            del branch.children[c]
    
    def __repr__(self):
        return f'Trie(token_id={self.token_id}, children={self.children})'
    
    def all_tokens(self):
        if self.token_id is not None:
            yield ('', self.token_id)
        for (c, child) in self.children.items():
            for token, token_id in child.all_tokens():
                yield (c + token, token_id)
                
    def find(self, text):
        cursor = self
        for c in text:
            try:
                cursor = cursor.children[c]
            except KeyError:
                raise KeyError(text)
        if cursor.token_id is None:
            raise KeyError(text)
        else:
            return cursor.token_id
    
    def find_longest_prefix(self, text, start_pos=0):
        token_found = None
        token_id = self.token_id
        cursor = self
        for i in range(start_pos, len(text)):
            c = text[i]
            try:
                cursor = cursor.children[c]
            except KeyError:
                break
            if cursor.token_id is not None:
                token_found = text[start_pos:i+1]
                token_id = cursor.token_id
        return token_found, token_id
    
trie = Trie()
for token_id, token in enumerate(['abc', 'ab', 'ad', 'ef']):
    trie.add_token(token=token, token_id=token_id)
print(trie.find_longest_prefix('abcdefg', start_pos=4))
trie.remove_token('ef')
trie.invariant()
print(list(trie.all_tokens()))
print(trie.find('ab'))

('ef', 3)
[('ab', 1), ('abc', 0), ('ad', 2)]
1


In [7]:
reference = tokenizer.BPETokenizer.from_corpus(tokenizer.minicorpus)

In [237]:
class BPETokenizer(Tokenizer):
    def __init__(self, token_list):
        self.token_list = token_list
        self.token_trie = Trie()
        self.tokens_by_id = {}
        for token in token_list:
            token_id = token['id']
            piece = token['piece']
            self.token_trie.add_token(token=piece, token_id=token_id)
            self.tokens_by_id[token_id] = piece
            
    def trie_tokenize(self, text):
        i = 0
        tokens = []
        while i < len(text):
            token, token_id = self.token_trie.find_longest(text, start_pos=i)
            tokens.append((token, token_id))
            i += len(token)
        return tokens
    
    def tokenize_involving_list_deletion(self, text):
        tokens = [c for c in text]
        for token in self.tokens_by_id.values():
            i = 0
            while i < len(tokens):
                while i + 1 < len(tokens) and tokens[i] + tokens[i + 1] == token:
                    tokens[i] = token
                    del tokens[i + 1]
                i += 1
        return [self.token_trie.find(token) for token in tokens]
    
    def tokenize(self, text):
        tokens = [c for c in text]
        for token in self.tokens_by_id.values():
            i = 0
            new_tokens = []
            while i < len(tokens):
                if i + 1 < len(tokens) and tokens[i] + tokens[i + 1] == token:
                    new_tokens.append(token)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens
        return [self.token_trie.find(token) for token in tokens]
    
    @classmethod
    def from_corpus(cls, corpus: List[str], max_vocab_size=1000):
        tokens = [[c for c in text] for text in corpus]
        ids_by_token = {}
        tokens_by_id = {}
        fresh_id = 5
        for text in corpus:
            for c in text:
                if c not in ids_by_token:
                    ids_by_token[c] = fresh_id
                    tokens_by_id[fresh_id] = c
                    fresh_id += 1
        pair_counter = Counter()
        while fresh_id < max_vocab_size and any(len(text) > 1 for text in tokens):
            for text in tokens:
                for i in range(len(text) - 1):
                    pair_counter[(text[i], text[i + 1])] += 1
            ((next_pair, next_count),) = pair_counter.most_common(1)
            next_token = ''.join(next_pair)
            ids_by_token[next_token] = fresh_id
            tokens_by_id[fresh_id] = next_token
            fresh_id += 1
            for i in range(len(tokens)):
                text = tokens[i]
                text_with_new_token = []
                j = 0
                while j < len(text):
                    if j + 1 < len(text) and text[j] + text[j + 1] == next_token:
                        text_with_new_token.append(next_token)
                        j += 2
                    else:
                        text_with_new_token.append(text[j])
                        j += 1
                tokens[i] = text_with_new_token
        return cls([{'id': token_id, 'piece': token} for token, token_id in ids_by_token.items()])

# tokenizer.test_bpe_tokenizer(BPETokenizer)
# tokenizer.test_tokenizer_from_corpus(BPETokenizer)

In [190]:
class BPEtokenizerAlex:
    def __init__(self, tokens):
        self.tokens_by_id = {}
        self.ids_by_token = {}
        for token in tokens:
            self.tokens_by_id[token['id']]=token['piece']
            self.ids_by_token[token['piece']]=token['id']
    
    def tokenize(self, string):
        char_ids = np.array([self.ids_by_token[char] for char in string])
        next_token = np.arange(1,len(char_ids)+1)
        live = np.ones(len(char_ids))
        token_locations= {}
        #print(sorted(self.tokens_by_id.keys(),key = lambda x: len(str(self.tokens_by_id[x]))))
        for ids in sorted(self.tokens_by_id.keys(),key = lambda x: len(str(self.tokens_by_id[x]))):
            token = self.tokens_by_id[ids]
            if len(str(token)) == 1:
                #print(token)
                token_locations[ids]=[]
                for i in range(0,len(char_ids)):
                    if char_ids[i]==ids:
                        token_locations[ids].append(i)    
                #print(token_locations[ids])
                continue
            token_locations[ids]=[]
            insert_pts = []
            new_sid_token_loc = {}
            for i in range(1,len(token)):
                stoken1=token[:i]
                stoken2=token[i:]
                if stoken1 not in self.ids_by_token.keys():
                    continue
                if stoken2 not in self.ids_by_token.keys():
                    continue
                sid1=self.ids_by_token[stoken1]
                sid2=self.ids_by_token[stoken2]
                
                new_sid_token_loc[sid1]=[]
                #print(token_locations.keys())
                #print(sid1)
                #print(token_locations[sid1])
                #print(self.tokens_by_id[sid1])
                #print(self.tokens_by_id[ids])
                for token_loc in token_locations[sid1]:
                    insert_pts.append([(token_loc,sid1)])
            insert_pts = sorted(insert_pts)
            for (token_loc,
                    
            if next_token[token_loc]!= len(char_ids) and char_ids[next_token[token_loc]]==sid2 and live[token_loc]==1 and live[next_token[token_loc]]==1:
                        char_ids[token_loc]=ids
                        token_locations[ids].append(token_loc)
                        live[next_token[token_loc]]=0
                        next_token[token_loc]=next_token[next_token[token_loc]]
                    else:
                        new_sid_token_loc.append(token_loc)
                token_locations[sid1]=new_sid_token_loc
        new_tokens=[]
        ptr = 0 
        while ptr != len(char_ids):
            new_tokens.append(char_ids[ptr])
            ptr=next_token[ptr]
            
        return new_tokens
                
                
                
        
        

In [191]:
Atokenizer = BPEtokenizerAlex(reference.token_list)

In [215]:
tokenizer.test_bpe_tokenizer(BPEtokenizerAlex)

AssertionError: 

In [216]:
reference.tokenize("hello, my name is tom trundlewich")

[42, 296, 11, 464, 807, 80, 168, 83, 125, 83, 31, 504, 725, 118, 134]

In [217]:
Atokenizer.tokenize("hello, my name is tom trundlewich")

[671, 96, 11, 464, 807, 80, 168, 83, 125, 83, 31, 504, 725, 118, 134]

In [222]:
Atokenizer.tokens_by_id[42]

'h'