In [3]:
import regex as re
from typing import List, Dict, Tuple
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders
from collections import defaultdict
import heapq
from math import log

In [4]:
def get_syllable(label: str, burmese_consonant: str, others: str) -> List[str]:
    """
    Segment a Burmese word into syllables using regex-based rules.
    
    Args:
        label (str): Input Burmese text or word.
        burmese_consonant (str): Regex range for Burmese consonants.
        others (str): Regex range for other characters (vowels, punctuation, etc.).
    
    Returns:
        List[str]: List of syllables.
    """
    # Define regex patterns for Burmese consonants and other characters
    # label = re.sub(r"(?<![္])(["+burmese_consonant+"])(?![်္|့])|(["+others+"])", r" \1\2", label).strip()
    # label = re.sub('(?<=[က-ၴ])([a-zA-Z0-9])', r' \1', label)
    # label = re.sub('([0-9၀-၉])\s+([0-9၀-၉])\s*', r'\1\2 ', label)
    # label = re.sub('([0-9၀-၉])\s+(\+)', r'\1 \2 ', label)
    # label = label.split()
    label = re.sub(r"(?<![္])([" + burmese_consonant + r"])(?![်္|့])|([" + others + r"])", r" \1\2", label).strip()
    label = re.sub(r"(?<=[က-ၴ])([a-zA-Z0-9])", r" \1", label)
    label = re.sub(r"([0-9၀-၉])\s+([0-9၀-၉])\s*", r"\1\2 ", label)
    label = re.sub(r"([0-9၀-၉])\s+(\+)", r"\1 \2 ", label)
    label = label.split()
    
    return label

def syllable_split(label: str) -> List[str]:
    """
    Split Burmese text into syllables, handling spaces and word boundaries.
    
    Args:
        label (str): Input Burmese text.
    
    Returns:
        List[str]: List of syllables.
    """
    burmese_consonant = 'က-အ'
    others = r"ဣဤဥဦဧဩဪဿ၌၍၏၀-၉၊။!-/:-@[-`{-~\s.,"
    
    label_syllable = [get_syllable(s, burmese_consonant, others) + [' '] for s in label.split()]
    return [s for sublist in label_syllable for s in sublist][:-1]

In [5]:
# Trie implementation for efficient dictionary lookups
class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end = False
        self.token = None
        self.freq = 0.0  # Frequency for scoring

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word: str, freq: float = 1.0):
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.is_end = True
        node.token = word
        node.freq = freq
        # print(f"Inserted token: {word}, Frequency: {freq}")

    def find_all_matches(self, syllables: List[str], start: int) -> List[Tuple[str, float, int]]:
        matches = []
        max_len = min(10, len(syllables) - start)  # Limit max token length (adjust as needed)
        
        for length in range(1, max_len + 1):
            candidate = ''.join(syllables[start:start + length])
            node = self.root
            valid = True
            # Traverse trie character by character
            for char in candidate:
                if char not in node.children:
                    valid = False
                    break
                node = node.children[char]
            if valid and node.is_end:
                matches.append((node.token, node.freq, start + length))
        return matches

