In [3]:
class TrieNode:
    def __init__(self):
        self.children = {}  # Maps character to TrieNode
        self.is_end_of_word = False  # Marks the end of a word

class FSTTrie:
    def __init__(self):
        self.root = TrieNode()
    
    def insert(self, word):
        """Insert a word into the trie."""
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.is_end_of_word = True
    
    def search(self, word):
        """Search for a word in the trie. Returns True if found."""
        node = self.root
        for char in word:
            if char not in node.children:
                return False
            node = node.children[char]
        return node.is_end_of_word
    
    def starts_with(self, prefix):
        """Check if there exists any word with the given prefix."""
        node = self.root
        for char in prefix:
            if char not in node.children:
                return False
            node = node.children[char]
        return True
    
    def decode(self, input_sequence):
        """
        Decode an input sequence of characters into valid words
        by exploring all possible paths in the trie.
        """
        results = []
        node = self.root

        def dfs(current_node, path, remaining_input):
            if not remaining_input:  # If no input remains
                if current_node.is_end_of_word:  # Check if this is a valid word
                    results.append("".join(path))
                return

            # Process the next character in the input
            current_char = remaining_input[0]
            if current_char in current_node.children:
                # Follow the valid path
                dfs(current_node.children[current_char], path + [current_char], remaining_input[1:])
            
            # Optionally handle blanks or mismatches (if needed)

        dfs(node, [], input_sequence)
        return results

# Example Usage
fst_trie = FSTTrie()
lexicon = ["go", "forth", "gone"]

# Insert words into the trie
for word in lexicon:
    fst_trie.insert(word)

# Decode input sequences
input_sequence = "go"  # Perfect match
print("Decoded words for input 'go':", fst_trie.decode(input_sequence))

input_sequence = "fo"  # Partial match
print("Decoded words for input 'fo':", fst_trie.decode(input_sequence))

input_sequence = "ghost"  # No match in lexicon
print("Decoded words for input 'ghost':", fst_trie.decode(input_sequence))


Decoded words for input 'go': ['go']
Decoded words for input 'fo': []
Decoded words for input 'ghost': []


In [8]:
def edit_distance(word1, word2):
    """Calculate the minimum edit distance between two words."""
    m, n = len(word1), len(word2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0:
                dp[i][j] = j
            elif j == 0:
                dp[i][j] = i
            elif word1[i - 1] == word2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
    return dp[m][n]

def find_closest_word(trie, input_word):
    """Find the closest word in the trie to the input word based on edit distance."""
    closest_word = None
    min_distance = float('inf')

    for word in lexicon:
        distance = edit_distance(input_word, word)
        if distance < min_distance:
            min_distance = distance
            closest_word = word
    
    return closest_word

# Example Usage
print("Closest to 'ghost':", find_closest_word(fst_trie, "ghost"))  # Expected output: 'gone' or 'go'


Closest to 'ghost': go


In [9]:
import heapq

class FSTTrie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.is_end_of_word = True

    def beam_search(self, char_probs, beam_width=3):
        """
        Perform beam search on the FST Trie with character probabilities.
        
        Args:
            char_probs (list of dict): Each element is a dict of {char: prob}.
            beam_width (int): The maximum number of candidates to retain at each step.
        
        Returns:
            List of tuples: Decoded words and their scores.
        """
        # Beam is a priority queue (negative score for max-heap behavior)
        beam = [(0, self.root, [])]  # (score, current_node, path_so_far)

        for timestep_probs in char_probs:
            new_beam = []
            for score, node, path in beam:
                for char, prob in timestep_probs.items():
                    if char in node.children:
                        child_node = node.children[char]
                        new_score = score + prob  # Add log-prob for beam pruning
                        new_path = path + [char]
                        heapq.heappush(new_beam, (new_score, child_node, new_path))
            
            # Keep only the top beam_width candidates
            beam = heapq.nlargest(beam_width, new_beam)

        # Extract complete words from the beam
        results = []
        for score, node, path in beam:
            if node.is_end_of_word:
                results.append(("".join(path), score))
        
        return results


# Example Usage
fst_trie = FSTTrie()
lexicon = ["go", "gone", "forth"]
for word in lexicon:
    fst_trie.insert(word)

# Simulated LSTM output probabilities (at each timestep, char → log-probability)
char_probs = [
    {'g': -0.2, 'f': -1.0},   # Probabilities at timestep 1
    {'o': -0.1, 'a': -1.2},   # Probabilities at timestep 2
    {'n': -0.3, 'r': -1.5},   # Probabilities at timestep 3
    {'e': -0.4, 't': -1.1},   # Probabilities at timestep 4
]

# Perform beam search
results = fst_trie.beam_search(char_probs, beam_width=3)
print("Beam Search Results:", results)


Beam Search Results: [('gone', -1.0)]


In [11]:
import heapq
from collections import defaultdict

class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end_of_word = False

class FSTTrie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.is_end_of_word = True

    def ctc_beam_search(self, char_probs, beam_width=3, blank='_'):
        """
        Perform CTC beam search with FST constraints, and collapse repeated characters.
        
        Args:
            char_probs (list of dict): Each element is a dict of {char: prob}.
            beam_width (int): The maximum number of candidates to retain at each step.
            blank (str): The blank symbol for CTC decoding.
        
        Returns:
            List of tuples: Decoded words and their scores.
        """
        # Beam is a priority queue (negative score for max-heap behavior)
        beam = [("", self.root, 0.0, None)]  # (prefix, current_node, score, last_char)

        for timestep_probs in char_probs:
            new_beam = defaultdict(lambda: (-float('inf'), -float('inf')))

            for prefix, node, p_score, last_char in beam:
                # Add blank symbol transition: Blank means skip this timestep
                new_p_score = p_score + timestep_probs.get(blank, -float('inf'))
                new_beam[(prefix, node)] = (
                    max(new_beam[(prefix, node)][0], new_p_score),  # Update p_score for blank
                    new_beam[(prefix, node)][1]  # Keep existing last_char
                )

                # Process each character (skip if it doesn't exist in the trie)
                for char, prob in timestep_probs.items():
                    if char == blank or char not in node.children:
                        continue

                    # Collapse repeated characters (skip adding if same as last_char)
                    if char == last_char:
                        continue
                    
                    child_node = node.children[char]
                    new_prefix = prefix + char

                    # Add probability for non-blank characters (non-repeating)
                    new_p_score = p_score + prob
                    new_beam[(new_prefix, child_node)] = (
                        new_beam[(new_prefix, child_node)][0],  # Keep existing blank score
                        max(new_beam[(new_prefix, child_node)][1], new_p_score)  # Update p_score for non-blank
                    )

            # Convert scores to log probabilities for ranking and limit beam size
            beam = heapq.nlargest(
                beam_width,
                [(prefix, node, p_b, p_nb) for (prefix, node), (p_b, p_nb) in new_beam.items()],
                key=lambda x: max(x[2], x[3])  # Use max of blank and non-blank probabilities
            )

        # Finalize by collapsing sequences and filtering valid words
        results = []
        for prefix, node, p_b, p_nb in beam:
            if node.is_end_of_word:
                results.append((prefix, max(p_b, p_nb)))
        
        return sorted(results, key=lambda x: x[1], reverse=True)


# Example Usage
fst_trie = FSTTrie()
lexicon = ["go", "gone", "forth"]
for word in lexicon:
    fst_trie.insert(word)

# Simulated LSTM output probabilities (at each timestep, char → log-probability)
char_probs = [
    {'g': -0.2, 'f': -1.0, '_': -0.6},  # Timestep 1
    {'o': -0.1, 'a': -1.2, '_': -0.4},  # Timestep 2
    {'o': -0.3, 'r': -1.5, '_': -0.5},  # Timestep 3
    {'n': -0.4, 't': -1.1, '_': -0.6},  # Timestep 4
]

# Perform beam search with CTC decoding
results = fst_trie.ctc_beam_search(char_probs, beam_width=3)
print("CTC Beam Search Results:", results)


CTC Beam Search Results: [('go', -inf)]
