In [28]:
from copy import copy
import json
import os
dict_file = '/usr/share/dict/words'
from time import time
import numpy as np
import tqdm

In [2]:
with open(dict_file, 'r') as f:
    word_list = [x.strip().lower() for x in f.readlines()]

In [3]:
# All data is stored in word list and trie list. This class is just a viewer 
# That takes an index and follows the node at that index.
# It is meant to be easy to move this viewer around the trie without
# Extra overhead

class trie_node:
    def __init__(self, word_list, trie_list, index=0):
        #print(index, len(trie_list))
        # 0 is the root index
        self.word_list = word_list
        self.trie_list = trie_list
        # Creating the root node
        if index == -1:
            self.reindex(self.add_node())
        # Loading a node
        else:
            self.reindex(index)
    
    def reindex(self, index):
        if index < len(self.trie_list) and index >= 0:
            self.index = index
        else:
            assert False, "Trie index " + index + " out of bounds" 
    
    def get_word_inds(self):
        return self.trie_list[self.index][0]
    
    def get_children(self):
        return self.trie_list[self.index][1]
    
    def get_words(self):
        return [self.word_list[x] for x in self.get_word_inds()]
    
    # Creates new trie node for child
    def get_child(self, letter):
        if letter in self.get_children():
            return trie_node(self.word_list, self.trie_list, index=self.get_children[letter])
        else:
            return None
    
    # Reindexes this node to become the child
    def become_child(self, letter):
        if letter in self.get_children().keys():
            self.reindex(self.get_children()[letter])
            return True
        else:
            return False
    
    # Yields new trie nodes for each child
    def iter_children(self):
        for letter, index in self.get_children().iteritems():
            yield letter, trie_node(self.word_list, self.trie_list, index=index)
    
    def add_node(self, node=None):
        if node is None:
            self.trie_list.append(([], {}))
        else:
            self.trie_list.append(node)
        return len(self.trie_list) - 1
        
    def add_word(self, word, index):
        original_index = self.index
        for letter in word:
            if letter in self.get_children():
                self.become_child(letter)
            else:
                next_ind = self.add_node()
                self.get_children()[letter] = next_ind
                self.reindex(next_ind)
        self.get_word_inds().append(index)
        self.reindex(original_index)

                
    def search(self, letters):
        original_index = self.index
        for letter in letters:
            if not self.become_child(letter):
                return []
        words = self.get_words()
        self.reindex(original_index)
        return words

def build_trie(search_terms, word_list, anagram=False):
    trie_list = []
    root_node = trie_node(word_list, trie_list, index=-1)
    
    for index, search_term in tqdm.tqdm(enumerate(search_terms)):
        if anagram:
            root_node.add_word(sorted(search_term), index)        
        else:
            root_node.add_word(search_term, index)
        root_node.reindex(0)
    
    return trie_list, word_list, root_node

def save_tries(trie_nodes, names, out_file):
    with open(out_file, 'w') as f:
        json.dump({name :{'word_list': trie_node.word_list,
                      'trie_list': trie_node.trie_list} for name, trie_node in zip(names, trie_nodes)}, f)

# Yields all combinations of the lists passed.
def comb(words):
    if words:
        for word in words[0]:
            for word_list in comb(words[1:]):
                if word is None:
                    yield [] + word_list
                elif type(word) is tuple:
                    yield [word[0]] + word_list
                else:
                    yield [word] + word_list
    else:
        yield []

# Finds all words matching a given pattern (with wild cards) and/or 
# satisfying a given anagram