In [None]:
class BurmeseTokenizer:
    def __init__(self, dictionary: Dict[str, str], token_freqs: Dict[str, float] = None, bpe_vocab_size: int = 10000):
        """
        Initialize the Burmese tokenizer with a root-and-particle dictionary and optional token frequencies.
        
        Args:
            dictionary (Dict[str, str]): Dictionary mapping words to 'root' or 'particle'.
            token_freqs (Dict[str, float]): Token frequencies for disambiguation (optional).
            bpe_vocab_size (int): Vocabulary size for BPE training.
        """
        self.dictionary = dictionary
        self.bpe_tokenizer = None
        self.bpe_vocab_size = bpe_vocab_size
        # Initialize trie with dictionary
        self.trie = Trie()
        for word in dictionary:
            freq = token_freqs.get(word, 1.0) if token_freqs else 1.0
            self.trie.insert(word, freq)

    def segment_syllables(self, text: str) -> List[str]:
        """Segment Burmese text into syllables using syllable_split."""
        return syllable_split(text)
    
    def maximum_matching(self, syllables: List[str], beam_size: int = 3) -> List[str]:
        """
        Recombine syllables into root words and particles using probability-based maximum matching.
        
        Args:
            syllables (List[str]): List of syllables.
            beam_size (int): Number of segmentations to consider (for disambiguation).
        
        Returns:
            List[str]: List of tokenized root words and particles.
        """
        # Beam search state: (position, tokens, score)
        beam = [(0, [], 0.0)]  # Start with position 0, empty tokens, and score 0.0
        final_segmentations = []

        while beam:
            new_beam = []
            for pos, tokens, score in beam:
                if pos >= len(syllables):
                    # If we reach the end of the syllables, add the segmentation to final results
                    final_segmentations.append((tokens, score))
                    continue

                # Find all possible matches at the current position
                matches = self.trie.find_all_matches(syllables, pos)
                # print(f"Position: {pos}, Matches: {matches}")
                if not matches:
                    # No match: treat as single syllable with low frequency
                    matches = [(syllables[pos], 0.01, pos + 1)]  # Assign a very low frequency for unmatched syllables

                # Extend the beam with each match
                for token, freq, next_pos in matches:
                    new_tokens = tokens + [token]
                    # Score: log-sum of frequencies with a penalty for short tokens
                    new_score = score + log(max(freq, 1e-10)) - 0.01 * len(token)
                    new_beam.append((next_pos, new_tokens, new_score))

            # Keep the top-k segmentations based on score
            beam = heapq.nlargest(beam_size, new_beam, key=lambda x: x[2])

        # Select the best segmentation from the final results
        if final_segmentations:
            best_tokens, _ = max(final_segmentations, key=lambda x: x[1])
            return best_tokens

        # Fallback: return syllables as-is if no valid segmentation is found
        return syllables

    def train_bpe(self, texts: List[str], special_tokens: List[str] = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]):
        """
        Train a BPE tokenizer with dictionary constraints on a raw Burmese text corpus.
        
        Args:
            texts (List[str]): List of raw Burmese texts (e.g., sentences from your dataset).
            special_tokens (List[str]): Special tokens for NLP frameworks.
        """
        self.bpe_tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
        self.bpe_tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
        self.bpe_tokenizer.decoder = decoders.ByteLevel()

        # Pre-tokenize corpus using maximum_matching to protect dictionary tokens
        pretokenized_texts = []
        for text in texts:
            syllables = self.segment_syllables(text)
            tokens = self.maximum_matching(syllables)
            pretokenized_texts.append(' '.join(tokens))

        # Add dictionary tokens as protected tokens
        protected_tokens = list(self.dictionary.keys()) + special_tokens
        trainer = trainers.BpeTrainer(
            vocab_size=self.bpe_vocab_size,
            special_tokens=protected_tokens,
            initial_alphabet=[c for word in self.dictionary.keys() for c in word]
        )
        self.bpe_tokenizer.train_from_iterator(pretokenized_texts, trainer)
        self.bpe_tokenizer.save("burmese_bpe_tokenizer.json")
    
    def tokenize(self, text: str, use_bpe: bool = True) -> Tuple[List[str], List[int]]:
        """
        Tokenize Burmese text into sub-word units.
        
        Args:
            text (str): Input Burmese text.
            use_bpe (bool): Whether to apply BPE after dictionary-based tokenization.
        
        Returns:
            Tuple[List[str], List[int]]: List of tokens and their corresponding IDs.
        """
        syllables = self.segment_syllables(text)
        tokens = self.maximum_matching(syllables)
        if use_bpe and self.bpe_tokenizer:
            encoded = self.bpe_tokenizer.encode(' '.join(tokens))
            return encoded.tokens, encoded.ids
        return tokens, list(range(len(tokens)))
    
    def decode(self, token_ids: List[int]) -> str:
        """Decode token IDs back to Burmese text."""
        if self.bpe_tokenizer:
            return self.bpe_tokenizer.decode(token_ids)
        return ''.join(self.dictionary.get(id, '[UNK]') for id in token_ids)

    def batch_tokenize(self, texts: List[str], use_bpe: bool = True) -> List[Tuple[List[str], List[int]]]:
        """Tokenize a batch of texts for efficiency."""
        return [self.tokenize(text, use_bpe) for text in texts]

