Skip to content

fix(generation): stop beam search per-instance when heuristic satisfied #38778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

guang-yng
Copy link

What does this PR do?

This PR fixes a bug in beam search generation where early stopping heuristics (when early_stopping=False) was incorrectly applied across the entire batch, instead of per instance.

🔍 Problem

When early_stopping=False, the generation heuristic is supposed to stop generating once it’s unlikely that any beam will improve. However, the current behavior waits until all batch instances satisfy this heuristic before halting. This causes:

  • Instances that are already “done” (according to the heuristic) to continue generating,
  • Unnecessarily long and repetitive outputs,
  • Inconsistent behavior depending on batch composition.

✅ Fix

We now apply the early stopping heuristic per-instance. As soon as a single instance has no beams left that can improve, generation for that instance is not used for updating answers. This restores expected behavior and leads to:

  • Consistency between single-instance and batched generation,
  • Parity with behavior in transformers < 4.50.

🧪 Reproduction Example

Working case (single input)

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

olmo_model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B-Instruct")
olmo_model = olmo_model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0425-1B-Instruct", padding_side="left")
generation_config = GenerationConfig(
    num_beams=10,
    max_new_tokens=256,
    length_penalty=2,
)

question = [ {"role": "user", "content": "What is 3+5?"} ]

question = tokenizer.apply_chat_template(
    question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)

inputs = tokenizer(question, return_tensors="pt", padding=True).to("cuda")

outputs = olmo_model.generate(
    **inputs,
    generation_config=generation_config,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

Produces clean output:

The sum of 3 and 5 is 8. 
So, 3 + 5 = 8. 
...
The sum of 3 and 5 is \(\boxed{8}\).

Broken case (batched input)

question = [ {"role": "user", "content": "What is 3+5?"} ]
cot_question = [ {"role": "user", "content": "What is 3+5? Explain your reasoning step by step, and provide the final answer at the end."} ]

question = tokenizer.apply_chat_template(
    question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
cot_question = tokenizer.apply_chat_template(
    cot_question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)

inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to("cuda")

outputs = olmo_model.generate(
    **inputs,
    generation_config=generation_config,
)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(responses[0])

Produces repetitive output:

The sum of 3 and 5 is 8. 
...
The sum of \(3 + 5\) is \(\boxed{8}\).

If you have any more questions or need further assistance, feel free to ask!
The sum of \(3 + 5\) is \(\boxed{8}\).

If you have any more questions or need further assistance, feel free to ask!
The sum of \(3 + 5\) is \(\boxed{8}\). 

If you have any more questions or need further assistance, feel free to ask!
The sum of \(3 + 5\) is \(\boxed{8}\). 

If you have any more questions or need further assistance, feel free to ask!
The sum of \(3 + 5\) is \(\boxed{8}\). 

If you have any more questions or need further assistance, feel free to ask!

This undesirable repetition happens only when batched with longer examples. It can occur even with default settings like length_penalty=1.

This bug appears in recent versions with vectorized beam search. It does not appear in transformers < 4.50.0.

Who can review?

@gante Could you please take a look at this? Thanks!

Previously, when early_stopping is set to `False`, the early-stopping heuristic only halted generation when **all** batch instances reached the criterion. This caused instances that are impossible (suggested by the heuristic) to improve keep generating, leading to inconsistent and overlong outputs across the batch.

Now we apply the heuristic **per-instance**: once a certain instance of batch has its all beams impossibe to improve, we mark that instance finished while letting others continue. This restores expected behavior and ensures consistency in batched generation.
@Rocketknight1
Copy link
Member

cc @zucchini-nlp while @gante is out!

@zucchini-nlp
Copy link
Member

For beam search I suggest to wait for @gante, because it was recently refactored and probably he @gante will be a better person to decide, if this is a bug or intended change from refactoring. He will be coming back next week :)

@guang-yng
Copy link
Author

Hi @gante, could you please review this? :)

@gante
Copy link
Member

gante commented Jun 24, 2025

Hi @guang-yng 👋

Thank you for the PR, it definitely looks like it goes in the right direction!

Before I approve, I want to do some small debugging of my own to understand better what's going on. In theory, if we say that a given batch item can't be improved with beam search (= its scores can never be better than current running scores), then we shouldn't need to add further logic to prevent that beam from being updated (= the changes in this PR) 👀

@gante
Copy link
Member

gante commented Jun 24, 2025

@guang-yng

I see what's going on: prior to the refactor, we processed each batch individually, and we tagged a batch as done as soon as all sentences were done (e.g. early stopping criteria were satisfied). No further changes were allowed. That is not the case after the vectorization refactor: all batches are processed together, and we rely on logic/math to mask unwanted updates.

Contrarily to what I have written above, under our default parameterization, we can have better scores than our estimate of best possible scores. This is done mostly for BC purposes, and I've opened a PR to better document it (#39000)

These two things together mean that we changed the behavior with the refactor (better scores are possible + we don't isolate batches -> unwanted updates can occur) and that we want to merge this PR to revert this behavior change :)

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is in great shape 💛

Four tiny requests:

  1. (see my comment below)
  2. Rebase this PR after [generate] document non-canonical beam search default behavior #39000 is merged, to ensure we keep the updated comments
  3. Let's add a comment somewhere explaining that we keep track of done beams to ensure we get the same output regardless of batch size
  4. It's missing a test to prevent regressions. We can add pretty much the snippet you shared in the PR header as a test in GenerationIntegrationTests :D

@guang-yng
Copy link
Author

Hi @gante, thank you for the review!

Regarding the test—would it be okay to use the "allenai/OLMo-2-0425-1B-Instruct" model? It’s a 1B model, so I’m wondering if it might be too large for a test case.

  1. It's missing a test to prevent regressions. We can add pretty much the snippet you shared in the PR header as a test in GenerationIntegrationTests :D

@gante
Copy link
Member

gante commented Jun 25, 2025

@guang-yng 1B models are fine for our CI 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants