Skip to content

Fix token_lengths kwarg leaking into the labels encoder#372

Open
yuriihavrylko wants to merge 1 commit into
urchade:mainfrom
yuriihavrylko:fix/bi-encoder-token-lengths
Open

Fix token_lengths kwarg leaking into the labels encoder#372
yuriihavrylko wants to merge 1 commit into
urchade:mainfrom
yuriihavrylko:fix/bi-encoder-token-lengths

Conversation

@yuriihavrylko

Copy link
Copy Markdown

Fixes #370

Problem

Since 3a38d45 ("make gliner serving more efficient") the collator emits a token_lengths kwarg — precomputed CPU-side text lengths that avoid an attention_mask.sum().tolist() GPU→CPU sync. The uni-encoder text path pops it, but BiEncoder.encode_labels only removed packing_config and pair_attention_mask before forwarding kwargs to the labels backbone. The text-side token_lengths therefore reaches the labels encoder and crashes any plain HF backbone:

TypeError: BertModel.forward() got an unexpected keyword argument 'token_lengths'

This breaks every bi-encoder checkpoint (e.g. knowledgator/modern-gliner-bi-base-v1.0, knowledgator/modern-gliner-bi-large-v1.0) on transformers 4.x. Bisected to 3a38d45: every commit before it predicts correctly, every commit after crashes.

Fix

Pop token_lengths in encode_labels alongside the other text-side kwargs — the value describes text sequences and is meaningless for label sequences.

Verification

knowledgator/modern-gliner-bi-base-v1.0 + transformers 4.48.3:

  • current main: TypeError above
  • with this fix: [('Angela Merkel', 'person', 0.85), ('Kyiv', 'location', 0.75), ('CDU', 'organization', 0.93)] — identical to gliner 0.2.16 output

Note: on transformers 5.x this crash does not reproduce (the kwarg is silently swallowed), but bi-encoders fail there for an unrelated reason — see the companion PR for #324.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0.2.27 TypeError: BertModel.forward() got an unexpected keyword argument 'token_lengths'

1 participant