In [None]:
# ============================================
# ASSIGNMENT — BYTE PAIR ENCODING & WORDPIECE
# ============================================
# Author: (Your Name)
# Input file: sentence_tokens.txt
# Requirements: None (pure Python implementation)
# Target: 32,000 merge steps / vocab size

import re
from collections import defaultdict, Counter
from math import log

# --------------------------------------------------
# STEP 1. Load Corpus
# --------------------------------------------------
with open("sentence_tokens.txt", "r", encoding="utf-8") as f:
    corpus = [line.strip() for line in f if line.strip()]

print(f"Loaded {len(corpus)} sentences.")

# --------------------------------------------------
# STEP 2. Prepare data for subword learning
# --------------------------------------------------
def corpus_to_char_tokens(corpus):
    """Split each word into characters separated by spaces, append '</w>' to mark word end."""
    tokenized = []
    for sent in corpus:
        words = sent.strip().split()
        tokenized.append([" ".join(list(word)) + " </w>" for word in words])
    return tokenized

char_level_corpus = corpus_to_char_tokens(corpus)

# Flatten into word list
word_freq = Counter()
for sent in char_level_corpus:
    for word in sent:
        word_freq[word] += 1

print(f"Unique words in corpus: {len(word_freq)}")

# --------------------------------------------------
# STEP 3. Helper functions for BPE
# --------------------------------------------------
def get_pair_freq(word_freq):
    """Count frequency of symbol pairs in the corpus."""
    pairs = defaultdict(int)
    for word, freq in word_freq.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += freq
    return pairs


def merge_vocab(pair, v_in):
    """Merge the most frequent pair in vocabulary."""
    v_out = {}
    bigram = re.escape(" ".join(pair))
    pattern = re.compile(r"(?<!\S)" + bigram + r"(?!\S)")
    for word in v_in:
        w_out = pattern.sub("".join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out


# --------------------------------------------------
# STEP 4. BYTE PAIR ENCODING (BPE)
# --------------------------------------------------
def byte_pair_encoding(word_freq, num_merges=32000):
    merges = []
    vocab = word_freq.copy()

    for i in range(num_merges):
        pairs = get_pair_freq(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)
        merges.append(best)

        if (i + 1) % 1000 == 0:
            print(f"Merge step {i+1}: {best}")

    # Build final vocabulary
    final_vocab = set()
    for word in vocab:
        for token in word.split():
            final_vocab.add(token)
    return merges, final_vocab


print("\nTraining Byte Pair Encoding...")
bpe_merges, bpe_vocab = byte_pair_encoding(word_freq, num_merges=32000)
print(f"BPE training complete. Final vocab size ≈ {len(bpe_vocab)}")

# --------------------------------------------------
# STEP 5. WORDPIECE IMPLEMENTATION
# --------------------------------------------------
def train_wordpiece(corpus, vocab_size=32000):
    """
    Simplified WordPiece training (similar to BERT).
    Starts from character-level vocab and merges pairs based on likelihood.
    """
    # Initialize vocab with characters
    word_freq = Counter()
    for sent in corpus:
        for word in sent.split():
            word_freq[word] += 1

    # Character-level initialization
    vocab = set()
    for word in word_freq:
        for ch in word:
            vocab.add(ch)
    vocab = {ch: 1 for ch in vocab}
    vocab["##UNK"] = 1

    # Create words with start/end tokens
    subwords = {" ".join(list(w)) + " </w>": f for w, f in word_freq.items()}

    merges = []
    while len(vocab) < vocab_size:
        pair_freq = get_pair_freq(subwords)
        if not pair_freq:
            break
        best = max(pair_freq, key=pair_freq.get)
        merges.append(best)
        subwords = merge_vocab(best, subwords)
        new_token = "".join(best)
        vocab[new_token] = 1
        if len(vocab) % 1000 == 0:
            print(f"WordPiece vocab size: {len(vocab)}")

    return merges, vocab


print("\nTraining WordPiece...")
wp_merges, wp_vocab = train_wordpiece(corpus, vocab_size=32000)
print(f"WordPiece training complete. Final vocab size ≈ {len(wp_vocab)}")

# --------------------------------------------------
# STEP 6. Encode a sentence example
# --------------------------------------------------
def encode_with_merges(sentence, merges, end_token="</w>"):
    """Tokenize a new sentence using learned BPE merges."""
    tokens = []
    for word in sentence.split():
        chars = list(word) + [end_token]
        while True:
            pairs = [(chars[i], chars[i + 1]) for i in range(len(chars) - 1)]
            mergeable = [p for p in pairs if p in merges]
            if not mergeable:
                break
            best = mergeable[0]
            new_chars = []
            i = 0
            while i < len(chars):
                if i < len(chars) - 1 and (chars[i], chars[i + 1]) == best:
                    new_chars.append("".join(best))
                    i += 2
                else:
                    new_chars.append(chars[i])
                    i += 1
            chars = new_chars
        tokens.extend(chars)
    return tokens

sample = "Natural language processing"
print("\nEncoding sample with BPE:\n", encode_with_merges(sample.lower(), bpe_merges[:500]))

# --------------------------------------------------
# STEP 7. Save outputs (optional)
# --------------------------------------------------
with open("bpe_vocab.txt", "w", encoding="utf-8") as f:
    for tok in sorted(bpe_vocab):
        f.write(tok + "\n")

with open("wordpiece_vocab.txt", "w", encoding="utf-8") as f:
    for tok in sorted(wp_vocab):
        f.write(tok + "\n")

print("\nSaved bpe_vocab.txt and wordpiece_vocab.txt")


Loaded 144154 sentences.
Unique words in corpus: 91309

Training Byte Pair Encoding...
Merge step 1000: ('मै', 'च</w>')
Merge step 2000: ('सा', 'ं')
Merge step 3000: ('मुस्लि', 'म</w>')
Merge step 4000: ('म', 'क')
Merge step 5000: ('त', 'ें')
Merge step 6000: ('एम', 'सी')
Merge step 7000: ('आकर्', 'षित</w>')
Merge step 8000: ('हा', 'ली</w>')
Merge step 9000: ('स्ट्र', 'ा')
Merge step 10000: ('शे', 'र')
Merge step 11000: ('ना', 'नक</w>')
Merge step 12000: ('पीपु', 'ल्स</w>')
Merge step 13000: ('ढूंढ', '</w>')
Merge step 14000: ('ग', 'पुर</w>')
Merge step 15000: ('बी', 'जापुर</w>')
Merge step 16000: ('किसा', 'नो</w>')
Merge step 17000: ('व', 'ता</w>')
Merge step 18000: ('क', 'म्प</w>')
Merge step 19000: ('बि', 'खरा</w>')
Merge step 20000: ('करो', 'ड़')
Merge step 21000: ('एफ', 'एसएल</w>')
Merge step 22000: ('सेलिब्रि', 'टी</w>')
Merge step 23000: ('औचि', 'त्य</w>')
Merge step 24000: ('सिसौ', 'दिया</w>')
Merge step 25000: ('टू', 'ल')
Merge step 26000: ('कल्याण', 'पुर</w>')
Merge step 2700