Skip to content

ModernBERT for MLM outputs incorrect hidden state shape. #38499

Open
@jfkback

Description

@jfkback

System Info

When using ModernBERTForMaskedLM with output_hidden_states=True the hidden state is not correctly padded when it is returned. A minimal example is included below:

import torch
from transformers import AutoTokenizer, ModernBertForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
model = ModernBertForMaskedLM.from_pretrained("answerdotai/ModernBERT-base").to("cuda")

inputs = tokenizer(
    [
        "The capital of France is <mask>.",
        "The name of the first president of the united states is <mask>.",
    ],
    padding=True,
    return_tensors="pt",
).to("cuda")

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

print(inputs["attention_mask"].sum())
# >>> 26
print(outputs.hidden_states[-1].shape)
# >>> torch.Size([26, 768])


assert outputs.hidden_states[-1].shape == inputs["input_ids"].shape + (
    model.config.hidden_size,
)

I'm using the following library versions:

  • transformers==4.48.2
  • torch==2.6.0

It appears that what is returned is the flattened version as the tensor is 2D and the first dimension corresponds to the sum of the attention mask. This issue doesn't happen when using the non MLM version.

I searched modern bert and hidden state and looked at the recent commits and didn't see any mention of this issue, but it might have been fixed in a newer version without it being obvious.

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run the code provided in the issue with flash attention on a Cuda GPU.

Expected behavior

The hidden states should have shape [batch size, max sequence length, model dim] but they have shape [unknown dim (I think the number of unpadded tokens), model dim].

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions