# Imports


In [6]:
import pandas as pd 
import numpy as np

import scipy
from scipy.stats import norm

from rapidfuzz.distance import Levenshtein

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

import pickle
import joblib

from abc import ABC, abstractmethod

import sys 
import random
import string
import math
import time

### Version Checks


In [7]:
print('python:', sys.version)
print()

print('numpy:', np.__version__)
print('pandas:', pd.__version__)
print()

print('scipy:', scipy.__version__)
print()

print('matplotlib:', matplotlib.__version__)
print('seaborn:', sns.__version__)

python: 3.12.9 (tags/v3.12.9:fdb8142, Feb  4 2025, 15:27:58) [MSC v.1942 64 bit (AMD64)]

numpy: 2.2.4
pandas: 2.2.3

scipy: 1.15.2

matplotlib: 3.10.1
seaborn: 0.13.2


# Template class


In [16]:
from abc import ABC, abstractmethod
import numpy as np

class TreeTemplate(ABC):
    def __init__(self, num_strings, split_num, depth, s1_idx=None, s2_idx=None, tf_array=None):
        self.num_strings = num_strings
        self.split_num = split_num
        self.depth = depth
        self.s1_idx = s1_idx
        self.s2_idx = s2_idx
        self.left = None
        self.right = None

        if tf_array is None and depth == 0:
            self.tf_array = np.zeros((num_strings, split_num), dtype=bool)
        else:
            self.tf_array = tf_array

    @abstractmethod
    def distance(self, a, b):
        pass

    @abstractmethod
    def build_tree(self, strings, indices):
        pass

    @abstractmethod
    def find_matches(self, strings, new_str):
        pass
    
    @abstractmethod
    def save_tree(self, filename):
        pass

    @staticmethod
    @abstractmethod
    def load_tree(filename):
        pass




In [17]:
from abc import ABC, abstractmethod

class IndexTemplate(ABC):
    def __init__(self, num_trees, num_strings, split_num):
        self.num_trees = num_trees
        self.num_strings = num_strings
        self.split_num = split_num
        self.trees = []

    @abstractmethod
    def add_item(self, i, item):
        pass

    @abstractmethod
    def build(self):
        pass

    @abstractmethod
    def unbuild(self):
        pass

    @abstractmethod
    def get_nns_by_vector(self, query, topk=None):
        pass

    @abstractmethod
    def get_nns_by_item(self, i, n=None):
        pass
    
    @abstractmethod
    def save(self, filename):
        pass

    @staticmethod
    @abstractmethod
    def load(filename):
        pass

    @abstractmethod
    def unload(self):
        pass

    @abstractmethod
    def on_disk_build(self, fn) -> bool:
        pass
    

In [18]:
class ForestTemplate(ABC):
    def __init__(self, n_estimators, max_depth, random_state=None):
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.random_state = random_state
        self.trees = []
        self.strings = []
        if random_state is not None:
            random.seed(random_state)
            np.random.seed(random_state)
        self._is_fitted = False 


    @abstractmethod
    def fit(self, strings: list[str]):
        pass

    @abstractmethod
    def transform(self, strings: list[str]) -> np.ndarray:
        pass

    @abstractmethod
    def predict(self, queries: list[str], topk=1) -> list[list[str]]:
        pass
    
    @abstractmethod
    def save(self, filename: str):
        pass

    @staticmethod
    @abstractmethod
    def load(filename: str):
        pass

    @abstractmethod
    def get_params(self, deep=True):
        pass

    @abstractmethod
    def set_params(self, **params):
        pass

    @abstractmethod
    def __repr__(self):
        pass
      

# LevenshteinTree


In [19]:
from rapidfuzz.distance import Levenshtein
import pickle

