From 709098959334fd6230258d18b2b4dbca5176a38a Mon Sep 17 00:00:00 2001 From: Vikash Singh Date: Wed, 22 Nov 2017 00:07:45 +0530 Subject: [PATCH] added feature to return span_info when using extract_keywords --- flashtext/keyword.py | 21 ++++++++++---- test/test_kp_exceptions.py | 23 +++++++++++++++- test/test_kp_extract_span.py | 53 ++++++++++++++++++++++++++++++++++++ test/test_remove_keywords.py | 15 ++++++++++ 4 files changed, 105 insertions(+), 7 deletions(-) create mode 100644 test/test_kp_extract_span.py diff --git a/flashtext/keyword.py b/flashtext/keyword.py index 880baf3..f10fc81 100644 --- a/flashtext/keyword.py +++ b/flashtext/keyword.py @@ -381,7 +381,7 @@ def add_keywords_from_list(self, keyword_list): """ if not isinstance(keyword_list, list): - raise AttributeError("keyword_list should be a list") + raise AttributeError("keyword_list should be a list") for keyword in keyword_list: self.add_keyword(keyword) @@ -441,7 +441,7 @@ def get_all_keywords(self, term_so_far='', current_dict=None): terms_present[key] = sub_values[key] return terms_present - def extract_keywords(self, sentence): + def extract_keywords(self, sentence, span_info=False): """Searches in the string for all keywords present in corpus. Keywords present are added to a list `keywords_extracted` and returned. @@ -468,7 +468,9 @@ def extract_keywords(self, sentence): if not self.case_sensitive: sentence = sentence.lower() current_dict = self.keyword_trie_dict + sequence_start_pos = 0 sequence_end_pos = 0 + reset_current_dict = False idx = 0 sentence_len = len(sentence) while idx < sentence_len: @@ -515,17 +517,19 @@ def extract_keywords(self, sentence): idx = sequence_end_pos current_dict = self.keyword_trie_dict if longest_sequence_found: - keywords_extracted.append(longest_sequence_found) - + keywords_extracted.append((longest_sequence_found, sequence_start_pos, idx)) + reset_current_dict = True else: # we reset current_dict current_dict = self.keyword_trie_dict + reset_current_dict = True elif char in current_dict: # we can continue from this char current_dict = current_dict[char] else: # we reset current_dict current_dict = self.keyword_trie_dict + reset_current_dict = True # skip to end of word idy = idx + 1 while idy < sentence_len: @@ -538,9 +542,14 @@ def extract_keywords(self, sentence): if idx + 1 >= sentence_len: if self._keyword in current_dict: sequence_found = current_dict[self._keyword] - keywords_extracted.append(sequence_found) + keywords_extracted.append((sequence_found, sequence_start_pos, sentence_len)) idx += 1 - return keywords_extracted + if reset_current_dict: + reset_current_dict = False + sequence_start_pos = idx + if span_info: + return keywords_extracted + return [value[0] for value in keywords_extracted] def replace_keywords(self, sentence): """Searches in the string for all keywords present in corpus. diff --git a/test/test_kp_exceptions.py b/test/test_kp_exceptions.py index bcd7d99..3c0ebff 100644 --- a/test/test_kp_exceptions.py +++ b/test/test_kp_exceptions.py @@ -30,7 +30,13 @@ def test_add_keyword_file_missing(self): with pytest.raises(IOError): keyword_processor.add_keyword_from_file('missing_file') - def test_add_keyword_file_missing(self): + def test_add_keyword_from_list(self): + keyword_processor = KeywordProcessor() + keyword_list = "java" + with pytest.raises(AttributeError): + keyword_processor.add_keywords_from_list(keyword_list) + + def test_add_keyword_from_dictionary(self): keyword_processor = KeywordProcessor() keyword_dict = { "java": "java_2e", @@ -39,6 +45,21 @@ def test_add_keyword_file_missing(self): with pytest.raises(AttributeError): keyword_processor.add_keywords_from_dict(keyword_dict) + def test_remove_keyword_from_list(self): + keyword_processor = KeywordProcessor() + keyword_list = "java" + with pytest.raises(AttributeError): + keyword_processor.remove_keywords_from_list(keyword_list) + + def test_remove_keyword_from_dictionary(self): + keyword_processor = KeywordProcessor() + keyword_dict = { + "java": "java_2e", + "product management": "product manager" + } + with pytest.raises(AttributeError): + keyword_processor.remove_keywords_from_dict(keyword_dict) + def test_empty_string(self): keyword_processor = KeywordProcessor() keyword_dict = { diff --git a/test/test_kp_extract_span.py b/test/test_kp_extract_span.py new file mode 100644 index 0000000..2b9f7a4 --- /dev/null +++ b/test/test_kp_extract_span.py @@ -0,0 +1,53 @@ +from flashtext import KeywordProcessor +import logging +import unittest +import json + +logger = logging.getLogger(__name__) + + +class TestKPExtractorSpan(unittest.TestCase): + def setUp(self): + logger.info("Starting...") + with open('test/keyword_extractor_test_cases.json') as f: + self.test_cases = json.load(f) + + def tearDown(self): + logger.info("Ending.") + + def test_extract_keywords(self): + """For each of the test case initialize a new KeywordProcessor. + Add the keywords the test case to KeywordProcessor. + Extract keywords and check if they match the expected result for the test case. + + """ + for test_id, test_case in enumerate(self.test_cases): + keyword_processor = KeywordProcessor() + for key in test_case['keyword_dict']: + keyword_processor.add_keywords_from_list(test_case['keyword_dict'][key]) + keywords_extracted = keyword_processor.extract_keywords(test_case['sentence'], span_info=True) + for kwd in keywords_extracted: + # returned keyword lowered should match the sapn from sentence + self.assertEqual( + kwd[0].lower(), test_case['sentence'].lower()[kwd[1]:kwd[2]], + "keywords span don't match the expected results for test case: {}".format(test_id)) + + def test_extract_keywords_case_sensitive(self): + """For each of the test case initialize a new KeywordProcessor. + Add the keywords the test case to KeywordProcessor. + Extract keywords and check if they match the expected result for the test case. + + """ + for test_id, test_case in enumerate(self.test_cases): + keyword_processor = KeywordProcessor(case_sensitive=True) + for key in test_case['keyword_dict']: + keyword_processor.add_keywords_from_list(test_case['keyword_dict'][key]) + keywords_extracted = keyword_processor.extract_keywords(test_case['sentence'], span_info=True) + for kwd in keywords_extracted: + # returned keyword should match the sapn from sentence + self.assertEqual( + kwd[0], test_case['sentence'][kwd[1]:kwd[2]], + "keywords span don't match the expected results for test case: {}".format(test_id)) + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_remove_keywords.py b/test/test_remove_keywords.py index 40010bc..4735a51 100644 --- a/test/test_remove_keywords.py +++ b/test/test_remove_keywords.py @@ -31,6 +31,21 @@ def test_remove_keywords(self): self.assertEqual(keywords_extracted, test_case['keywords'], "keywords_extracted don't match the expected results for test case: {}".format(test_id)) + def test_remove_keywords_using_list(self): + """For each of the test case initialize a new KeywordProcessor. + Add the keywords the test case to KeywordProcessor. + Remove the keywords in remove_keyword_dict + Extract keywords and check if they match the expected result for the test case. + """ + for test_id, test_case in enumerate(self.test_cases): + keyword_processor = KeywordProcessor() + keyword_processor.add_keywords_from_dict(test_case['keyword_dict']) + for key in test_case['remove_keyword_dict']: + keyword_processor.remove_keywords_from_list(test_case['remove_keyword_dict'][key]) + keywords_extracted = keyword_processor.extract_keywords(test_case['sentence']) + self.assertEqual(keywords_extracted, test_case['keywords'], + "keywords_extracted don't match the expected results for test case: {}".format(test_id)) + def test_remove_keywords_dictionary_compare(self): """For each of the test case initialize a new KeywordProcessor. Add the keywords the test case to KeywordProcessor.