Preserve output tokens that equal mask_token_id#76
Merged
jianc99 merged 1 commit intoz-lab:mainfrom Apr 17, 2026
Merged
Conversation
The final cleanup block in dflash_generate used
output_ids = output_ids[:, output_ids[0] != mask_token_id]
to strip trailing positions that the generation loop had not yet written
into. The check is value-based, so any legitimate output token whose id
happens to equal mask_token_id is silently dropped from the middle of
the sequence.
The generation loop already tracks the last written position via
`start`, so a length-based slice is equivalent in the normal case and
correct in the pathological case. Replace the two trailing lines with a
single slice that caps at max_length:
output_ids = output_ids[:, :min(start + 1, max_length)]
This preserves the cap on total generation length (which used to be
enforced by the `:max_length` line) and removes the dependency on the
sampled token values never colliding with mask_token_id.
Refs: z-lab#73
Member
|
Good catch! Thanks. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #73.
Problem
dflash_generate's final cleanup atdflash/model.py:150-151strips trailingmask-padded positions by value, not position:
Any legitimate output token whose id equals
mask_token_idis silentlydropped from the middle of the sequence. The published configs avoid the
worst case by choosing
mask_token_idfrom each tokenizer's reservedzone, but the cleanup is still fragile by construction (a misconfigured
tokenizer or a model that emits the special token loses tokens).
Reproduction without a model:
Fix
The generation loop already tracks the last written index via
start(used at line 139 for
past_key_values_target.crop(start)). Alength-based slice that caps at
max_lengthis equivalent in the normalcase and correct in the pathological case:
+ 1is needed becauseoutput_ids[:, start]holds the posterior tokensampled in the final iteration (written at index
old_start + acceptance_length + 1and thenstartadvances to thatindex). Capping at
max_lengthpreserves the existingmax_new_tokensguarantee on natural exit.
Equivalence in the normal case
Numbers from a small mental simulation (
num_input_tokens=5,max_length=10,block_size=5):A unit-style assertion is included as a comment in
.analysis/verify/k2_fix_test.pyin the analysis tree (not committed to the repo) to confirm parity on the
first three rows and divergence on the fourth.
Notes
This is one of three small correctness fixes I'm sending against the
recently merged repo (also #74, #75). Each is on its own branch so they
can be merged or rejected independently.