Skip to content

Commit

Permalink
Fix null byte \x00 issue by switching to numba.types.unicode_type
Browse files Browse the repository at this point in the history
  • Loading branch information
M0gician authored and lapp0 committed May 30, 2024
1 parent 538f77a commit 4a5ef55
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,11 @@ def fsm_info(self):
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
dtype=np.dtype("int64, int64"),
)
alphabet_symbol_mapping_items = np.fromiter(
(
it
for it in self.alphabet._symbol_mapping.items()
if it[0] != anything_else
),
dtype=np.dtype("U2, int64"),
)
alphabet_symbol_mapping_items = [
(k, v)
for k, v in self.alphabet._symbol_mapping.items()
if k != anything_else
]
nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64"))
self.__dict__["_fsm_info"] = create_fsm_info(
self.initial,
Expand All @@ -110,7 +107,7 @@ def fsm_info(self):

nb_int_list_type = numba.types.ListType(numba.int64)
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
nb_unichar_2_type = numba.types.UnicodeCharSeq(2)
nb_unicode_type = numba.types.unicode_type


@numba.njit(cache=True)
Expand All @@ -136,7 +133,7 @@ def create_fsm_info(

# use 2-char strings so that we can represent incomplete utf-8 sequences
# as 2-hex-digit pairs
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_2_type, numba.int64)
alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64)
for symbol_and_trans_key in alphabet_symbol_mapping_items:
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]

Expand Down Expand Up @@ -804,7 +801,7 @@ def reduced_vocabulary(
raise RuntimeError(
f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}"
)
token_str = tuple(byte_symbol(b) for b in token_bytes)
token_str = "".join(byte_symbol(b) for b in token_bytes)

vocabulary.setdefault(token_str, []).append(token_idx)
else:
Expand All @@ -813,15 +810,14 @@ def reduced_vocabulary(
vocabulary_nb = numba.typed.List.empty_list(
numba.types.Tuple(
(
nb_unichar_2_type[:],
nb_unicode_type,
numba.int64[:],
)
)
)
for token_tuple, token_ids in vocabulary.items():
token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2"))
for token_str, token_ids in vocabulary.items():
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token_tuple_np, token_ids_np))
vocabulary_nb.append((token_str, token_ids_np))

return vocabulary_nb, empty_token_ids

Expand Down

0 comments on commit 4a5ef55

Please sign in to comment.