# Implementacion demostrativa de BPE

In [35]:
import re
from collections import defaultdict, Counter

In [40]:
class BPE:

    def __init__(self):
        self.vocab : Counter= {}
        self.bpe_merges :list = []
        self.token_to_id : dict = {}
        self.id_to_token : dict = {}

    def get_vocab(self, corpus: list[str]):
        """
        Build the initial vocabulary from the input corpus.
        Each word in the corpus is split into characters.
        """
        vocab = Counter()
        for word in corpus:
            word = ' '.join(list(word)) + ' </w>' # end of word token
            vocab[word] += 1
        return vocab

    def get_stats(self, vocab):
        """
        Compute the frequency of character pairs in the vocabulary.
        """
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs

    def merge_vocab(self, pair, vocab):
        """
        Merge the most frequent character pair in the vocabulary.
        """
        bigram = re.escape(' '.join(pair))
        p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        new_vocab = {}
        for word in vocab:
            new_word = p.sub(''.join(pair), word)
            new_vocab[new_word] = vocab[word]
        return new_vocab

    def learn_bpe(self, corpus :list , num_merges=10):
        """
        Learn the BPE merges by iteratively finding the most frequent character pairs.
        """
        vocab = self.get_vocab(corpus)
        for _ in range(num_merges):
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best_pair = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best_pair, vocab)
            self.bpe_merges.append(best_pair)
        
        # Build token_to_id and id_to_token dictionaries
        self.build_token_mappings(vocab)

    def build_token_mappings(self, vocab):
        """
        Build token-to-ID and ID-to-token mappings.
        """
        tokens = sorted(set(token for word in vocab for token in word.split()))
        for idx, token in enumerate(tokens):
            self.token_to_id[token] = idx
            self.id_to_token[idx] = token

    def apply_bpe(self, word):
        """
        Apply the learned BPE merges to a new word.
        """
        word = ' '.join(list(word)) + ' </w>'
        word = word.split()

        for merge in self.bpe_merges:
            while True:
                pair_indices = [(i, i + 1) for i in range(len(word) - 1) if (word[i], word[i + 1]) == merge]
                if not pair_indices:
                    break
                for i, j in pair_indices:
                    word[i:i + 2] = [''.join(merge)]
        return word

    def encode(self, text):
        """
        Encode the text using the learned BPE merges and return token IDs.
        """
        bpe_tokens = self.apply_bpe(text)
        token_ids = [self.token_to_id[token] for token in bpe_tokens]
        return bpe_tokens, token_ids

    def decode(self, token_ids):
        """
        Decode token IDs back into text by using the id_to_token mapping.
        """
        # Get the tokens from token IDs
        tokens = [self.id_to_token[token_id] for token_id in token_ids]
        
        # Join tokens to form the final word
        # Remove the `</w>` which signifies the end of the word
        text = ''.join(tokens).replace('</w>', '')
        return text



# Ejemplo

In [41]:
# Sample corpus
corpus = ["low", "lowest", "newer", "wider"]

# Initialize and learn BPE merges
bpe = BPE()
bpe.learn_bpe(corpus, num_merges=10)

# Print the learned BPE merges
print("BPE Merges:", bpe.bpe_merges)

# Tokenize a new word
word = "low"
tokens, token_ids = bpe.encode(word)

print(f"Word: {word}")
print(f"Tokens: {tokens}")
print(f"Token IDs: {token_ids}")


BPE Merges: [('l', 'o'), ('lo', 'w'), ('e', 'r'), ('er', '</w>'), ('low', '</w>'), ('low', 'e'), ('lowe', 's'), ('lowes', 't'), ('lowest', '</w>'), ('n', 'e')]
Word: low
Tokens: ['low</w>']
Token IDs: [3]


In [42]:
bpe.decode(token_ids)

'low'