Skip to content

Commit

Permalink
Implement token alignment for StopAtEosFSM and RegexFSM
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Feb 17, 2024
1 parent e99d92d commit 4aa74f2
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 22 deletions.
237 changes: 225 additions & 12 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import TYPE_CHECKING, List, NewType, Protocol, Tuple
from collections import defaultdict
from copy import deepcopy
from typing import TYPE_CHECKING, Dict, List, NewType, Protocol, Tuple

import interegular
import torch
from lark import Lark

# from outlines.fsm.parsing import PartialLark
Expand All @@ -22,6 +25,11 @@ def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state == self.final_state

def align_prompt_tokens(
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
...

def allowed_token_ids(self, state: FSMState) -> List[int]:
...

Expand All @@ -37,13 +45,41 @@ class StopAtEosFSM(FSM):

def __init__(self, tokenizer: "Tokenizer"):
self.eos_token_id = tokenizer.eos_token_id
self.vocabulary = tokenizer.vocabulary.values()
self.vocabulary = tokenizer.vocabulary
self.tokenizer = tokenizer
self.states_to_token_maps = self.create_states_to_tokens_map()

def create_states_to_tokens_map(self) -> Dict[int, Dict[int, int]]:
"""Create the states_to_tokens_map. All tokens from the starting state lead
to itself, except for the eos_token that leads to the final state."""
return {
self.first_state: {
token_id: self.first_state
if token_id != self.eos_token_id
else self.final_state
for token_id in self.vocabulary.values()
}
}

def align_prompt_tokens(
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
(
token_ids,
attention_masks,
self.states_to_token_maps,
) = align_tokens_states_to_token_maps(
token_ids, attention_masks, self.vocabulary, self.states_to_token_maps
)
return token_ids, attention_masks

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
When in the initial state we allow every token to be generated.
In the final state the only allowed token is `stop_token_id`.
Otherwise we allow the valid transitions tokens corresponding to
the current state of the states_to_token_maps
Parameters
----------
Expand All @@ -57,14 +93,13 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
"""
if self.is_final_state(state):
return [self.eos_token_id]
return list(self.vocabulary)
return list(self.states_to_token_maps[state].keys())

def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM.
The FSM stays in the initial state `0` unless the specified stop token
has been generated or the maximum number of tokens has been reached. In
which case the FSM moves to the final state `-1`.
The FSM transitions from a state to the other through the
states_to_token_maps until the final state is reached.
Parameters
----------
Expand All @@ -78,14 +113,14 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
The new state of the FSM.
"""
if token_id == self.eos_token_id:
if self.is_final_state(state):
return self.final_state

return self.first_state
return FSMState(self.states_to_token_maps[state][token_id])

def copy(self) -> "StopAtEosFSM":
"""Create a copy of the FSM."""
return self
return deepcopy(self)


class RegexFSM(FSM):
Expand Down Expand Up @@ -121,9 +156,22 @@ def create_states_mapping(
self.states_to_token_maps, self.empty_token_ids = create_states_mapping(
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
)
self.vocabulary = tokenizer.vocabulary.values()
self.vocabulary = tokenizer.vocabulary
self.eos_token_id = tokenizer.eos_token_id

def align_prompt_tokens(
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
(
token_ids,
attention_masks,
self.states_to_token_maps,
) = align_tokens_states_to_token_maps(
token_ids, attention_masks, self.vocabulary, self.states_to_token_maps
)
return token_ids, attention_masks

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down Expand Up @@ -184,7 +232,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:

def copy(self) -> "RegexFSM":
"""Create a copy of the FSM."""
return self
return deepcopy(self)


class CFGFSM(FSM):
Expand Down Expand Up @@ -218,6 +266,12 @@ def __init__(self, cfg_string: str, tokenizer):
self.proposal_last: List[int] = []
self.regex_fsm_last: RegexFSM

def align_prompt_tokens(
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Not applicable to this type of FSM"""
return token_ids, attention_masks

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down Expand Up @@ -333,3 +387,162 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
def copy(self) -> "CFGFSM":
"""Create a copy of the FSM."""
return CFGFSM(self.cfg_string, self.tokenizer)


def align_tokens_states_to_token_maps(
token_ids: torch.Tensor,
attention_masks: torch.Tensor,
vocabulary: Dict[str, int],
states_to_token_maps: Dict[int, Dict[int, int]],
) -> Tuple[torch.Tensor, torch.Tensor, Dict[int, Dict[int, int]]]:
"""Apply token alignment to the provided prompt tokens and attention masks given the
states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
states_to_token_maps"""
prompt_token_ids = token_ids.tolist()
crossing_tokens = find_crossing_tokens(prompt_token_ids, vocabulary)
valid_crossing_tokens = get_crossing_tokens_target_states(
states_to_token_maps, crossing_tokens, prompt_token_ids, vocabulary
)
if not valid_crossing_tokens:
return token_ids, attention_masks, states_to_token_maps
(
states_to_token_maps,
number_cropped_tokens,
) = add_crossing_tokens_states_to_tokens_map(
states_to_token_maps, prompt_token_ids, valid_crossing_tokens
)
return (
token_ids[:-number_cropped_tokens],
attention_masks[:-number_cropped_tokens],
states_to_token_maps,
)


def find_crossing_tokens(
token_ids: List[int], vocabulary: Dict[str, int]
) -> Dict[int, List[int]]:
"""Find the tokens that could replace one or more tokens at the end of token_ids
while conserving the same intial text (and extending it by at least one character).
Return a dictionary with, for the indexes in the token_ids, the associated crossing tokens.
"""
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
len_token_ids = len(token_ids)
max_length_token_text = max(len(item) for item in vocabulary.keys())
characters_considered = ""
crossing_tokens_map = {}

for index, token_id in enumerate(reversed(token_ids)):
characters_considered = reversed_vocabulary[token_id] + characters_considered
if len(characters_considered) >= max_length_token_text:
break
crossing_token_ids = [
token_id
for text, token_id in vocabulary.items()
if text.startswith(characters_considered)
and len(text) > len(characters_considered)
]
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids

return crossing_tokens_map


def get_crossing_tokens_target_states(
states_to_tokens_map: Dict[int, Dict[int, int]],
crossing_tokens: Dict[int, List[int]],
prompt_token_ids: List[int],
vocabulary: Dict[str, int],
) -> Dict[int, Dict[int, int]]:
"""For each crossing token associated to an index, check that the characters after the boundary
match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
provided indexes, the associated valid tokens with the state they would lead to.
"""
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
prompt_token_texts = [
reversed_vocabulary[token_id] for token_id in prompt_token_ids
]

valid_crossing_tokens: Dict[int, Dict[int, int]] = defaultdict(dict)
for pos, tokens in crossing_tokens.items():
for token in tokens:
is_valid = True
characters = reversed_vocabulary[token]
characters_before_border = "".join(prompt_token_texts[pos:])
characters_after_border = characters[len(characters_before_border) :]
state = 0
for char in characters_after_border:
char_token = vocabulary.get(char)
try:
state = states_to_tokens_map[state][char_token] # type: ignore
except KeyError:
is_valid = False
break
if is_valid:
valid_crossing_tokens[pos][token] = state

return valid_crossing_tokens


def add_crossing_tokens_states_to_tokens_map(
states_to_tokens_map: Dict[int, Dict[int, int]],
prompt_token_ids: List[int],
crossing_tokens_map: Dict[int, Dict[int, int]],
) -> Tuple[Dict[int, Dict[int, int]], int]:
"""Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
the starting state of the fsm as we would include some characters at the end of the prompt in
the states_to_tokens_map.
Attention! the starting state of the states_to_tokens_map provided must be 0.
Return the updated states_to_tokens_map and the number of cropped tokens/additional states
"""
if not crossing_tokens_map:
return states_to_tokens_map, 0
first_crossing_token_pos = min(
[key for key, value in crossing_tokens_map.items() if value]
)
number_additional_states = len(prompt_token_ids) - first_crossing_token_pos
highest_state = max(
max(states_to_tokens_map.keys()),
max(max(items.values()) for items in states_to_tokens_map.values()),
)

for i in range(number_additional_states):
# add the tokens that was originally part of the prompt
if i == number_additional_states - 1:
states_to_tokens_map[highest_state + 1 + i] = {
prompt_token_ids[first_crossing_token_pos + i]: 0
}
else:
states_to_tokens_map[highest_state + 1 + i] = {
prompt_token_ids[first_crossing_token_pos + i]: highest_state + 2 + i
}
# add the crossing tokens
crossing_tokens = crossing_tokens_map.get(first_crossing_token_pos + i)
if crossing_tokens:
for token, target_state in crossing_tokens.items():
states_to_tokens_map[highest_state + 1 + i][token] = target_state

# set the id of our new initial state to 0
states_to_tokens_map = swap_state_ids_states_to_tokens_map(
states_to_tokens_map, highest_state + 1, 0
)
return states_to_tokens_map, number_additional_states


def swap_state_ids_states_to_tokens_map(
states_to_tokens_map: Dict[int, Dict[int, int]],
first_state_id: int,
second_state_id: int,
) -> Dict[int, Dict[int, int]]:
"""Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
first_state_transitions = states_to_tokens_map.pop(first_state_id)
second_state_transitions = states_to_tokens_map.pop(second_state_id)
states_to_tokens_map[first_state_id] = second_state_transitions
states_to_tokens_map[second_state_id] = first_state_transitions

for transitions in states_to_tokens_map.values():
for token, target_state_id in list(transitions.items()):
if target_state_id == first_state_id:
transitions[token] = second_state_id
elif target_state_id == second_state_id:
transitions[token] = first_state_id

return states_to_tokens_map
Loading

0 comments on commit 4aa74f2

Please sign in to comment.