Skip to content

fix(generation): stop beam search per-instance when heuristic satisfied #38778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3755,16 +3755,36 @@ def _gather_beams(tensor: torch.Tensor, beam_indices: torch.Tensor) -> torch.Ten
return gathered_tensor

@staticmethod
def _beam_search_has_unfinished_sequences(
def _get_improvement_possibility(
is_improvement_possible: torch.Tensor,
running_beam_scores: torch.Tensor,
beam_scores: torch.Tensor,
is_sent_finished: torch.Tensor,
next_token_hits_stopping_criteria: torch.Tensor,
cur_len: int,
max_length: int,
decoder_prompt_len: int,
early_stopping: Union[bool, str],
length_penalty: float,
):
"""
Check if there is a possibility to improve the finished beam_scores.
"""
if early_stopping == "never" and length_penalty > 0.0:
best_hypothetical_length = max_length - decoder_prompt_len
else:
best_hypothetical_length = cur_len - decoder_prompt_len
best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
return is_improvement_possible & torch.any(
best_possible_running_score > worst_finished_score, dim=-1, keepdim=True
)

@staticmethod
def _beam_search_has_unfinished_sequences(
is_improvement_possible: torch.Tensor,
is_sent_finished: torch.Tensor,
next_token_hits_stopping_criteria: torch.Tensor,
early_stopping: Union[bool, str],
):
"""
Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
Expand All @@ -3776,13 +3796,7 @@ def _beam_search_has_unfinished_sequences(
# early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the
# sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use
# `max_length` there.
if early_stopping == "never" and length_penalty > 0.0:
best_hypothetical_length = max_length - decoder_prompt_len
else:
best_hypothetical_length = cur_len - decoder_prompt_len
best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
improvement_possible = torch.any(best_possible_running_score > worst_finished_score)
improvement_possible = torch.any(is_improvement_possible)

# b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is
# enabled, where we want to finish as soon as all beams have a completed sequence.
Expand Down Expand Up @@ -3878,6 +3892,7 @@ def _update_finished_beams(
topk_log_probs: torch.Tensor,
beam_indices: torch.Tensor,
topk_running_beam_indices: torch.Tensor,
is_improvement_possible: torch.Tensor,
is_sent_finished: torch.Tensor,
next_token_hits_stopping_criteria: torch.Tensor,
top_num_beam_mask: torch.Tensor,
Expand All @@ -3902,6 +3917,9 @@ def _update_finished_beams(
# - make sure no scores can be added anymore if beam is full and early stopping is on
beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True)
topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9
# - make sure no scores can be added anymore if improvement is not possible
topk_log_probs += (~is_improvement_possible).to(torch.float32) * -1.0e9

# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9

Expand Down Expand Up @@ -4053,6 +4071,9 @@ def _beam_search(
# per batch, beam-item state bit indicating if sentence has finished.
is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)

# per batch state bit indicating if there is a possibility to improve the best finished sentence.
is_improvement_possible = torch.ones((batch_size, 1), dtype=torch.bool, device=input_ids.device)

# per batch, beam-item state bit indicating if there are valid continuations.
next_token_hits_stopping_criteria = torch.zeros(
(batch_size, num_beams), dtype=torch.bool, device=input_ids.device
Expand Down Expand Up @@ -4165,6 +4186,7 @@ def _beam_search(
topk_log_probs=topk_log_probs,
beam_indices=beam_indices,
topk_running_beam_indices=topk_running_beam_indices,
is_improvement_possible=is_improvement_possible,
is_sent_finished=is_sent_finished,
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
top_num_beam_mask=top_num_beam_mask,
Expand All @@ -4186,16 +4208,22 @@ def _beam_search(
)

cur_len = cur_len + 1
is_improvement_possible = self._get_improvement_possibility(
is_improvement_possible=is_improvement_possible,
running_beam_scores=running_beam_scores,
beam_scores=beam_scores,
is_sent_finished=is_sent_finished,
cur_len=cur_len,
max_length=max_length,
decoder_prompt_len=decoder_prompt_len,
early_stopping=early_stopping,
length_penalty=length_penalty,
)
this_peer_finished = not self._beam_search_has_unfinished_sequences(
running_beam_scores,
beam_scores,
is_improvement_possible,
is_sent_finished,
next_token_hits_stopping_criteria,
cur_len,
max_length,
decoder_prompt_len,
early_stopping,
length_penalty,
)

# 5. prepare outputs
Expand Down