In [29]:
import torch
from torch import nn
from typing import List

In [39]:
class BPETokenizer:
    def __init__(self):
        self.itos = {}
        self.stoi = {}
        self.vocab_size = 0
        self.max_key = 0
        self.merges = {}
    
    def train(self, text_corpus: str, vocab_size: int):
        self.vocab_size = vocab_size
        text_corpus_set = set(text_corpus)
        text_corpus_unique = sorted(text_corpus_set)

        for ch in text_corpus_unique:
            self.stoi[ch] = self.max_key
            self.itos[self.max_key] = ch
            self.max_key += 1

        text_idx = [self.stoi[ch] for ch in text_corpus]

        # i = 0
        while self.max_key < vocab_size:
            # if i == 1:
            #     break
            pair_count = {}
            for i1, i2 in zip(text_idx, text_idx[1:]):
                pair_count[(i1, i2)] = pair_count.get((i1, i2), 0) + 1

            max_pair = max(pair_count, key=lambda pair: pair_count[pair])
            print(max_pair) # DEBUG
            new_text_idx = []

            i = 0
            while i < len(text_idx)-1:
                if (text_idx[i], text_idx[i+1]) == max_pair:
                    new_text_idx.append(self.max_key)
                    i += 2
                else:
                    new_text_idx.append(text_idx[i])
                    i += 1
            
            merged_pair = self.itos[max_pair[0]] + self.itos[max_pair[1]]
            self.itos[self.max_key] = merged_pair
            self.stoi[merged_pair] = self.max_key

            self.merges[max_pair] = self.max_key
            self.max_key += 1
            text_idx = new_text_idx

            # print(pair_count)
            # print(max(pair_count.items(), key=lambda item: item[1]))
            # i += 1

        self.final_vocab = sorted(self.stoi.items(), key=lambda item: len(item[0]), reverse=True)
    
    def encode(self, text: str):
        encoded_idx = []
        rem_text = text

        while rem_text:
            for substr, subkey in self.final_vocab:
                if rem_text.startswith(substr):
                    encoded_idx.append(subkey)
                    rem_text = rem_text[len(substr):]
                    break
        
        return encoded_idx

    def decode(self, idx: List[int]):
        return "".join(self.itos[i] for i in idx)


In [40]:
sample_text = "The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn’t think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley’s sister, but they hadn’t met for several years; in fact, Mrs. Dursley pretended she didn’t have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn’t want Dudley mixing with a child like that."

In [41]:
tok = BPETokenizer()
tok.train(sample_text, 70)

(26, 16)
(0, 33)
(13, 24)
(13, 0)
(12, 0)
(13, 31)
(25, 0)
(24, 25)
(26, 0)
(38, 0)
(17, 21)
(3, 0)
(26, 35)
(5, 27)
(16, 9)
(29, 9)
(1, 0)
(26, 34)
(0, 25)
(46, 40)
(52, 19)
(22, 27)
(8, 16)
(53, 38)
(24, 13)
(9, 21)
(17, 25)
(7, 22)
(60, 26)
(61, 45)
(47, 37)
(13, 28)
(49, 10)
(24, 0)
(35, 0)
(44, 55)
(36, 62)


In [43]:
tok.final_vocab

[('e Potter', 69),
 ('Dursley', 56),
 ('Potter', 62),
 ('Dursl', 53),
 ('t th', 50),
 ('Durs', 52),
 ('had ', 63),
 ('. Th', 68),
 (' th', 34),
 ('ey ', 42),
 ('ter', 45),
 ('Pot', 61),
 (', b', 65),
 ('er ', 67),
 ('th', 33),
 ('er', 35),
 ('e ', 36),
 ('d ', 37),
 ('ey', 38),
 ('s ', 39),
 ('rs', 40),
 ('t ', 41),
 ('in', 43),
 ('. ', 44),
 ('Du', 46),
 ('ha', 47),
 ('wa', 48),
 (', ', 49),
 (' s', 51),
 ('ou', 54),
 ('Th', 55),
 ('re', 57),
 ('an', 58),
 ('is', 59),
 ('Po', 60),
 ('ev', 64),
 ('r ', 66),
 (' ', 0),
 (',', 1),
 ('-', 2),
 ('.', 3),
 (';', 4),
 ('D', 5),
 ('M', 6),
 ('P', 7),
 ('T', 8),
 ('a', 9),
 ('b', 10),
 ('c', 11),
 ('d', 12),
 ('e', 13),
 ('f', 14),
 ('g', 15),
 ('h', 16),
 ('i', 17),
 ('k', 18),
 ('l', 19),
 ('m', 20),
 ('n', 21),
 ('o', 22),
 ('p', 23),
 ('r', 24),
 ('s', 25),
 ('t', 26),
 ('u', 27),
 ('v', 28),
 ('w', 29),
 ('x', 30),
 ('y', 31),
 ('’', 32)]

In [44]:
tok.encode("hello harry potter, my name is Salazar Slytherin!")

KeyboardInterrupt: 