# Example Usage
if __name__ == "__main__":
    from math import log

    # Example dictionary with frequencies
    dictionary = {
        "ပညာရေး": "root",
        "ပညာ": "root",
        "ဝန်ကြီးဌာန": "root",
        "ဝန်ကြီး": "root",
        "သည်": "particle",
        "အထက်တန်း": "root",
        "အထက်": "root",
        "ကျောင်း": "root",
        "များ": "particle",
        "တွင်": "particle",
        "သင်ကြား": "root",
        "ရေး": "particle",
        "အတွက်": "particle",
        "အထောက်အကူ": "root",
        "ပြု": "root",
        "စာအုပ်": "root",
        "ကို": "particle",
        "ထုတ်ဝေ": "root",
        "ထုတ်": "root",
        "ဝေ": "root",
        "ခဲ့": "particle",
        "။": "punctuation",
        "၊": "punctuation",
        "ဗိုလ်ချုပ်ကတော်": "root",
        "ဗိုလ်ချုပ်": "root",
        "က": "particle",
        "တော်လှန်": "root",
        "လှန်": "root",
        "တော်": "root",
    }

    # Example token frequencies (precomputed from your corpus)
    token_freqs = {
        "ဗိုလ်ချုပ်ကတော်": 0.01,  # Rare
        "ဗိုလ်ချုပ်": 0.1,           # Common
        "က": 0.5,                   # Very common
        "တော်လှန်": 0.05,          # Moderately common
        "လှန်": 0.02,              # Less common
        "တော်": 0.03,             # Moderately common
        "သည်": 0.4,               # Very common
        # Add frequencies for other tokens
    }

    # Initialize tokenizer
    tokenizer = BurmeseTokenizer(dictionary, token_freqs)

    # Example texts
    text1 = "ပညာရေးဝန်ကြီးဌာနသည် အထက်တန်းကျောင်းများတွင် သင်ကြားမှုအတွက် သင်ကြားရေးအထောက်အကူပြု စာအုပ်များကို ထုတ်ဝေခဲ့သည်။"
    text2 = "တော်လှန်ရေးအထောက်အကူအတွက်သင်ကြားသည်။"
    text3 = "ဗိုလ်ချုပ်ကတော်လှန်သည်"

    # Tokenize without BPE
    tokens_1, token_ids_1 = tokenizer.tokenize(text1, use_bpe=False)
    print("Tokens (no BPE, text1):", tokens_1)
    print("Token IDs (no BPE, text1):", token_ids_1)

    # Tokenize without BPE
    tokens_2, token_ids_2 = tokenizer.tokenize(text2, use_bpe=False)
    print("Tokens (no BPE, text2):", tokens_2)
    print("Token IDs (no BPE, text2):", token_ids_2)

    # Train BPE on your raw Burmese text corpus
    corpus = [text1, text2] * 50  # Replace with your actual raw text corpus
    tokenizer.train_bpe(corpus)

    # Tokenize with BPE
    tokens_1_BPE, token_ids_1_BPE = tokenizer.tokenize(text1, use_bpe=True)
    print("Tokens (with BPE, text1):", tokens_1_BPE)
    print("Token IDs (with BPE, text1):", token_ids_1_BPE)

    # Tokenize with BPE
    tokens_2_BPE, token_ids_2_BPE = tokenizer.tokenize(text2, use_bpe=True)
    print("Tokens (with BPE, text2):", tokens_2_BPE)
    print("Token IDs (with BPE, text2):", token_ids_2_BPE)

    # Decode example
    # decoded_text = tokenizer.decode(token_ids)
    # print("Decoded text:", decoded_text)

    # Batch tokenize example
    # batch_results = tokenizer.batch_tokenize([text1, text2], use_bpe=False)
    # print("Batch tokens (no BPE):", [tokens for tokens, _ in batch_results])

