# Tokenizer
Need to stick with utf-8 (most efficient)

In [6]:
text = "hello"
list(text.encode("utf-8"))

[104, 101, 108, 108, 111]

In [7]:
japanese = "こんにちは"
list(japanese.encode("utf-8"))

[227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129, 161, 227, 129, 175]

In [31]:
def get_stats(ids,counts=None):
    # get the input as ids and returning the most occured ids
    
    counts = {} if counts is None else counts
    for pair in zip(ids,ids[1:]):
        counts[pair] = counts.get(pair,0)+1
    return counts

In [28]:
def merge(ids, pair, idx):
    # merge the ids with the new idx I guess I guess
    newids = []
    i = 0
    while i < len(ids):
        # if not at the very last position AND the pair matches, replace it
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

In [40]:
class SimpleTokenizer:
    def __init__(self):
        pass

    def train(self,text,vocab_size,verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256
    
        text_bytes = text.encode("utf-8")
        ids = list(text_bytes)
    
        merges = {}
        vocab = {idx:bytes([idx]) for idx in range(256)} # int to bytes
    
        for i in range(num_merges):
            stats = get_stats(ids)
            pair = max(stats,key=stats.get)
    
            idx = 256 + i

            ids = merge(ids,pair,idx)
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
            
        self.merges = merges
        self.vocab = vocab
        
    def encode(self,text):
        text_bytes = text.encode("utf-8")
        ids = list(text_bytes)
        while len(ids) >= 2:

            stats = get_stats(ids)
            pair = min(stats,key=lambda p: self.merges.get(p,float("inf")))

            if pair not in self.merges:
                break

        idx = self.merge[pair]
        ids = merge(ids,pair,idx)
        return ids
        
    def decode(self,ids):
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8",errors="replace")

    def _build_vocab(self):
        # vocab is simply and deterministically derived from merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab


In [41]:
text = open("shake_texts.txt", "r", encoding="utf-8").read()

In [42]:
tokenizer = SimpleTokenizer()
tokenizer.train(text,512,verbose=True)

merge 1/256: (101, 32) -> 256 (b'e ') had 27643 occurrences
merge 2/256: (116, 104) -> 257 (b'th') had 22739 occurrences
merge 3/256: (116, 32) -> 258 (b't ') had 16508 occurrences
merge 4/256: (115, 32) -> 259 (b's ') had 15364 occurrences
merge 5/256: (100, 32) -> 260 (b'd ') had 14165 occurrences
merge 6/256: (44, 32) -> 261 (b', ') had 14098 occurrences
merge 7/256: (111, 117) -> 262 (b'ou') had 12730 occurrences
merge 8/256: (101, 114) -> 263 (b'er') had 11771 occurrences
merge 9/256: (105, 110) -> 264 (b'in') had 10606 occurrences
merge 10/256: (121, 32) -> 265 (b'y ') had 10283 occurrences
merge 11/256: (97, 110) -> 266 (b'an') had 10197 occurrences
merge 12/256: (58, 10) -> 267 (b':\n') had 8762 occurrences
merge 13/256: (111, 114) -> 268 (b'or') had 8458 occurrences
merge 14/256: (111, 32) -> 269 (b'o ') had 8134 occurrences
merge 15/256: (101, 110) -> 270 (b'en') had 7568 occurrences
merge 16/256: (10, 10) -> 271 (b'\n\n') had 7098 occurrences
merge 17/256: (97, 114) -> 272 (

In [None]:
result = tokenizer.encode("You are")
print('encoded->',result)
decoded = tokenizer.decode(result)
print('decoded->',decoded)

In [None]:

    # def save(self, file_prefix):
    #     """
    #     Saves two files: file_prefix.vocab and file_prefix.model
    #     This is inspired (but not equivalent to!) sentencepiece's model saving:
    #     - model file is the critical one, intended for load()
    #     - vocab file is just a pretty printed version for human inspection only
    #     """
    #     # write the model: to be used in load() later
    #     model_file = file_prefix + ".model"
    #     with open(model_file, 'w') as f:
    #         # write the version, pattern and merges, that's all that's needed
    #         f.write("minbpe v1\n")
    #         f.write(f"{self.pattern}\n")
    #         # write the special tokens, first the number of them, then each one
    #         f.write(f"{len(self.special_tokens)}\n")
    #         for special, idx in self.special_tokens.items():
    #             f.write(f"{special} {idx}\n")
    #         # the merges dict
    #         for idx1, idx2 in self.merges:
    #             f.write(f"{idx1} {idx2}\n")
    #     # write the vocab: for the human to look at
    #     vocab_file = file_prefix + ".vocab"
    #     inverted_merges = {idx: pair for pair, idx in self.merges.items()}
    #     with open(vocab_file, "w", encoding="utf-8") as f:
    #         for idx, token in self.vocab.items():
    #             # note: many tokens may be partial utf-8 sequences
    #             # and cannot be decoded into valid strings. Here we're using
    #             # errors='replace' to replace them with the replacement char �.
    #             # this also means that we couldn't possibly use .vocab in load()
    #             # because decoding in this way is a lossy operation!
    #             s = render_token(token)
    #             # find the children of this token, if any
    #             if idx in inverted_merges:
    #                 # if this token has children, render it nicely as a merge
    #                 idx0, idx1 = inverted_merges[idx]
    #                 s0 = render_token(self.vocab[idx0])
    #                 s1 = render_token(self.vocab[idx1])
    #                 f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
    #             else:
    #                 # otherwise this is leaf token, just print it
    #                 # (this should just be the first 256 tokens, the bytes)
    #                 f.write(f"[{s}] {idx}\n")

    # def load(self, model_file):
    #     """Inverse of save() but only for the model file"""
    #     assert model_file.endswith(".model")
    #     # read the model file
    #     merges = {}
    #     special_tokens = {}
    #     idx = 256
    #     with open(model_file, 'r', encoding="utf-8") as f:
    #         # read the version
    #         version = f.readline().strip()
    #         assert version == "minbpe v1"
    #         # read the pattern
    #         self.pattern = f.readline().strip()
    #         # read the special tokens
    #         num_special = int(f.readline().strip())
    #         for _ in range(num_special):
    #             special, special_idx = f.readline().strip().split()
    #             special_tokens[special] = int(special_idx)
    #         # read the merges
    #         for line in f:
    #             idx1, idx2 = map(int, line.split())
    #             merges[(idx1, idx2)] = idx
    #             idx += 1
    #     self.merges = merges
    #     self.special_tokens = special_tokens
    #     self.vocab = self._build_vocab()
    