class word_finder():
    def __init__(self, in_file, file_type=None):
        # Load from json / txt word list
        if type(in_file) is str:
            if file_type is None:
                file_type = os.path.splitext(in_file)[1]
            if file_type == '.txt':
                with open(in_file, 'r') as f:
                    word_list = [x.strip().lower() for x in f.readlines()]
                self.ordered_trie = build_trie(word_list, word_list, anagram=False)[2]
                self.anagram_trie = build_trie(word_list, word_list, anagram=True)[2]
            elif file_type == '.json':
                with open(in_file, 'r') as f:
                    tries = json.load(f)
                self.ordered_trie = trie_node(tries['ordered_trie']['word_list'],
                                              tries['ordered_trie']['trie_list'], 0)
                self.anagram_trie = trie_node(tries['anagram_trie']['word_list'],
                                              tries['anagram_trie']['trie_list'], 0)
            else:
                print("Invalid file type, tries not loaded")

            # Add filler node to anagram trie (necessary for anagram finding)
            self.filler_index = len(self.anagram_trie.trie_list)
            self.anagram_trie.add_node(([], {x:self.filler_index for x in 'abcdefghijklmnopqrstuvwxyz'}))
        # Copy constructor
        else:
            self.anagram_trie = in_file.anagram_trie
            self.ordered_trie = in_file.ordered_trie
            self.filler_index = in_file.filler_index
    
    def save(self, outfile):
        save_tries([self.ordered_trie, self.anagram_trie],
                  ['ordered_trie', 'anagram_trie'],
                  out_file)
    
    # All word strings matching a certain pattern. Lengths specifies 
    # the length of each word (in order) in the word string. Use '.' as a 
    # wildcard.
    # Ex: match_pattern('go.d.ayt.y.u', lengths=(4, 3, 2, 3))
    # Will return ('good', 'day', 'to', 'you')
    # Among other things
    def match_pattern(self, word, index=0, output=None, lengths=None):
        self.ordered_trie.reindex(index)
        if output is None:
            output = set()
        # Run match pattern for each word in lengths and then return all combinations
        if lengths is not None:
            lengths = list(lengths)
            assert sum(lengths) == len(word)
            suffix = word
            words = []
            output = set()
            for num in lengths:
                prefix = suffix[:num]
                suffix = suffix[num:]
                words.append(self.match_pattern(prefix))
            for x in comb(words):
                output.add(tuple(x))
        elif not len(word):
            for x in self.ordered_trie.get_words():
                output.add((x, ))
        else:
            prefix = word[0]
            suffix = word[1:]
            if prefix == '.':
                for _, next_index in self.ordered_trie.get_children().items():
                    self.match_pattern(suffix, 
                                       index=next_index, 
                                       output=output)
            elif prefix in self.ordered_trie.get_children().keys():
                self.match_pattern(suffix, 
                                   index=self.ordered_trie.get_children()[prefix], 
                                   output=output)
        return output
    
    # Finds all mulitword anagrams using the letters passed. Lengths
    # specifies the length of each word in the output set. If the sum
    # of lengths is less than the number of letters, than partial
    # anagrams matching the word structure will be found
    # Eg. anagrams('pzpipxleate', lengths=(5, 3))
    # Will return ('apple', 'pie') among other things
    
    def anagrams(self, letters, lengths=None):
        letters = sorted(letters)
        if lengths is None:
            lengths = [len(letters)]
        lengths = list(lengths)
        assert sum(lengths) <= len(letters)
        results = set()
        if sum(lengths) < len(letters):
            queue = [[[0] * len(lengths) + [self.filler_index], 
                      lengths + [len(letters) - sum(lengths)], 
                      letters]]
        else:
            queue = [[[0] * len(lengths), lengths, letters]]
        while queue:
            indices, lengths, letters = queue.pop(0)
            if not letters:
                words = []
                for idx in indices:
                    if idx != self.filler_index:
                        self.anagram_trie.reindex(idx)
                        words.append(self.anagram_trie.get_words())
                if all(words):
                    for x in comb(words):
                        results.add(tuple(x)) 
            else:
                prefix = letters[0]
                suffix = suffix = letters[1:]       
                for i, (idx, length) in enumerate(zip(indices, lengths)):
                    self.anagram_trie.reindex(idx)
                    if length and prefix in self.anagram_trie.get_children().keys():
                        new_lengths = copy(lengths)
                        new_lengths[i] = length - 1
                        new_indices = copy(indices)
                        new_indices[i] = self.anagram_trie.get_children()[prefix]
                        queue.append([new_indices, new_lengths, suffix])

        return results
    
    # Returns all words that fit the given pattern (if passed)
    # and match the given anagram (if passed)
    def find_words(self, pattern = None, letters = None, lengths = None):
        if pattern is not None:
            output = self.match_pattern(pattern, lengths=lengths)
            if letters is not None:
                return output.intersection(self.anagrams(letters, lengths=lengths))
            return output
        if letters is not None:
            return self.anagrams(letters, lengths=lengths)
        return {}


In [124]:
a = time()
with open('flat_trie.json', 'r') as f:
    t2, w2 = json.load(f)
print(time() - a)

1.41257882118


In [178]:
dict_file

'/usr/share/dict/words'

In [179]:
! stat /usr/share/dict/words

16777221 755234 lrwxr-xr-x 1 root wheel 0 4 "Sep 17 14:33:55 2019" "Mar  7 10:14:46 2019" "Sep 17 14:33:55 2019" "Mar  7 10:14:46 2019" 4194304 0 0x80000 /usr/share/dict/words


In [218]:
a = time()
engine.anagrams('pzpipxleatepxiwar', lengths=(5, 8, 4))
print(time() - a)

71.9436020851
