In [17]:
import json 
import regex as re

In [18]:
with open('../poems.txt', 'r', encoding='utf-8') as poem_file:
    poems = poem_file.read().replace('\xa0', ' ') #replaces non-breaking spaces as normal spaces 

In [19]:
poems

'Ricardo Reis\nDiana através dos ramos\n\nDiana através dos ramos\nEspreita a vinda de Endymion\nEndymion que nunca vem,\nEndymion, Endymion,\nLá longe na floresta…\nE a sua voz chamando\nExclama através dos ramos\nEndymion, Endymion…\n\nAssim choram os deuses…\n\n\nRicardo Reis\nAqui, sem outro Apolo do que Apolo,\n\nAqui, sem outro Apolo do que Apolo,\nSem um suspiro abandonemos Cristo\n        E a febre de buscarmos\n        Um deus dos dualismos.\n\nE longe da cristã sensualidade\nQue a casta calma da beleza antiga\n        Nos restitua o antigo\n        Sentimento da vida.\n\n\nRicardo Reis\nEm Ceres anoitece.\n\nEm Ceres anoitece.\nNos píncaros ainda\n        Faz luz.\n\nSinto-me tão grande\nNesta hora solene\n        E vã\n\nQue, assim como há deuses\nDos campos, das flores\n        Das searas,\n\n\nAgora eu quisera\nQue um deus existisse\n        De mim.\n\n\nRicardo Reis\nNão a ti, mas aos teus, odeio, Cristo.\n\nNão a ti, mas aos teus, odeio, Cristo.\nTu não és mais que um de

In [20]:
def get_pair_frequency(tokens, counts = None): #checks for the number of token pairs in the text. Returns dictionary with pair: # of occurrences
    counts={}  if counts is None else counts
    for pair in zip(tokens, tokens[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge_pair(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i< len(ids)-1 and pair[0] == ids[i] and pair[1] == ids[i+1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

In [4]:
class Tokenizer():
    def __init__(self):
        self.vocab = {idx: bytes([idx]) for idx in range(256)}
        self.merges = {}
    
    def train(self, vocab_size, text):
        tokens = text.encode('utf-8')
        for i in range(vocab_size - 256):
            pair_frequency = get_pair_frequency(tokens)
            top_pair = max(pair_frequency, key=pair_frequency.get) #returns pair that appear the most
            idx = len(self.vocab)
            print(f'Merging {top_pair} -> {idx}')
            self.merges[top_pair] = idx
            tokens = merge_pair(tokens, top_pair, idx)
            self.vocab[idx] = self.vocab[top_pair[0]] + self.vocab[top_pair[1]]
    
    def encode(self, text):
        tokens = list(text.encode('utf-8'))
        
        while len(tokens) > 1:
            pair_frequency = get_pair_frequency(tokens)
            pair = min(pair_frequency, key = lambda p:self.merges.get(p, float("inf"))) #returns pair that was merged first
            if pair not in self.merges:
                break
        
            idx = self.merges[pair] #get encoded id
            tokens = merge_pair(tokens, pair, idx)

        return tokens
    
    def decode(self, ids):
        tokens = b"".join([self.vocab[idx] for idx in ids]) #ids(list of integers as encoded) -> bytes
        text = tokens.decode("utf-8", errors='replace') #bytes -> characters
        return text


In [5]:
tok = Tokenizer()
tok.train(270, poems)

Merging (101, 32) -> 256
Merging (111, 32) -> 257
Merging (97, 32) -> 258
Merging (115, 32) -> 259
Merging (194, 160) -> 260
Merging (260, 260) -> 261
Merging (101, 114) -> 262
Merging (44, 32) -> 263
Merging (113, 117) -> 264
Merging (109, 32) -> 265
Merging (101, 115) -> 266
Merging (101, 110) -> 267
Merging (10, 10) -> 268
Merging (111, 114) -> 269


In [59]:
tok.decode(tok.encode('o meu nome é Fernandão.'))

'o meu nome é Fernandão.'

In [6]:
GPT_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

class RegexTokenizer(Tokenizer):
    def __init__(self):
        super().__init__() 
        self.pattern = GPT_SPLIT_PATTERN
        self.compiled_pattern = re.compile(self.pattern)
        self.merges = {} # (int, int) -> int
        self.vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes

    def train(self, text, vocab_size, verbose=False):
        
        num_merges = vocab_size - 256

        text_chunks = re.findall(self.compiled_pattern, text)        
        ids = [list(ch.encode("utf-8")) for ch in text_chunks]
        
        for i in range(num_merges):

            stats = {}

            for chunk_ids in ids:
                get_pair_frequency(chunk_ids, stats)

            top_pair = max(stats, key=stats.get) #returns most common pair
            idx = 256 + i
            
            if verbose:
                print(f'Merging {top_pair} into {idx}')
            
            ids = [merge_pair(chunk_ids, top_pair, idx) for chunk_ids in ids]
            
            self.merges[top_pair] = idx
            self.vocab[idx] = self.vocab[top_pair[0]] + self.vocab[top_pair[1]]

    def _encode_chunk(self, chunk_bytes):
        ids = list(chunk_bytes)
        while len(ids)>=2:
            stats = get_pair_frequency(ids)
            pair = min(stats, key=lambda p:self.merges.get(p, float("inf"))) #get the pair that has was merged first
            if pair not in self.merges:
                break #nothing else to merge
            
            idx = self.merges[pair]
            ids = merge_pair(ids, pair, idx)
        
        return ids

    def encode(self, text):
        # chunks encoded separately and then merged together
        text_chunks = re.findall(self.compiled_pattern, text)

        ids = []

        for chunk in text_chunks:
            chunk_bytes = chunk.encode("utf-8")
            chunk_ids = self._encode_chunk(chunk_bytes)
            ids.extend(chunk_ids)
        
        return ids

    def decode(self, ids):
        tokens = b"".join([self.vocab[idx] for idx in ids])
        text = tokens.decode("utf-8", errors='replace') #translates bytes to characters
        return text


In [22]:
toks = RegexTokenizer()
toks.train(poems, 5000, verbose=True)

Merging (32, 32) into 256
Merging (32, 97) into 257
Merging (32, 100) into 258
Merging (32, 101) into 259
Merging (32, 115) into 260
Merging (101, 114) into 261
Merging (32, 109) into 262
Merging (113, 117) into 263
Merging (32, 99) into 264
Merging (111, 115) into 265
Merging (32, 116) into 266
Merging (10, 10) into 267
Merging (111, 114) into 268
Merging (101, 115) into 269
Merging (32, 112) into 270
Merging (97, 115) into 271
Merging (32, 110) into 272
Merging (32, 111) into 273
Merging (110, 116) into 274
Merging (32, 263) into 275
Merging (97, 114) into 276
Merging (100, 111) into 277
Merging (275, 101) into 278
Merging (256, 256) into 279
Merging (195, 163) into 280
Merging (105, 115) into 281
Merging (280, 111) into 282
Merging (258, 101) into 283
Merging (105, 110) into 284
Merging (32, 118) into 285
Merging (101, 109) into 286
Merging (97, 110) into 287
Merging (32, 102) into 288
Merging (100, 97) into 289
Merging (46, 46) into 290
Merging (44, 10) into 291
Merging (111, 109) 

In [23]:
import pickle

with open('toks_vocab_5k.pkl', 'wb') as f:
    pickle.dump(toks.vocab, f)

In [24]:
with open('toks_vocab_5k.pkl', 'rb') as f:
    vocab = pickle.load(f)

print(vocab)

{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[',