Skip to content

Commit

Permalink
Add unit tests for token alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Mar 10, 2024
1 parent 4aa74f2 commit 6bb90f8
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 45 deletions.
5 changes: 3 additions & 2 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def find_crossing_tokens(
) -> Dict[int, List[int]]:
"""Find the tokens that could replace one or more tokens at the end of token_ids
while conserving the same intial text (and extending it by at least one character).
Return a dictionary with, for the indexes in the token_ids, the associated crossing tokens.
Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
"""
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
len_token_ids = len(token_ids)
Expand All @@ -441,7 +441,8 @@ def find_crossing_tokens(
if text.startswith(characters_considered)
and len(text) > len(characters_considered)
]
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids
if crossing_token_ids:
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids

return crossing_tokens_map

Expand Down
95 changes: 54 additions & 41 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Iterator, List, Optional, Union
from typing import Iterator, List, Optional, Tuple, Union

import torch

from outlines.fsm.fsm import FSMState
from outlines.fsm.fsm import FSM, FSMState
from outlines.generate.generator import sequence_generator


Expand All @@ -21,6 +21,53 @@ def __init__(
self.device = device
self.num_samples = sampler.samples

def align_prompt_tokens(
self,
prompt_token_ids: torch.Tensor,
attention_masks: torch.Tensor,
fsms: List[FSM],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Implement token alignment for each fsm. Return the updated tokens_ids and attention_masks"""
aligned_prompts, aligned_masks = zip(
*[
fsm.align_prompt_tokens(prompt, mask)
for prompt, mask, fsm in zip(prompt_token_ids, attention_masks, fsms)
]
)
# We have to pad some of the prompts if they are not all of the same length after this operation
max_length_aligned_prompt = max(prompt.shape[0] for prompt in aligned_prompts)
padded_aligned_prompts = [
torch.cat(
[
torch.full(
(max_length_aligned_prompt - prompt.shape[0],),
0,
device=prompt_token_ids.device,
dtype=prompt.dtype,
),
prompt,
]
)
for prompt in aligned_prompts
]
padded_aligned_masks = [
torch.cat(
[
torch.full(
(max_length_aligned_prompt - mask.shape[0],),
0,
device=prompt_token_ids.device,
dtype=mask.dtype,
),
mask,
]
)
for mask in aligned_masks
]
aligned_prompt_token_ids = torch.stack(padded_aligned_prompts)
aligned_attention_masks = torch.stack(padded_aligned_masks)
return aligned_prompt_token_ids, aligned_attention_masks

def get_generated_token_ids(
self,
prompt_token_ids: torch.Tensor,
Expand Down Expand Up @@ -189,49 +236,15 @@ def __call__(
num_samples = self.num_samples
batch_size = len(prompts)

fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)]
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]

prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)

# Token alignment may shorten some of the prompts by removing tokens at their end.
# We have to pad some of the prompts if they are not all of the same length after this operation
aligned_prompts, aligned_masks = zip(
*[
fsm.align_prompt_tokens(prompt, mask)
for prompt, mask, fsm in zip(prompt_token_ids, attention_masks, fsms)
]
fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)]
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]

aligned_prompt_token_ids, aligned_attention_masks = self.align_prompt_tokens(
prompt_token_ids, attention_masks, fsms
)
max_length_aligned_prompt = max(prompt.shape[0] for prompt in aligned_prompts)
padded_aligned_prompts = [
torch.cat(
[
torch.full(
(max_length_aligned_prompt - prompt.shape[0],),
0,
dtype=prompt.dtype,
),
prompt,
]
)
for prompt in aligned_prompts
]
padded_aligned_masks = [
torch.cat(
[
torch.full(
(max_length_aligned_prompt - mask.shape[0],),
0,
dtype=mask.dtype,
),
mask,
]
)
for mask in aligned_masks
]
aligned_prompt_token_ids = torch.stack(padded_aligned_prompts)
aligned_attention_masks = torch.stack(padded_aligned_masks)

weights = torch.zeros(
(batch_size * num_samples), dtype=torch.float, device=self.device
Expand Down
Loading

0 comments on commit 6bb90f8

Please sign in to comment.