In [1]:
import re
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Optional
import json
import pickle

In [2]:
def get_syllable(label: str, burmese_consonant: str, others: str) -> List[str]:
    """
    Segment a Burmese word into syllables using regex-based rules.
    
    Args:
        label (str): Input Burmese text or word.
        burmese_consonant (str): Regex range for Burmese consonants.
        others (str): Regex range for other characters (vowels, punctuation, etc.).
    
    Returns:
        List[str]: List of syllables.
    """
    # Define regex patterns for Burmese consonants and other characters
    # label = re.sub(r"(?<![္])(["+burmese_consonant+"])(?![်္|့])|(["+others+"])", r" \1\2", label).strip()
    # label = re.sub('(?<=[က-ၴ])([a-zA-Z0-9])', r' \1', label)
    # label = re.sub('([0-9၀-၉])\s+([0-9၀-၉])\s*', r'\1\2 ', label)
    # label = re.sub('([0-9၀-၉])\s+(\+)', r'\1 \2 ', label)
    # label = label.split()
    label = re.sub(r"(?<![္])([" + burmese_consonant + r"])(?![်္|့])|([" + others + r"])", r" \1\2", label).strip()
    label = re.sub(r"(?<=[က-ၴ])([a-zA-Z0-9])", r" \1", label)
    label = re.sub(r"([0-9၀-၉])\s+([0-9၀-၉])\s*", r"\1\2 ", label)
    label = re.sub(r"([0-9၀-၉])\s+(\+)", r"\1 \2 ", label)
    label = label.split()
    
    return label

def syllable_split(label: str) -> List[str]:
    """
    Split Burmese text into syllables, handling spaces and word boundaries.
    
    Args:
        label (str): Input Burmese text.
    
    Returns:
        List[str]: List of syllables.
    """
    burmese_consonant = 'က-အ'
    others = r"ဣဤဥဦဧဩဪဿ၌၍၏၀-၉၊။!-/:-@[-`{-~\s.,"
    
    label_syllable = [get_syllable(s, burmese_consonant, others) + [' '] for s in label.split()]
    return [s for sublist in label_syllable for s in sublist][:-1]