class LevenshteinTree(TreeTemplate):
    def __init__(self, num_strings, split_num, depth, s1_idx=None, s2_idx=None, tf_array=None, distance_fn=None):
        super().__init__(num_strings, split_num, depth, s1_idx, s2_idx, tf_array)
        self._distance_fn = distance_fn or Levenshtein.distance
    
    def distance(self, a, b):
        return self._distance_fn(a, b)

    def build_tree(self, strings, indices):
        # len(indices) == 0 or 1, can't split more
        if self.depth >= self.split_num or len(indices) <= 1:
            return None  # no node needed for empty/terminal group

        # len(indices) == 2, just split in two
        if len(indices) == 2:
            self.tf_array[indices[0], self.depth] = True
            self.tf_array[indices[1], self.depth] = False
            return LevenshteinTree(self.num_strings, self.split_num, self.depth)

        # Most cases
        rand_pos = np.random.choice(len(indices), size=2, replace=False)
        s1_idx = indices[rand_pos[0]]
        s2_idx = indices[rand_pos[1]]

        node = LevenshteinTree(self.num_strings, self.split_num, self.depth, s1_idx, s2_idx, self.tf_array)

        for idx in indices:
            d1 = self.distance(strings[idx], strings[s1_idx])
            d2 = self.distance(strings[idx], strings[s2_idx])
            self.tf_array[idx, self.depth] = (d1 >= d2)

        mask = self.tf_array[indices, self.depth]
        left = indices[np.flatnonzero(mask)]
        right = indices[np.flatnonzero(~mask)]

        node.left = LevenshteinTree(self.num_strings, self.split_num, self.depth + 1, tf_array=self.tf_array).build_tree(strings, left)
        node.right = LevenshteinTree(self.num_strings, self.split_num, self.depth + 1, tf_array=self.tf_array).build_tree(strings, right)

        return node

    def get_code(self, strings, new_str):
        fingerprint = np.zeros(self.split_num, dtype=bool)
        idx = 0
        node = self

        while node and node.s1_idx is not None and node.s2_idx is not None:
            d1 = self.distance(new_str, strings[node.s1_idx])
            d2 = self.distance(new_str, strings[node.s2_idx])
            go_left = d1 >= d2
            fingerprint[idx] = go_left
            node = node.left if go_left else node.right
            idx += 1
            
            # just in case
            assert idx < self.split_num, f"Fingerprint overflowed! idx={idx}, split_num={self.split_num}"


        return fingerprint

    def find_matches(self, strings, new_str, return_strings=True):
        fingerprint = self.get_code(strings, new_str)
        matches = np.where((self.tf_array == fingerprint).all(axis=1))[0]

        if return_strings:
            return matches, [strings[i] for i in matches]
        else:
            return matches
        
    def transform(self, strings: list[str]) -> np.ndarray:
        result = np.zeros((len(strings), self.split_num), dtype=bool)
        for i, s in enumerate(strings):
            result[i] = self.get_code(strings, s)
        return result

    def depth(self) -> int:
        def _max_depth(node):
            if node is None:
                return 0
            return 1 + max(_max_depth(node.left), _max_depth(node.right))

        return _max_depth(self)

    def save_tree(self, filename):
        with open(filename, "wb") as f:
            pickle.dump(self, f)

    @staticmethod
    def load_tree(filename):
        with open(filename, "rb") as f:
            return pickle.load(f)



# LevenshteinIndex


In [20]:
class LevenshteinIndex(IndexTemplate):
    def __init__(self, num_trees, num_strings, split_num):
        super().__init__(num_trees, num_strings, split_num)
        self._string_buffer = [None] * num_strings  # reserve space
        self._item_count = 0

    # add one-by-one
    def add_item(self, i: int, string: str):
        if self._string_buffer[i] is None:
            self._item_count += 1
        self._string_buffer[i] = string

    # add data at once
    def add_items_bulk(self, strings):
        if not isinstance(strings, (list, np.ndarray, pd.Series)):
            raise TypeError("Input must be a list, numpy array, or pandas Series of strings.")

        strings_array = np.asarray(strings, dtype=object)
        n = len(strings_array)

        if n > len(self._string_buffer):
            raise ValueError("Too many strings to add to the index.")

        # Update item count for previously None entries
        mask = np.array(self._string_buffer[:n], dtype=object) == None
        self._item_count += np.count_nonzero(mask)

        # Bulk assign to the internal buffer
        self._string_buffer[:n] = strings_array


    # Annoy style
    def build(self):
        strings = self._string_buffer
        for i in range(self.num_trees):
            root = LevenshteinTree(self.num_strings, self.split_num, 0)
            tree = root.build_tree(strings, np.arange(self.num_strings))
            self.trees.append(tree)

    def unbuild(self):
        self.trees = []
        return True
    

    def get_nns_by_vector(self, query_str, topk=None):
        all_matches = []

        for tree in self.trees:
            matches = tree.find_matches(self._string_buffer, query_str)
            all_matches.extend(matches)

        # Remove duplicates, keep unique indices
        unique_matches = list(set(all_matches))

        if topk:
            # Compute Levenshtein distances to query_str for all unique matches
            scored = [(idx, Levenshtein.distance(query_str, self._string_butter[idx])) for idx in unique_matches]
            scored.sort(key=lambda x: x[1])  # sort by distance (ascending)
            return [idx for idx, dist in scored[:topk]]

        return unique_matches  # return unsorted matches if topk not requested
    
    # same as get_nns_by_vector
    def get_nns_by_string(self, query_str, topk=None):
        return self.get_nns_by_vector(query_str, topk=topk)

    def get_nns_by_item(self, i, n=None):
        query_str = self._string_buffer[i]
        return self.get_nns_by_string(query_str, topk=n)


    def save(self, filename: str):
        with open(filename, "wb") as f:
            pickle.dump(self, f)

    @staticmethod
    def load(filename: str):
        with open(filename, "rb") as f:
            return pickle.load(f)
        
    def unload(self):
        self.trees = []
        self._string_buffer = []
        self._item_count = 0
        return True
    

    def on_disk_build(self, fn: str) -> bool:
        # Placeholder: no real on-disk building for Levenshtein trees
        print(f"[Warning] on_disk_build is not supported yet.")
        return True


    def get_distance(self, i: int, j: int) -> int:
        if not hasattr(self, '_string_buffer'):
            raise ValueError("No string data loaded. Use add_item() or build().")

        return Levenshtein.distance(self._string_buffer[i], self._string_buffer[j])

    def get_n_items(self) -> int:
        return self._item_count
    
    def get_n_trees(self) -> int:
        return len(self.trees)
    
    def get_strings(self) -> list[str]:
        return self._string_buffer
    
    def get_item_vector(self, i: int, tree_id: int = 0) -> list[bool]:
        if not self.trees:
            raise ValueError("No trees built.")
        string = self._string_buffer[i]
        return self.trees[tree_id].get_code(self._string_buffer, string).tolist()
    

    def verbose(self, v: bool) -> bool:
        self._verbose = v
        return True

    def set_seed(self, s: int) -> None:
        random.seed(s)
        np.random.seed(s)



# LevenshteinForest


In [21]:
class LevenshteinForest(ForestTemplate):
    def __init__(self, n_estimators=10, max_depth=120, random_state=None):
        super().__init__(n_estimators, max_depth, random_state)
        
    def fit(self, strings: list[str]):
        self.strings = strings
        self.num_strings = len(strings)

        self.trees = []
        for _ in range(self.n_estimators):
            root = LevenshteinTree(self.num_strings, self.max_depth, 0)
            tree = root.build_tree(strings, np.arange(self.num_strings))
            self.trees.append(tree)

        self._is_fitted = True
        return self

    def transform(self, strings: list[str]) -> np.ndarray:
        assert self._is_fitted, "Call fit() before transform()"

        result = np.zeros((len(strings), self.n_estimators * self.max_depth), dtype=bool)

        for i, s in enumerate(strings):
            code_parts = []
            for tree in self.trees:
                code = tree.get_code(self.strings, s)
                code_parts.append(code)
            result[i] = np.concatenate(code_parts)

        return result
    
    def predict(self, queries: list[str], topk=1) -> list[list[str]]:
        assert self._is_fitted, "Call fit() before predict()"
        predictions = []

        for q in queries:
            match_counts = {}

            for tree in self.trees:
                matches = tree.find_matches(self.strings, q, return_strings=False)
                for idx in matches:
                    match_counts[idx] = match_counts.get(idx, 0) + 1

            # Sort matches by how often they appeared
            sorted_matches = sorted(match_counts.items(), key=lambda x: -x[1])
            top_indices = [idx for idx, _ in sorted_matches[:topk]]
            top_strings = [self.strings[i] for i in top_indices]

            predictions.append(top_strings)

        return predictions
    
    import pickle
    def save(self, filename: str):
        with open(filename, "wb") as f:
            pickle.dump(self, f)

    @staticmethod
    def load(filename: str):
        with open(filename, "rb") as f:
            return pickle.load(f)

    def get_params(self, deep=True):
        return {
            "n_estimators": self.n_estimators,
            "max_depth": self.max_depth,
            "random_state": self.random_state
        }

    def set_params(self, **params):
        for key, value in params.items():
            if hasattr(self, key):
                setattr(self, key, value)
        return self
    

    def __repr__(self):
        return (
            f"LevenshteinForest("
            f"n_estimators={self.n_estimators}, "
            f"max_depth={self.max_depth}, "
            f"random_state={self.random_state}, "
            f"is_fitted={self._is_fitted})"
        )


# Example Usage


### LevenshteinIndex

- add string one by one from file


In [None]:
# create index
leven_index = LevenshteinIndex(num_trees=10, num_strings=10_000, split_num=120)

# add strings to index
with open('Data/Random_strings_10K.txt', 'r', encoding='utf-8') as f:
    for i, line in enumerate(f):
        leven_index.add_item(i, line.strip())

# build index
leven_index.build()

# query
query = "hello world"
neighbors = leven_index.get_nns_by_string(query, topk=3)

print("Top 3 nearest neighbors:")
for idx in neighbors:
    print(f"[{idx}] {leven_index.get_strings()[idx]}")


leven_index.save("Models/LevenshteinIndex_10k_1.pkl")

### LevenshteinIndex

- just read all string at once


In [None]:
with open('Data/Random_strings_10K.txt', 'r', encoding='utf-8') as f:
    strings = [line.strip() for line in f]

# Create index and add
leven_index = LevenshteinIndex(num_trees=10, num_strings=len(strings), split_num=120)
# for i, s in enumerate(strings):
#     leven_index.add_item(i, s)
leven_index.add_items_bulk(strings)


# build index
leven_index.build()

# Query
query = "hello world"
neighbors = leven_index.get_nns_by_string(query, topk=3)

print("Top 3 nearest neighbors:")
for idx in neighbors:
    print(f"[{idx}] {leven_index.get_strings()[idx]}")

# save
leven_index.save("Models/LevenshteinIndex_10k_2.pkl")


### LevenshteinForest


In [None]:
data = pd.read_csv('Data/Random_strings_10k.txt', delimiter='\t', encoding='utf-8')
train_strings = data['String']

forest = LevenshteinForest(n_estimators=10, max_depth=150)
forest.fit(train_strings)

# codes = forest.transform(train_strings)
predictions = forest.predict(["new string"], topk=3)
print(predictions)

forest.save("Models/LevenshteinForest_10k.pkl")
# loaded = LevenshteinForest.load("Models/LevenshteinForest_10k.pkl")

