In [17]:
import numpy as np
import random
from tabulate import tabulate

class Tree:
    def __init__(self, min_leaf_size=1, max_depth=15):
        self.root = None
        self.words = None
        self.min_leaf_size = min_leaf_size
        self.max_depth = max_depth

    def fit(self, words, verbose=False):
        self.words = words
        self.root = Node(depth=0, parent=None)

        self.root.fit(all_words=self.words, my_words_idx=np.arange(len(self.words)),
                      min_leaf_size=self.min_leaf_size, max_depth=self.max_depth, verbose=verbose)

    def predict(self, bigrams, max_words=5):
        return self.root.predict(bigrams, max_words)


class Node:
    def __init__(self, depth, parent):
        self.depth = depth
        self.parent = parent
        self.all_words = None
        self.my_words_idx = None
        self.children = {}
        self.is_leaf = True
        self.query = None
        self.history = []

    def get_query(self):
        return self.query

    def get_child(self, response):
        if self.is_leaf:
            return self
        else:
            if response not in self.children:
                response = list(self.children.keys())[0]
            return self.children[response]

    def get_bigrams(self, word, lim=5):
        bg = [''.join(bg) for bg in zip(word, word[1:])]
        bg = sorted(set(bg))
        return tuple(bg)[:lim]

    def get_random_bigram(self):
        return chr(ord('a') + random.randint(0, 25)) + chr(ord('a') + random.randint(0, 25))

    def process_leaf(self, all_words, my_words_idx, history, verbose):
        self.my_words_idx = my_words_idx

    def process_node(self, all_words, my_words_idx, history, verbose):
        query = self.get_random_bigram()
        split_dict = {True: [], False: []}

        for idx in my_words_idx:
            bg_list = self.get_bigrams(all_words[idx])
            split_dict[query in bg_list].append(idx)

        return query, split_dict

    def fit(self, all_words, my_words_idx, min_leaf_size, max_depth, fmt_str="    ", verbose=False):
        self.all_words = all_words
        self.my_words_idx = my_words_idx

        if len(my_words_idx) <= min_leaf_size or self.depth >= max_depth:
            self.is_leaf = True
            self.process_leaf(self.all_words, self.my_words_idx, self.history, verbose)
        else:
            self.is_leaf = False
            self.query, split_dict = self.process_node(self.all_words, self.my_words_idx, self.history, verbose)

            for i, (response, split) in enumerate(split_dict.items()):
                self.children[response] = Node(depth=self.depth + 1, parent=self)
                history = self.history.copy()
                history.append(self.query)
                self.children[response].history = history
                self.children[response].fit(self.all_words, split, min_leaf_size, max_depth, fmt_str, verbose)

    def predict(self, bigrams, max_words=5):
        node = self
        valid_words = []

        def contains_all_bigrams(word, bigrams):
            word_bigrams = self.get_bigrams(word)
            return all(bg in word_bigrams for bg in bigrams)

        while len(valid_words) < max_words and not node.is_leaf:
            node = node.get_child(any(bg in bigrams for bg in node.get_bigrams(self.all_words[node.my_words_idx[0]])))

        for idx in node.my_words_idx:
            word = self.all_words[idx]
            if contains_all_bigrams(word, bigrams):
                valid_words.append(word)
                if len(valid_words) == max_words:
                    break

        return valid_words


def load_dictionary(file_path):
    with open(file_path, 'r') as file:
        words = [word.strip() for word in file.readlines()]
    return words


def generate_random_bigrams(word, max_bigrams=5):
    bigrams = []
    while len(bigrams) < max_bigrams:
        bg = ''.join(random.sample(word, 2))
        if bg not in bigrams:
            bigrams.append(bg)
    return bigrams


def get_bigrams(word, limit=5):
    # Create bigrams from the word
    bigrams = [word[i:i+2] for i in range(len(word)-1)]

    # Remove duplicates and sort them
    unique_bigrams = sorted(set(bigrams))

    # Return the first 'limit' bigrams
    return unique_bigrams[:limit]


def process_words_from_file(file_path):
    word_bigram_pairs = []

    with open(file_path, 'r') as file:
        words = file.readlines()

    # Remove any trailing whitespace characters including newline characters
    words = [word.strip() for word in words]

    for word in words:
        bigrams = get_bigrams(word)
        word_bigram_pairs.append((word, bigrams))

    return word_bigram_pairs


def calculate_accuracy(root_word, guess_list):
    correct_word = root_word
    total_points = 0
    for i, word in enumerate(guess_list):
        position = i + 1  # Position starts from 1
        if word == correct_word:
            points = 1 / position
            total_points += points
            break  # Stop after finding the correct word

    return total_points


def main():
    # Load dictionary from file
    file_path = '/Users/shubhamkumarjha/Downloads/dict.txt'  # Adjust the path to your dictionary file
    words = load_dictionary(file_path)

    # Initialize and train the decision tree
    tree = Tree(min_leaf_size=1, max_depth=5)
    tree.fit(words, verbose=True)

    # Generate word_bigram_pairs for testing
    word_bigram_pairs = process_words_from_file(file_path)

    # Randomly select 200 word_bigram_pairs for testing
    random_pairs = random.sample(word_bigram_pairs, 200)

    # Create a table for root words and guess words
    table_data = []

    total_accuracy_points = 0
    total_possible_points = 0

    for word, test_bigrams in random_pairs:
        # Predict up to 5 words containing all test bigrams
        guess_list = tree.predict(test_bigrams, max_words=5)

        # Sort the guess_list by increasing string length
        guess_list.sort(key=len)

        # Calculate accuracy and points
        accuracy_points = calculate_accuracy(word, guess_list)
        total_accuracy_points += accuracy_points
        total_possible_points += 1  # Each correct guess list has a maximum of 1 point

        # Append to table data
        table_data.append([word, guess_list])



    # Calculate accuracy
    if total_possible_points > 0:
        accuracy = total_accuracy_points / total_possible_points
        print(f"\nAccuracy: {accuracy:.2f}")

if __name__ == "__main__":
    main()



Accuracy: 0.86
