Skip to content

Commit

Permalink
added feature to return span_info when using extract_keywords
Browse files Browse the repository at this point in the history
  • Loading branch information
vi3k6i5 committed Nov 21, 2017
1 parent ae2d85d commit 7090989
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 7 deletions.
21 changes: 15 additions & 6 deletions flashtext/keyword.py
Expand Up @@ -381,7 +381,7 @@ def add_keywords_from_list(self, keyword_list):
"""
if not isinstance(keyword_list, list):
raise AttributeError("keyword_list should be a list")
raise AttributeError("keyword_list should be a list")

for keyword in keyword_list:
self.add_keyword(keyword)
Expand Down Expand Up @@ -441,7 +441,7 @@ def get_all_keywords(self, term_so_far='', current_dict=None):
terms_present[key] = sub_values[key]
return terms_present

def extract_keywords(self, sentence):
def extract_keywords(self, sentence, span_info=False):
"""Searches in the string for all keywords present in corpus.
Keywords present are added to a list `keywords_extracted` and returned.
Expand All @@ -468,7 +468,9 @@ def extract_keywords(self, sentence):
if not self.case_sensitive:
sentence = sentence.lower()
current_dict = self.keyword_trie_dict
sequence_start_pos = 0
sequence_end_pos = 0
reset_current_dict = False
idx = 0
sentence_len = len(sentence)
while idx < sentence_len:
Expand Down Expand Up @@ -515,17 +517,19 @@ def extract_keywords(self, sentence):
idx = sequence_end_pos
current_dict = self.keyword_trie_dict
if longest_sequence_found:
keywords_extracted.append(longest_sequence_found)

keywords_extracted.append((longest_sequence_found, sequence_start_pos, idx))
reset_current_dict = True
else:
# we reset current_dict
current_dict = self.keyword_trie_dict
reset_current_dict = True
elif char in current_dict:
# we can continue from this char
current_dict = current_dict[char]
else:
# we reset current_dict
current_dict = self.keyword_trie_dict
reset_current_dict = True
# skip to end of word
idy = idx + 1
while idy < sentence_len:
Expand All @@ -538,9 +542,14 @@ def extract_keywords(self, sentence):
if idx + 1 >= sentence_len:
if self._keyword in current_dict:
sequence_found = current_dict[self._keyword]
keywords_extracted.append(sequence_found)
keywords_extracted.append((sequence_found, sequence_start_pos, sentence_len))
idx += 1
return keywords_extracted
if reset_current_dict:
reset_current_dict = False
sequence_start_pos = idx
if span_info:
return keywords_extracted
return [value[0] for value in keywords_extracted]

def replace_keywords(self, sentence):
"""Searches in the string for all keywords present in corpus.
Expand Down
23 changes: 22 additions & 1 deletion test/test_kp_exceptions.py
Expand Up @@ -30,7 +30,13 @@ def test_add_keyword_file_missing(self):
with pytest.raises(IOError):
keyword_processor.add_keyword_from_file('missing_file')

def test_add_keyword_file_missing(self):
def test_add_keyword_from_list(self):
keyword_processor = KeywordProcessor()
keyword_list = "java"
with pytest.raises(AttributeError):
keyword_processor.add_keywords_from_list(keyword_list)

def test_add_keyword_from_dictionary(self):
keyword_processor = KeywordProcessor()
keyword_dict = {
"java": "java_2e",
Expand All @@ -39,6 +45,21 @@ def test_add_keyword_file_missing(self):
with pytest.raises(AttributeError):
keyword_processor.add_keywords_from_dict(keyword_dict)

def test_remove_keyword_from_list(self):
keyword_processor = KeywordProcessor()
keyword_list = "java"
with pytest.raises(AttributeError):
keyword_processor.remove_keywords_from_list(keyword_list)

def test_remove_keyword_from_dictionary(self):
keyword_processor = KeywordProcessor()
keyword_dict = {
"java": "java_2e",
"product management": "product manager"
}
with pytest.raises(AttributeError):
keyword_processor.remove_keywords_from_dict(keyword_dict)

def test_empty_string(self):
keyword_processor = KeywordProcessor()
keyword_dict = {
Expand Down
53 changes: 53 additions & 0 deletions test/test_kp_extract_span.py
@@ -0,0 +1,53 @@
from flashtext import KeywordProcessor
import logging
import unittest
import json

logger = logging.getLogger(__name__)


class TestKPExtractorSpan(unittest.TestCase):
def setUp(self):
logger.info("Starting...")
with open('test/keyword_extractor_test_cases.json') as f:
self.test_cases = json.load(f)

def tearDown(self):
logger.info("Ending.")

def test_extract_keywords(self):
"""For each of the test case initialize a new KeywordProcessor.
Add the keywords the test case to KeywordProcessor.
Extract keywords and check if they match the expected result for the test case.
"""
for test_id, test_case in enumerate(self.test_cases):
keyword_processor = KeywordProcessor()
for key in test_case['keyword_dict']:
keyword_processor.add_keywords_from_list(test_case['keyword_dict'][key])
keywords_extracted = keyword_processor.extract_keywords(test_case['sentence'], span_info=True)
for kwd in keywords_extracted:
# returned keyword lowered should match the sapn from sentence
self.assertEqual(
kwd[0].lower(), test_case['sentence'].lower()[kwd[1]:kwd[2]],
"keywords span don't match the expected results for test case: {}".format(test_id))

def test_extract_keywords_case_sensitive(self):
"""For each of the test case initialize a new KeywordProcessor.
Add the keywords the test case to KeywordProcessor.
Extract keywords and check if they match the expected result for the test case.
"""
for test_id, test_case in enumerate(self.test_cases):
keyword_processor = KeywordProcessor(case_sensitive=True)
for key in test_case['keyword_dict']:
keyword_processor.add_keywords_from_list(test_case['keyword_dict'][key])
keywords_extracted = keyword_processor.extract_keywords(test_case['sentence'], span_info=True)
for kwd in keywords_extracted:
# returned keyword should match the sapn from sentence
self.assertEqual(
kwd[0], test_case['sentence'][kwd[1]:kwd[2]],
"keywords span don't match the expected results for test case: {}".format(test_id))

if __name__ == '__main__':
unittest.main()
15 changes: 15 additions & 0 deletions test/test_remove_keywords.py
Expand Up @@ -31,6 +31,21 @@ def test_remove_keywords(self):
self.assertEqual(keywords_extracted, test_case['keywords'],
"keywords_extracted don't match the expected results for test case: {}".format(test_id))

def test_remove_keywords_using_list(self):
"""For each of the test case initialize a new KeywordProcessor.
Add the keywords the test case to KeywordProcessor.
Remove the keywords in remove_keyword_dict
Extract keywords and check if they match the expected result for the test case.
"""
for test_id, test_case in enumerate(self.test_cases):
keyword_processor = KeywordProcessor()
keyword_processor.add_keywords_from_dict(test_case['keyword_dict'])
for key in test_case['remove_keyword_dict']:
keyword_processor.remove_keywords_from_list(test_case['remove_keyword_dict'][key])
keywords_extracted = keyword_processor.extract_keywords(test_case['sentence'])
self.assertEqual(keywords_extracted, test_case['keywords'],
"keywords_extracted don't match the expected results for test case: {}".format(test_id))

def test_remove_keywords_dictionary_compare(self):
"""For each of the test case initialize a new KeywordProcessor.
Add the keywords the test case to KeywordProcessor.
Expand Down

0 comments on commit 7090989

Please sign in to comment.