Bound stop-token check to written tokens in dflash_generate#109
Open
SuperMarioYL wants to merge 1 commit intoz-lab:mainfrom
Open
Bound stop-token check to written tokens in dflash_generate#109SuperMarioYL wants to merge 1 commit intoz-lab:mainfrom
SuperMarioYL wants to merge 1 commit intoz-lab:mainfrom
Conversation
`dflash_generate` pre-allocates `output_ids` with `mask_token_id` past the
prompt (model.py:79). The in-loop early-exit check at the bottom of the
decode loop scanned the full pre-allocated tail:
if stop_token_ids is not None and any(
stop_token_id in output_ids[:, num_input_tokens:]
for stop_token_id in stop_token_ids
):
break
When `mask_token_id` happens to be one of the `stop_token_ids` (a
model-config-dependent edge case the project already cares about — see
PR z-lab#76 "Preserve output tokens that equal mask_token_id"), `mask_token_id`
in the unwritten tail of the buffer satisfies the `in` check on the very
first iteration and generation aborts after one block.
Aligning with the post-loop trim at model.py:151-155 — which already
uses `torch.isin` over the trimmed slice — the in-loop check now scopes
the scan to `output_ids[0, num_input_tokens : start + 1]`, i.e. positions
that have actually been written this run. The pre-allocated tensor is
hoisted out of the loop so both checks share it.
Tests
-----
Added `tests/test_model.py` covering:
* `build_target_layer_ids` interpolation (1-layer, 2-layer, 4-layer)
* `extract_context_feature` offset+concat shape and values
* `sample` argmax / temperature paths
* regression test for the buffer-scan pattern (legacy check fires
spuriously, new check does not) and a sibling test confirming the
new check still detects a real stop token after the cursor advances.
Wired in via a `[project.optional-dependencies] test` extra so existing
backends are unaffected:
uv pip install -e ".[test]"
python -m pytest tests/test_model.py -v
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
dflash_generatepre-allocatesoutput_idswithmask_token_idpast theprompt at
dflash/model.py:79. The in-loop early-exit check at the bottom of thedecode loop scanned the full pre-allocated tail:
When
mask_token_idhappens to be one of thestop_token_ids— amodel-config-dependent edge case the project already cares about
(see #76 "Preserve output tokens that equal mask_token_id") —
mask_token_idin the unwritten tail of the buffer satisfies theincheck on the very first iteration and generation aborts afterone block.
Fix
Aligning with the post-loop trim a few lines below — which already
uses
torch.isinover the trimmed slice — the in-loop check nowscopes the scan to
output_ids[0, num_input_tokens : start + 1],i.e. positions that have actually been written this run. The
pre-allocated
stop_token_tensoris hoisted out of the loop so boththe in-loop and post-loop checks share it.
Tests
Added
tests/test_model.pycovering pure-Python / pure-tensor logicthat runs on CPU without weights:
build_target_layer_idsinterpolation (1-/2-/4-layer cases)extract_context_featureoffset+concat shape and valuessampleargmax / temperature pathslegacy check firing spuriously, asserts the new check does not
stop token after the cursor advances
Wired in via a
[project.optional-dependencies] testextra soexisting backends are unaffected:
Refs #76.