In [42]:
d = {('T', 'h'): 5, ('h', 'e'): 21, ('e', ' '): 16, (' ', 'D'): 6, ('D', 'u'): 7, ('u', 'r'): 6, ('r', 's'): 15, ('s', 'l'): 6, ('l', 'e'): 8, ('e', 'y'): 14, ('y', 's'): 3, ('s', ' '): 15, (' ', 'h'): 10, ('h', 'a'): 10, ('a', 'd'): 5, ('d', ' '): 16, (' ', 'e'): 2, ('e', 'v'): 4, ('v', 'e'): 7, ('e', 'r'): 17, ('r', 'y'): 1, ('y', 't'): 1, ('t', 'h'): 22, ('h', 'i'): 7, ('i', 'n'): 8, ('n', 'g'): 4, ('g', ' '): 4, (' ', 't'): 21, ('y', ' '): 12, (' ', 'w'): 11, ('w', 'a'): 7, ('a', 'n'): 7, ('n', 't'): 2, ('t', 'e'): 11, ('e', 'd'): 4, ('d', ','): 1, (',', ' '): 7, (' ', 'b'): 7, ('b', 'u'): 3, ('u', 't'): 5, ('t', ' '): 17, (' ', 'a'): 14, ('a', 'l'): 3, ('l', 's'): 1, ('s', 'o'): 4, ('o', ' '): 3, ('a', ' '): 4, (' ', 's'): 13, ('s', 'e'): 4, ('e', 'c'): 2, ('c', 'r'): 1, ('r', 'e'): 7, ('e', 't'): 4, ('t', ','): 2, ('n', 'd'): 5, ('e', 'i'): 2, ('i', 'r'): 1, ('r', ' '): 12, (' ', 'g'): 3, ('g', 'r'): 1, ('e', 'a'): 5, ('a', 't'): 5, ('e', 's'): 1, ('s', 't'): 5, (' ', 'f'): 5, ('f', 'e'): 1, ('a', 'r'): 4, ('a', 's'): 7, ('o', 'm'): 1, ('m', 'e'): 2, ('e', 'b'): 1, ('b', 'o'): 4, ('o', 'd'): 3, ('d', 'y'): 1, ('w', 'o'): 2, ('o', 'u'): 6, ('u', 'l'): 3, ('l', 'd'): 4, (' ', 'd'): 4, ('d', 'i'): 4, ('i', 's'): 6, ('s', 'c'): 1, ('c', 'o'): 2, ('o', 'v'): 1, (' ', 'i'): 7, ('i', 't'): 4, ('t', '.'): 3, ('.', ' '): 8, (' ', 'T'): 4, ('i', 'd'): 3, ('d', 'n'): 4, ('n', '’'): 4, ('’', 't'): 4, ('n', 'k'): 2, ('k', ' '): 2, (' ', 'c'): 2, ('b', 'e'): 3, ('i', 'f'): 2, ('f', ' '): 2, ('n', 'y'): 1, ('y', 'o'): 1, ('o', 'n'): 3, ('n', 'e'): 4, ('f', 'o'): 4, ('u', 'n'): 2, (' ', 'o'): 1, ('a', 'b'): 1, (' ', 'P'): 5, ('P', 'o'): 5, ('o', 't'): 7, ('t', 't'): 5, ('s', '.'): 4, (' ', 'M'): 3, ('M', 'r'): 3, ('y', '’'): 1, ('’', 's'): 1, ('s', 'i'): 4, ('r', ','): 2, (' ', 'm'): 2, ('o', 'r'): 4, ('r', 'a'): 1, ('l', ' '): 2, (' ', 'y'): 1, ('y', 'e'): 1, ('s', ';'): 1, (';', ' '): 2, ('n', ' '): 5, ('f', 'a'): 1, ('a', 'c'): 1, ('c', 't'): 1, (' ', 'p'): 2, ('p', 'r'): 1, ('e', 'n'): 3, ('d', 'e'): 2, ('s', 'h'): 3, ('a', 'v'): 1, ('c', 'a'): 1, ('a', 'u'): 1, ('u', 's'): 2, ('g', 'o'): 2, ('o', 'o'): 3, ('d', '-'): 1, ('-', 'f'): 1, ('r', '-'): 1, ('-', 'n'): 1, ('n', 'o'): 2, ('h', 'u'): 2, ('s', 'b'): 1, ('b', 'a'): 1, ('w', 'e'): 1, (' ', 'u'): 1, ('n', 'D'): 1, ('y', 'i'): 1, ('h', ' '): 2, ('p', 'o'): 1, ('o', 's'): 1, ('s', 's'): 1, ('i', 'b'): 1, ('b', 'l'): 1, ('t', 'o'): 3, ('e', '.'): 1, ('u', 'd'): 2, ('d', 'd'): 1, ('w', 'h'): 1, (' ', 'n'): 2, ('i', 'g'): 1, ('g', 'h'): 1, ('h', 'b'): 1, ('s', 'a'): 1, ('a', 'y'): 2, ('r', 'r'): 1, ('r', 'i'): 1, ('i', 'v'): 1, ('t', 'r'): 1, ('e', 'e'): 3, (' ', 'k'): 2, ('k', 'n'): 1, ('e', 'w'): 1, ('w', ' '): 1, ('s', 'm'): 1, ('m', 'a'): 1, ('l', 'l'): 1, ('n', ','): 1, ('o', ','): 1, ('i', 'm'): 1, ('m', '.'): 1, ('o', 'y'): 1, (' ', 'r'): 1, ('k', 'e'): 2, ('e', 'p'): 1, ('p', 'i'): 1, ('a', 'w'): 1, ('y', ';'): 1, ('d', 'l'): 1, ('m', 'i'): 1, ('i', 'x'): 1, ('x', 'i'): 1, ('w', 'i'): 1, ('c', 'h'): 1, ('i', 'l'): 1, (' ', 'l'): 1, ('l', 'i'): 1, ('i', 'k'): 1}

max(d, key=lambda item: d[item])

('t', 'h')