diff --git a/demo_run.py b/demo_run.py index ed5b26a..06cb698 100644 --- a/demo_run.py +++ b/demo_run.py @@ -40,12 +40,12 @@ def load_data_2_root(data): if __name__ == "__main__": - root_name = basedir + "data/root.pkl" + root_name = basedir + "/data/root.pkl" stopwords = get_stopwords() if os.path.exists(root_name): root = load_model(root_name) else: - dict_name = basedir + 'data/dict.txt' + dict_name = basedir + '/data/dict.txt' word_freq = load_dictionary(dict_name) root = TrieNode('*', word_freq) save_model(root, root_name) diff --git a/model.py b/model.py index 17980a3..4ff227a 100644 --- a/model.py +++ b/model.py @@ -81,6 +81,7 @@ def add(self, word): length = len(word) node = self.root if length == 3: + word = list(word) word[0], word[1], word[2] = word[1], word[2], word[0] for count, char in enumerate(word): diff --git a/utils.py b/utils.py index 1df28b4..f46ccd8 100644 --- a/utils.py +++ b/utils.py @@ -15,7 +15,10 @@ def get_stopwords(): def generate_ngram(input_list, n): - return zip(*[input_list[i:] for i in range(n)]) + result = [] + for i in range(1, n+1): + result.extend(zip(*[input_list[j:] for j in range(i)])) + return result def load_dictionary(filename):