diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 171a3251f01454..5b2170b697d48f 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -20,6 +20,7 @@ import itertools import re import unicodedata +from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple, Union, overload from .file_utils import PaddingStrategy, TensorType, add_end_docstrings @@ -102,7 +103,6 @@ def split(self, text: str) -> List[str]: >>> trie.split("[CLS] This is a extra_id_100") ["[CLS]", " This is a ", "extra_id_100"] """ - # indexes are counted left of the chars index. # "hello", index 0, is left of h, index 1 is between h and e. # index 5 is right of the "o". @@ -115,7 +115,7 @@ def split(self, text: str) -> List[str]: # If the trie contains, "blowing", and "lower" and we encounter the # string "blower", we need to split into ["b", "lower"]. # This is where we need to keep track of multiple possible starts. - states = {} + states = OrderedDict() # This will contain every indices where we need # to cut. @@ -144,36 +144,36 @@ def split(self, text: str) -> List[str]: # In this case, we already have partial matches (But unfinished) for start, trie_pointer in states.items(): - if current_char in trie_pointer: + if "" in trie_pointer: + # This is a final match, we need to reset and + # store the results in `offsets`. + + # Lookahead to match longest first + # Important in case of extra_id_1 vs extra_id_100 + lookahead_index = current + end = current + next_char = text[lookahead_index] if lookahead_index < len(text) else None + while next_char in trie_pointer: + trie_pointer = trie_pointer[next_char] + lookahead_index += 1 + if "" in trie_pointer: + end = lookahead_index + skip = lookahead_index + + if lookahead_index == len(text): + # End of string + break + next_char = text[lookahead_index] + # End lookahead + + # Storing and resetting + offsets.append(start) + offsets.append(end) + reset = True + elif current_char in trie_pointer: # The current character being looked at has a match within the trie # update the pointer (it will be stored back into states later). trie_pointer = trie_pointer[current_char] - if "" in trie_pointer: - # This is a final match, we need to reset and - # store the results in `offsets`. - - # Lookahead to match longest first - # Important in case of extra_id_1 vs extra_id_100 - lookahead_index = current + 1 - end = current + 1 - next_char = text[lookahead_index] if lookahead_index < len(text) else None - while next_char in trie_pointer: - trie_pointer = trie_pointer[next_char] - lookahead_index += 1 - if "" in trie_pointer: - end = lookahead_index - skip = lookahead_index - - if lookahead_index == len(text): - # End of string - break - next_char = text[lookahead_index] - # End lookahead - - # Storing and resetting - offsets.append(start) - offsets.append(end) - reset = True # Storing back the new pointer into the states. # Partial matches got longer by one. @@ -198,6 +198,18 @@ def split(self, text: str) -> List[str]: if current_char in self.data: states[current] = self.data[current_char] + # We have a cut at the end with states. + for start, trie_pointer in states.items(): + if "" in trie_pointer: + # This is a final match, we need to reset and + # store the results in `offsets`. + end = len(text) + offsets.append(start) + offsets.append(end) + # Longest cut is always the one with lower start so the first + # item so we need to break. + break + # We have all the offsets now, we just need to do the actual splitting. # We need to eventually add the first part of the string and the eventual # last part. diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 50aca9c4c8af8c..1a58f516927791 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -3562,3 +3562,15 @@ def test_trie_split(self): trie.add("extra_id_1") trie.add("extra_id_100") self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"]) + + def test_trie_single(self): + trie = Trie() + trie.add("A") + self.assertEqual(trie.split("ABC"), ["A", "BC"]) + self.assertEqual(trie.split("BCA"), ["BC", "A"]) + + def test_trie_final(self): + trie = Trie() + trie.add("TOKEN]") + trie.add("[SPECIAL_TOKEN]") + self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])