-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
62 lines (48 loc) · 2.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import csv
import Trie
class Utils:
@staticmethod
# load and clean data
def load_data():
with open('csv_data/data.csv', 'r') as f:
reader = csv.reader(f, delimiter=',')
headers = next(reader)
data = list(reader)
trie = Trie.TrieNode()
for elem in data:
if ' ' in elem: elem.remove(' ')
if '' in elem: elem.remove('')
trie.insert(' '.join(elem).strip().lower())
return trie
@staticmethod
# Search using levenshtein algorithm
def search(word, maxCost, trie):
size = len(word)
# build first row
currentRow = range(len(word) + 1)
results = []
# recursively search each branch of the trie
for letter in trie.children:
Utils.look_recursive(trie.children[letter], letter, word, currentRow,
results, maxCost, size)
return results
@staticmethod
def look_recursive(node, letter, word, previousRow, results, maxCost, size):
columns = len(word) + 1
currentRow = [previousRow[0] + 1]
# Create a row for the letter, with a column for each letter in the search word
for column in range(1, columns):
insertCost = currentRow[column - 1] + 1
deleteCost = previousRow[column] + 1
if word[column - 1] != letter:
replaceCost = previousRow[column - 1] + 1
else:
replaceCost = previousRow[column - 1]
currentRow.append(min(insertCost, deleteCost, replaceCost))
if currentRow[-1] <= maxCost and node.word != None:
# Add result with weighted levenshtein forumla for ranking 0 being closest match and 1 no match.
results.append((node.word, currentRow[-1]/max(len(node.word), size)))
if min(currentRow) <= maxCost:
for letter in node.children:
Utils.look_recursive(node.children[letter], letter, word, currentRow,
results, maxCost, size)