Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can use with BetterTransformer? #8

Closed
cceyda opened this issue May 24, 2023 · 3 comments
Closed

Can use with BetterTransformer? #8

cceyda opened this issue May 24, 2023 · 3 comments

Comments

@cceyda
Copy link

cceyda commented May 24, 2023

I was hoping this would work for optimizing the underlying model.encoder since it is independent(?) of the rest
but I'm getting a shape error like :
RuntimeError: shape '[1, 512]' is invalid for input of size 262144 basically saying it expected [512, 512] for attention_mask which is weird because attention_mask input is shaped [512, 512]

Here is the test code:

from span_marker import SpanMarkerModel
from optimum.bettertransformer import BetterTransformer
# Download from the 🤗 Hub
model = SpanMarkerModel.from_pretrained("tomaarsen/span-marker-bert-base-fewnerd-fine-super").eval()

# Run inference
entities = model.predict("Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris.")
print(entities) # works
better_encoder = BetterTransformer.transform(model.encoder)
model.encoder=better_encoder 

# Run inference
entities = model.predict("Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris.")
print(entities)
│ /Users/ceyda.1/miniconda/lib/python3.9/site-packages/span_marker/modeling.py:137 in forward      │
│                                                                                                  │
│   134 │   │   │   SpanMarkerOutput: The output dataclass.                                        │
│   135 │   │   """                                                                                │
│   136 │   │   token_type_ids = torch.zeros_like(input_ids)                                       │
│ ❱ 137 │   │   outputs = self.encoder(                                                            │
│   138 │   │   │   input_ids,                                                                     │
│   139 │   │   │   attention_mask=attention_mask,                                                 │
│   140 │   │   │   token_type_ids=token_type_ids,                                                 │
│                                                                                                  │
│ /Users/ceyda.1/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501 in          │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /Users/ceyda.1/miniconda/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:1 │
│ 020 in forward                                                                                   │
│                                                                                                  │
│   1017 │   │   │   inputs_embeds=inputs_embeds,                                                  │
│   1018 │   │   │   past_key_values_length=past_key_values_length,                                │
│   1019 │   │   )                                                                                 │
│ ❱ 1020 │   │   encoder_outputs = self.encoder(                                                   │
│   1021 │   │   │   embedding_output,                                                             │
│   1022 │   │   │   attention_mask=extended_attention_mask,                                       │
│   1023 │   │   │   head_mask=head_mask,                                                          │
│                                                                                                  │
│ /Users/ceyda.1/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501 in          │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /Users/ceyda.1/miniconda/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:6 │
│ 10 in forward                                                                                    │
│                                                                                                  │
│    607 │   │   │   │   │   encoder_attention_mask,                                               │
│    608 │   │   │   │   )                                                                         │
│    609 │   │   │   else:                                                                         │
│ ❱  610 │   │   │   │   layer_outputs = layer_module(                                             │
│    611 │   │   │   │   │   hidden_states,                                                        │
│    612 │   │   │   │   │   attention_mask,                                                       │
│    613 │   │   │   │   │   layer_head_mask,                                                      │
│                                                                                                  │
│ /Users/ceyda.1/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501 in          │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /Users/ceyda.1/miniconda/lib/python3.9/site-packages/optimum/bettertransformer/models/encoder_mo │
│ dels.py:246 in forward                                                                           │
│                                                                                                  │
│    243 │   │   │   # attention mask comes in with values 0 and -inf. we convert to torch.nn.Tra  │
│    244 │   │   │   # 0->false->keep this token -inf->true->mask this token                       │
│    245 │   │   │   attention_mask = attention_mask.bool()                                        │
│ ❱  246 │   │   │   attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], att  │
│    247 │   │   │   hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mas  │
│    248 │   │   │   attention_mask = None                                                         │
│    249                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: shape '[1, 512]' is invalid for input of size 262144
  • transformers version: 4.29.2
  • Python version: 3.9.16
  • PyTorch version (GPU?): 2.0.1 (False)
  • optimum :1.8.6
@tomaarsen
Copy link
Owner

Hello!

As discussed internally on slack, it seems like BetterTransformer does not support attention matrices as are required for SpanMarker. See this Figure for an example attention matrix that I would use:
image

I'm afraid they only support attention masks, i.e. with shape [seq_length] rather than [seq_length, seq_length].

  • Tom Aarsen

@cceyda
Copy link
Author

cceyda commented May 26, 2023

yeah I was hopeful because the model translation worked without an error, but debugging it deeper it probably wont work without writing custom model translation code.

@tomaarsen
Copy link
Owner

I think so too. This likely requires changes on the BetterTransformer side. I'll close this for now, then

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants