Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom stop token ids for generation (by making eos_token_id List[int]) #1

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -212,7 +212,7 @@ def process(
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
Expand All @@ -234,6 +234,9 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
Expand All @@ -253,7 +256,7 @@ def process(
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
Expand Down Expand Up @@ -307,11 +310,14 @@ def finalize(
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
Expand Down Expand Up @@ -376,7 +382,8 @@ def finalize(
indices[i, : len(best_idx)] = torch.tensor(best_idx)

if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
# inserting only the first eos_token_id
decoded[i, sent_lengths[i]] = eos_token_id[0]

return UserDict(
{
Expand Down Expand Up @@ -491,7 +498,7 @@ def process(
next_indices: torch.LongTensor,
scores_for_all_vocab: torch.FloatTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
) -> Tuple[torch.Tensor]:
r"""
Args:
Expand Down Expand Up @@ -549,6 +556,9 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
Expand All @@ -568,7 +578,7 @@ def process(
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
if (eos_token_id is not None) and (next_token.item() in eos_token_id):

# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
Expand Down Expand Up @@ -773,10 +783,13 @@ def finalize(
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
Expand Down Expand Up @@ -840,7 +853,8 @@ def finalize(
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
# inserting only the first eos_token_id
decoded[i, sent_lengths[i]] = eos_token_id[0]

return UserDict(
{
Expand Down
48 changes: 30 additions & 18 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import inspect
import math
from typing import Callable, Iterable, List, Optional, Tuple
from typing import Callable, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -100,24 +100,27 @@ class MinLengthLogitsProcessor(LogitsProcessor):
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token.
"""

def __init__(self, min_length: int, eos_token_id: int):
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")

if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]):
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

self.min_length = min_length
self.eos_token_id = eos_token_id

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
scores[:, self.eos_token_id] = -float("inf")
for i in self.eos_token_id:
scores[:, i] = -float("inf")
return scores


Expand Down Expand Up @@ -395,11 +398,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`.
eos_token_id (`int`):
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token.
"""

def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):

if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
Expand All @@ -413,7 +416,10 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)

bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

bad_words_ids = list(filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids))
self.bad_words_id_length_1 = []
self.bad_words_id_length_greater_than_1 = []
for word in bad_words_ids:
Expand Down Expand Up @@ -628,20 +634,23 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`int`):
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached.
"""

def __init__(self, max_length: int, eos_token_id: int):
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
self.max_length = max_length
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len == self.max_length - 1:
num_tokens = scores.shape[1]
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
scores[:, self.eos_token_id] = 0
scores[:, [i for i in range(num_tokens) if i in self.eos_token_id]] = -float("inf")
for i in self.eos_token_id:
scores[:, i] = 0
return scores


Expand Down Expand Up @@ -671,23 +680,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`int`):
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token.
input_ids_seq_length (`int`):
The length of the input sequence.
"""

def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int):
def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id

def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start:
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
self.regulation_factor, cur_len - self.regulation_start
)
for i in self.eos_token_id:
scores[:, i] = scores[:, i] * pow(
self.regulation_factor, cur_len - self.regulation_start
)
return scores


Expand Down
Loading