Skip to content

Preserve output tokens that equal mask_token_id#76

Merged
jianc99 merged 1 commit intoz-lab:mainfrom
shaun0927:fix/mask-token-strip
Apr 17, 2026
Merged

Preserve output tokens that equal mask_token_id#76
jianc99 merged 1 commit intoz-lab:mainfrom
shaun0927:fix/mask-token-strip

Conversation

@shaun0927
Copy link
Copy Markdown
Contributor

Closes #73.

Problem

dflash_generate's final cleanup at dflash/model.py:150-151 strips trailing
mask-padded positions by value, not position:

output_ids = output_ids[:, :max_length]
output_ids = output_ids[:, output_ids[0] != mask_token_id]

Any legitimate output token whose id equals mask_token_id is silently
dropped from the middle of the sequence. The published configs avoid the
worst case by choosing mask_token_id from each tokenizer's reserved
zone, but the cleanup is still fragile by construction (a misconfigured
tokenizer or a model that emits the special token loses tokens).

Reproduction without a model:

import torch
mask_token_id = 0
buf = torch.tensor([[11,12,13, 101,102,0,103,104, 0,0,0]])  # last 3 = unfilled
stripped = buf[:, buf[0] != mask_token_id]
# [[11, 12, 13, 101, 102, 103, 104]]   -- the legitimate 0 disappears

Fix

The generation loop already tracks the last written index via start
(used at line 139 for past_key_values_target.crop(start)). A
length-based slice that caps at max_length is equivalent in the normal
case and correct in the pathological case:

output_ids = output_ids[:, :min(start + 1, max_length)]

+ 1 is needed because output_ids[:, start] holds the posterior token
sampled in the final iteration (written at index
old_start + acceptance_length + 1 and then start advances to that
index). Capping at max_length preserves the existing max_new_tokens
guarantee on natural exit.

Equivalence in the normal case

Numbers from a small mental simulation (num_input_tokens=5, max_length=10,
block_size=5):

Termination Original output length Patched output length
Natural exit (start overshoots, no in-range mask) 10 10
Early stop, stop token at posterior position 10 (incl. stop) 10 (incl. stop)
Early stop, stop token mid-sequence 9 9
Natural exit, generated token equals mask_token_id 9 (bug) 10 (correct)

A unit-style assertion is included as a comment in .analysis/verify/k2_fix_test.py
in the analysis tree (not committed to the repo) to confirm parity on the
first three rows and divergence on the fourth.

Notes

This is one of three small correctness fixes I'm sending against the
recently merged repo (also #74, #75). Each is on its own branch so they
can be merged or rejected independently.

The final cleanup block in dflash_generate used
    output_ids = output_ids[:, output_ids[0] != mask_token_id]
to strip trailing positions that the generation loop had not yet written
into.  The check is value-based, so any legitimate output token whose id
happens to equal mask_token_id is silently dropped from the middle of
the sequence.

The generation loop already tracks the last written position via
`start`, so a length-based slice is equivalent in the normal case and
correct in the pathological case.  Replace the two trailing lines with a
single slice that caps at max_length:

    output_ids = output_ids[:, :min(start + 1, max_length)]

This preserves the cap on total generation length (which used to be
enforced by the `:max_length` line) and removes the dependency on the
sampled token values never colliding with mask_token_id.

Refs: z-lab#73
@jianc99
Copy link
Copy Markdown
Member

jianc99 commented Apr 17, 2026

Good catch! Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

dflash_generate: final mask cleanup is value-based and drops legitimate output tokens equal to mask_token_id

2 participants