-
Notifications
You must be signed in to change notification settings - Fork 29.5k
fix(generation): stop beam search per-instance when heuristic satisfied #38778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix(generation): stop beam search per-instance when heuristic satisfied #38778
Conversation
Previously, when early_stopping is set to `False`, the early-stopping heuristic only halted generation when **all** batch instances reached the criterion. This caused instances that are impossible (suggested by the heuristic) to improve keep generating, leading to inconsistent and overlong outputs across the batch. Now we apply the heuristic **per-instance**: once a certain instance of batch has its all beams impossibe to improve, we mark that instance finished while letting others continue. This restores expected behavior and ensures consistency in batched generation.
cc @zucchini-nlp while @gante is out! |
Hi @gante, could you please review this? :) |
Hi @guang-yng 👋 Thank you for the PR, it definitely looks like it goes in the right direction! Before I approve, I want to do some small debugging of my own to understand better what's going on. In theory, if we say that a given batch item can't be improved with beam search (= its scores can never be better than current running scores), then we shouldn't need to add further logic to prevent that beam from being updated (= the changes in this PR) 👀 |
I see what's going on: prior to the refactor, we processed each batch individually, and we tagged a batch as done as soon as all sentences were Contrarily to what I have written above, under our default parameterization, we can have better scores than our estimate of best possible scores. This is done mostly for BC purposes, and I've opened a PR to better document it (#39000) These two things together mean that we changed the behavior with the refactor (better scores are possible + we don't isolate batches -> unwanted updates can occur) and that we want to merge this PR to revert this behavior change :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR is in great shape 💛
Four tiny requests:
- (see my comment below)
- Rebase this PR after [generate] document non-canonical beam search default behavior #39000 is merged, to ensure we keep the updated comments
- Let's add a comment somewhere explaining that we keep track of done beams to ensure we get the same output regardless of batch size
- It's missing a test to prevent regressions. We can add pretty much the snippet you shared in the PR header as a test in
GenerationIntegrationTests
:D
Hi @gante, thank you for the review! Regarding the test—would it be okay to use the
|
@guang-yng 1B models are fine for our CI 🤗 |
What does this PR do?
This PR fixes a bug in beam search generation where early stopping heuristics (when
early_stopping=False
) was incorrectly applied across the entire batch, instead of per instance.🔍 Problem
When
early_stopping=False
, the generation heuristic is supposed to stop generating once it’s unlikely that any beam will improve. However, the current behavior waits until all batch instances satisfy this heuristic before halting. This causes:✅ Fix
We now apply the early stopping heuristic per-instance. As soon as a single instance has no beams left that can improve, generation for that instance is not used for updating answers. This restores expected behavior and leads to:
🧪 Reproduction Example
Working case (single input)
Produces clean output:
Broken case (batched input)
Produces repetitive output:
This undesirable repetition happens only when batched with longer examples. It can occur even with default settings like
length_penalty=1
.This bug appears in recent versions with vectorized beam search. It does not appear in transformers < 4.50.0.
Who can review?
@gante Could you please take a look at this? Thanks!