Tokens (no BPE, text1): ['ပညာ', 'ရေး', 'ဝန်ကြီးဌာန', 'သည်', ' ', 'အထက်တန်း', 'ကျောင်း', 'များ', 'တွင်', ' ', 'သင်ကြား', 'မှု', 'အတွက်', ' ', 'သင်ကြား', 'ရေး', 'အထောက်အကူ', 'ပြု', ' ', 'စာအုပ်', 'များ', 'ကို', ' ', 'ထုတ်ဝေ', 'ခဲ့', 'သည်', '။']
Token IDs (no BPE, text1): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
Tokens (no BPE, text2): ['တော်လှန်', 'ရေး', 'အထောက်အကူ', 'အတွက်', 'သင်ကြား', 'သည်', '။']
Token IDs (no BPE, text2): [0, 1, 2, 3, 4, 5, 6]
Tokens (with BPE, text1): ['ပညာ', 'ရေး', 'ဝန်ကြီးဌာန', 'သည်', 'အထက်တန်း', 'ကျောင်း', 'များ', 'တွင်', 'သင်ကြား', 'မှု', 'အတွက်', 'သင်ကြား', 'ရေး', 'အထောက်အကူ', 'ပြု', 'စာအုပ်', 'များ', 'ကို', 'ထုတ်ဝေ', 'ခဲ့', 'သည်', '။']
Token IDs (with BPE, text1): [1, 11, 2, 4, 5, 7, 8, 9, 10, 109, 12, 10, 11, 13, 14, 15, 8, 16, 17, 20, 4, 21]
Tokens (with BPE, text2): ['တော်လှန်', 'ရေး', 'အထောက်အကူ', 'အတွက်', 'သင်ကြား', 'သည်', '။']
Token IDs (with BPE, text2): [26, 11, 13, 12, 10, 4, 21]


In [7]:
# Example syllables
syllables = ["ပညာ", "ရေး", "ဝန်", "ကြီး", "ဌာန", "သည်"]

# Expected output: ["ပညာရေး", "ဝန်ကြီးဌာန", "သည်"]
tokens = tokenizer.maximum_matching(syllables)
print("Recombined tokens:", tokens)

Recombined tokens: ['ပညာရေး', 'ဝန်ကြီးဌာန', 'သည်']


In [8]:
# Example syllables
syllables = ["ဗိုလ်", "ချုပ်", "က", "တော်", "လှန်", "သည်"]

# Expected matches for each position
print("Matches at position 0:", tokenizer.trie.find_all_matches(syllables, 0))
print("Matches at position 1:", tokenizer.trie.find_all_matches(syllables, 1))
print("Matches at position 2:", tokenizer.trie.find_all_matches(syllables, 2))

Matches at position 0: [('ဗိုလ်ချုပ်', 0.1, 2), ('ဗိုလ်ချုပ်ကတော်', 0.01, 4)]
Matches at position 1: []
Matches at position 2: [('က', 0.5, 3)]


In [9]:
syllables = tokenizer.segment_syllables(text2)
print("Syllables:", syllables)

Syllables: ['တော်', 'လှန်', 'ရေး', 'အ', 'ထောက်', 'အ', 'ကူ', 'အ', 'တွက်', 'သင်', 'ကြား', 'သည်', '။']


In [10]:
# Example dictionary
dictionary = {
    "ဗိုလ်ချုပ်": 0.1,
    "က": 0.5,
    "တော်": 0.03,
    "လှန်": 0.02,
    "သည်": 0.4,
}

# Initialize Trie
trie = Trie()
for word, freq in dictionary.items():
    trie.insert(word, freq)

# Example syllables
syllables = ["ဗိုလ်", "ချုပ်", "က", "တော်", "လှန်", "သည်"]

# Test find_all_matches
print("Matches at position 0:", trie.find_all_matches(syllables, 0))
print("Matches at position 1:", trie.find_all_matches(syllables, 1))
print("Matches at position 2:", trie.find_all_matches(syllables, 2))

Matches at position 0: [('ဗိုလ်ချုပ်', 0.1, 2)]
Matches at position 1: []
Matches at position 2: [('က', 0.5, 3)]