In [3]:
class SyllableBPE:
    def __init__(self):
        self.vocab = {}  # Maps tokens to IDs
        self.inverse_vocab = {}  # Maps IDs to tokens
        self.merges = []  # List of (token1, token2) merge rules
        self.word_end_token = "</w>"  # Special token for word boundaries

    def segment_to_syllables(self, text: str) -> List[str]:
        """
        Placeholder for your syllable segmentation function.
        Replace this with your actual implementation.
        """
        # Example: Dummy segmentation (replace with your function)
        # In reality, this should split Burmese text into syllables
        return syllable_split(text)  # Replace with your syllable splitter

    def preprocess_text(self, text: str) -> List[List[str]]:
        """
        Preprocess text: split into words, segment into syllables, and add word-end token.
        """
        words = text.strip().split()
        syllablized_words = []
        for word in words:
            syllables = self.segment_to_syllables(word)
            if syllables:
                syllables[-1] += self.word_end_token  # Add </w> to last syllable
                syllablized_words.append(syllables)
        return syllablized_words

    def get_pair_frequencies(self, words: List[List[str]]) -> Dict[Tuple[str, str], int]:
        """
        Compute frequency of adjacent syllable pairs in the corpus.
        """
        pairs = defaultdict(int)
        for word in words:
            for i in range(len(word) - 1):
                pair = (word[i], word[i + 1])
                pairs[pair] += 1
        return pairs

    def merge_pair(self, pair: Tuple[str, str], words: List[List[str]]) -> List[List[str]]:
        """
        Merge the given syllable pair in all words and return updated words.
        """
        new_token = pair[0] + pair[1].replace(self.word_end_token, "") + (self.word_end_token if pair[1].endswith(self.word_end_token) else "")
        new_words = []
        for word in words:
            new_word = []
            i = 0
            while i < len(word):
                if i < len(word) - 1 and (word[i], word[i + 1]) == pair:
                    new_word.append(new_token)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_words.append(new_word)
        return new_words

    def train(self, texts: List[str], num_merges: int, vocab_size: Optional[int] = None):
        """
        Train the BPE model on a list of texts.
        """
        # Preprocess all texts into syllable-segmented words
        all_words = []
        for text in texts:
            all_words.extend(self.preprocess_text(text))

        # Initialize vocabulary with unique syllables
        syllable_counts = Counter()
        for word in all_words:
            syllable_counts.update(word)
        self.vocab = {syllable: idx for idx, syllable in enumerate(syllable_counts.keys())}
        self.inverse_vocab = {idx: syllable for syllable, idx in self.vocab.items()}
        next_token_id = len(self.vocab)

        # Perform merges
        for merge_iter in range(num_merges):
            pairs = self.get_pair_frequencies(all_words)
            if not pairs:
                break
            most_frequent_pair = max(pairs, key=pairs.get)
            all_words = self.merge_pair(most_frequent_pair, all_words)
            self.merges.append(most_frequent_pair)

            # Add new token to vocabulary
            new_token = most_frequent_pair[0] + most_frequent_pair[1].replace(self.word_end_token, "") + (
                self.word_end_token if most_frequent_pair[1].endswith(self.word_end_token) else ""
            )
            self.vocab[new_token] = next_token_id
            self.inverse_vocab[next_token_id] = new_token
            next_token_id += 1

            if vocab_size and len(self.vocab) >= vocab_size:
                break
    
    def encode(self, text: str) -> List[int]:
        """
        Encode text into token IDs.
        """
        words = self.preprocess_text(text)
        encoded = []
        for word in words:
            while word:
                # Try to find the longest matching token
                best_match = word[0]
                best_match_len = 1
                for i in range(1, len(word) + 1):
                    candidate = "".join(word[:i]).replace(self.word_end_token, "") + (
                        self.word_end_token if word[i - 1].endswith(self.word_end_token) else ""
                    )
                    if candidate in self.vocab:
                        best_match = candidate
                        best_match_len = i
                encoded.append(self.vocab[best_match])
                word = word[best_match_len:]
        return encoded

    def decode(self, token_ids: List[int]) -> str:
        """
        Decode token IDs back to text.
        """
        tokens = [self.inverse_vocab[tid] for tid in token_ids]
        text = "".join(tokens).replace(self.word_end_token, " ")
        return text.strip()

    def save(self, prefix: str):
        """
        Save vocabulary and merge rules to files.
        """
        with open(f"{prefix}_vocab.json", "w", encoding="utf-8") as f:
            json.dump(self.vocab, f, ensure_ascii=False)
        with open(f"{prefix}_merges.pkl", "wb") as f:
            pickle.dump(self.merges, f)

    def load(self, prefix: str):
        """
        Load vocabulary and merge rules from files.
        """
        with open(f"{prefix}_vocab.json", "r", encoding="utf-8") as f:
            self.vocab = json.load(f)
        with open(f"{prefix}_merges.pkl", "rb") as f:
            self.merges = pickle.load(f)
        self.inverse_vocab = {int(idx): token for token, idx in self.vocab.items()}
        self.vocab = {token: int(idx) for token, idx in self.vocab.items()}

In [6]:
# Example usage
if __name__ == "__main__":
    # Sample Burmese text (replace with actual Burmese text)
    sample_texts = [
        "ပညာရေးဝန်ကြီးဌာနသည် အထက်တန်းကျောင်းများတွင် သင်ကြားမှုအတွက် သင်ကြားရေးအထောက်အကူပြု စာအုပ်များကို ထုတ်ဝေခဲ့သည်။",
        "တော်လှန်ရေးအထောက်အကူအတွက်သင်ကြားသည်။",
        "ဗိုလ်ချုပ်က တော်ကောက်သည်"
    ]

    # Initialize and train the model
    bpe = SyllableBPE()
    bpe.train(sample_texts, num_merges=10)

    # Encode a new text
    test_text = "သည် အထက်တန်းကျောင်းတွင် သင်ကြားသည်"
    encoded = bpe.encode(test_text)
    print("Encoded:", encoded)

    # Decode back to text
    decoded = bpe.decode(encoded)
    print("Decoded:", decoded)

    # Save the model
    bpe.save("burmese_bpe")

    # Load the model
    new_bpe = SyllableBPE()
    new_bpe.load("burmese_bpe")
    print("Loaded model vocab size:", len(new_bpe.vocab))

Encoded: [7, 8, 9, 10, 11, 13, 36, 7]
Decoded: သည် အထက်တန်းကျောင်းတွင် သင်ကြားသည်
Loaded model vocab size: 46
