In [None]:
import re
from collections import defaultdict, Counter

class BPETokenizer:
    def __init__(self, vocab_size=100):
        self.vocab_size = vocab_size
        self.bpe_merges = []
        self.token_to_id = {}
        self.id_to_token = {}

    def get_stats(self, corpus):
        pairs = defaultdict(int)
        for word in corpus:
            for i in range(len(word)-1):
                pair = (word[i], word[i+1])
                pairs[pair] += 1
        return pairs

    def merge_vocab(self, corpus, pair):
        new_corpus = []
        pattern = re.escape(' '.join(pair))
        regex = re.compile(r'(?<!\S)' + pattern + r'(?!\S)')

        for word in corpus:
            joined = ' '.join(word)
            replaced = regex.sub(''.join(pair), joined)
            new_word = replaced.split()
            new_corpus.append(new_word)
        return new_corpus

    def train(self, text):
        corpus = [list(word) + ['</w>'] for word in text.split()]
        corpus = [tuple(word) for word in corpus]

        for _ in range(self.vocab_size):
            pairs = self.get_stats(corpus)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            self.bpe_merges.append(best)
            corpus = self.merge_vocab(corpus, best)

        # Create vocab and reverse map
        tokens = set()
        for word in corpus:
            tokens.update(word)

        self.token_to_id = {tok: i for i, tok in enumerate(tokens)}
        self.id_to_token = {i: tok for tok, i in self.token_to_id.items()}

    def encode(self, text):
        tokens = []
        for word in text.split():
            word = list(word) + ['</w>']
            for pair in self.bpe_merges:
                i = 0
                while i < len(word) - 1:
                    if word[i] == pair[0] and word[i+1] == pair[1]:
                        word[i:i+2] = [''.join(pair)]
                    else:
                        i += 1
            tokens.extend(word)
        return [self.token_to_id[token] for token in tokens if token in self.token_to_id]

    def decode(self, ids):
        tokens = [self.id_to_token[i] for i in ids]
        return ''.join([t.replace('</w>', ' ') for t in tokens]).strip()
        
        
 # Test
bpe = BPETokenizer(vocab_size=50)
bpe.train("low lowest lower lowest low")

encoded = bpe.encode("lowest low")
print("Encoded IDs:", encoded) 
'''
Encoded IDs: [1, 2]
'''

decoded = bpe.decode(encoded)
print("Decoded Text:", decoded)
'''
Decoded Text: lowest low
'''