Skip to content

Commit

Permalink
Implement token alignment for RegexFSM and StopAtFSM
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Jan 30, 2024
1 parent a04e8d4 commit 01bfc21
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 44 deletions.
115 changes: 76 additions & 39 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from copy import deepcopy
from typing import TYPE_CHECKING, List, NewType, Protocol
from typing import TYPE_CHECKING, Dict, List, NewType, Protocol, Tuple

import cloudpickle
import interegular
from lark import Lark

Expand All @@ -15,6 +14,9 @@


class FSM(Protocol):
def align_prompt_tokens(self, prompt: str) -> str:
...

def allowed_token_ids(self, state: FSMState) -> List[int]:
...

Expand All @@ -39,8 +41,23 @@ class StopAtTokenFSM(FSM):

def __init__(self, tokenizer: "Tokenizer", stop_token_id: int):
self.stop_token_id = stop_token_id
self.vocabulary = tokenizer.vocabulary.values()
self.final_states = {1}
self.tokenizer = tokenizer
self.vocabulary = tokenizer.vocabulary
self.final_states = {2}
self.valid_alignment_tokens: List[int] = []

def align_prompt_tokens(self, prompt: str) -> str:
"""Remove the last token from the prompt and set the value of self.valid_alignment_tokens"""
token_ids, _ = self.tokenizer.encode(prompt)
last_token_id = int(token_ids[0][-1])
last_token_text = self.tokenizer.decode([last_token_id])[0]
# select the tokens that start with the text removed from the prompt
self.valid_alignment_tokens = [
token
for text, token in self.vocabulary.items()
if text.startswith(last_token_text)
]
return prompt[: -len(last_token_text)]

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand All @@ -59,7 +76,9 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
"""
if state == 0:
return list(self.vocabulary)
return self.valid_alignment_tokens
elif state == 1:
return list(self.vocabulary.values())
else:
return [self.stop_token_id]

Expand All @@ -83,17 +102,17 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""
if token_id == self.stop_token_id:
return FSMState(1)
return FSMState(2)

return FSMState(0)
return FSMState(1)

def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states

def copy(self) -> "StopAtTokenFSM":
"""Create a copy of the FSM."""
return self
return deepcopy(self)


class RegexFSM(FSM):
Expand Down Expand Up @@ -122,41 +141,61 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
-1
} # Include the EOS token in final states
self.tokenizer = tokenizer
self.vocabulary = tokenizer.vocabulary.values()
self.vocabulary = tokenizer.vocabulary
self.end_token_id = tokenizer.eos_token_id

def align_prompt_tokens(self, prompt: str) -> str:
"""Remove the last token from the prompt and update the states_to_token_maps accordingly"""
token_ids, _ = self.tokenizer.encode(prompt)
last_token_id = int(token_ids[0][-1])
last_token_text = self.tokenizer.decode([last_token_id])[0]
vocabulary = {
self.tokenizer.decode([token_id])[0]: token_id
for token_id in range(len(self.vocabulary))
}
starting_state_tokens = {
self.tokenizer.decode([token_id])[0]: self.states_to_token_maps[0][token_id]
for token_id in self.states_to_token_maps[0]
}
# select the tokens that start with the text removed from the prompt and whose text after the
# initial prompt corresponds to that of one of the allowed tokens of the starting state
possible_tokens = {
vocabulary[token_text]: starting_state_tokens[token_text[len(last_token_text):]]
for token_text in vocabulary
if (
token_text.startswith(last_token_text)
and starting_state_tokens.get(token_text[len(last_token_text):])
)
last_token_length = len(last_token_text)
# select the tokens that start with the text removed from the prompt
crossing_tokens = {
token: text
for text, token in self.vocabulary.items()
if text.startswith(last_token_text)
}
# keep only the tokens whose text after the boundary matches the fsm
valid_tokens_states = self.find_valid_crossing_tokens(
crossing_tokens, last_token_length
)
# update the states_to_token_maps in the following manner:
# the value of the starting state is assigned to a new state, the starting state is now the
# possible_tokens found above + the last_token we removed (that leads to the new state)
additional_state_id = max(list(self.states_to_token_maps.keys()) + list(self.final_states)) + 1
# valid_tokens_states found above
additional_state_id = (
max(list(self.states_to_token_maps.keys()) + list(self.final_states)) + 1
)
self.states_to_token_maps[additional_state_id] = self.states_to_token_maps[0]
self.states_to_token_maps[0] = {**possible_tokens, last_token_id: additional_state_id}

return prompt[:-len(last_token_text)]

self.states_to_token_maps[0] = {}
for token, state in valid_tokens_states:
if state == 0:
self.states_to_token_maps[0][token] = additional_state_id
else:
self.states_to_token_maps[0][token] = state
return prompt[: -len(last_token_text)]

def find_valid_crossing_tokens(
self, crossing_tokens: Dict[int, str], last_token_length: int
) -> List[Tuple[int, int]]:
"""For each crossing token, check that the characters after the boundary match the FSM
and find the state it would lead to. Return the valid tokens with the associated state
"""
valid_tokens = []
for token, text in crossing_tokens.items():
is_valid = True
crossing_text = text[last_token_length:]
state = 0
for char in crossing_text:
char_token = self.vocabulary.get(char)
try:
state = self.states_to_token_maps[state][char_token] # type: ignore
except KeyError:
is_valid = False
break
if is_valid:
valid_tokens.append((token, state))
return valid_tokens

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down Expand Up @@ -222,12 +261,7 @@ def is_final_state(self, state: FSMState) -> bool:

def copy(self) -> "RegexFSM":
"""Create a copy of the FSM."""
# temporary solution to the problem of unpickleable dict_values
self.vocabulary = cloudpickle.dumps(self.vocabulary)
copy = deepcopy(self)
self.vocabulary = cloudpickle.loads(self.vocabulary)
copy.vocabulary = cloudpickle.loads(copy.vocabulary)
return copy
return deepcopy(self)


class CFGFSM(FSM):
Expand Down Expand Up @@ -257,6 +291,10 @@ def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
self.done = False
self.regex_fsm: RegexFSM

def align_prompt_tokens(self, prompt: str) -> str:
"""Not implemented for CFGFSM"""
return prompt

def _set_next_regex_fsm(self) -> None:
"""Use the CFG incremental parser to set the next regex FSM.
Expand All @@ -278,7 +316,6 @@ def _set_next_regex_fsm(self) -> None:
self.allow_eos = True
options.add("")
assert len(options) > 1

regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
self.reset_state = True
Expand Down
18 changes: 13 additions & 5 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json as pyjson
import warnings
from copy import deepcopy
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -77,12 +78,15 @@ def get_generated_token_ids(
return token_ids

def get_generated_sequences(
self, generated_token_ids: List[torch.Tensor], initial_prompts: List[str], prompts: List[str]
self,
generated_token_ids: List[torch.Tensor],
initial_prompts: List[str],
prompts: List[str],
) -> List[str]:
"""Give the text sequences generated based on the tokens generated and the initial prompts"""
generated_tokens_text = self.tokenizer.decode(generated_token_ids)
return [
generated_tokens_text[i][len(initial_prompts[i]) - len(prompts[i]):]
generated_tokens_text[i][len(initial_prompts[i]) - len(prompts[i]) :]
for i in range(len(generated_tokens_text))
]

Expand Down Expand Up @@ -196,7 +200,7 @@ def __call__(

if isinstance(prompts, str):
prompts = [prompts]
initial_prompts = copy.deepcopy(prompts)
initial_prompts = deepcopy(prompts)

if isinstance(stop_at, str):
stop_at = [stop_at]
Expand All @@ -205,7 +209,9 @@ def __call__(
max_tokens = max_tokens or self.max_tokens
num_sequences = len(prompts)
fsms = [self.fsm.copy() for _ in prompts]
prompts = [fsm.align_prompt_tokens(prompt) for fsm, prompt in zip(fsms, prompts)]
prompts = [
fsm.align_prompt_tokens(prompt) for fsm, prompt in zip(fsms, prompts)
]

if rng is None:
rng = torch.Generator(device=self.device)
Expand Down Expand Up @@ -239,7 +245,9 @@ def __call__(
generated_token_ids = self.get_generated_token_ids(
init_state, initial_prompts, last_state
)
generated = self.get_generated_sequences(generated_token_ids, initial_prompts, prompts)
generated = self.get_generated_sequences(
generated_token_ids, initial_prompts, prompts
)
stripped = [
self.strip_stop_sequences(sequence, stop_sequences)
for sequence in generated
Expand Down

0 comments on commit 01bfc21

Please sign in to comment.