Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onnx] support attn_mask fp16 type #110306

Closed

Conversation

rui-ren
Copy link
Contributor

@rui-ren rui-ren commented Sep 29, 2023

When users define customized attention mask using dtype=torch.float16, e.g.

from torch.nn import functional as F

float_min = torch.finfo(torch.float16).min

attention_mask_fp16 = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(torch.float16)

attn_output = F.scaled_dot_product_attention(
                 query_layer_, key_layer_, value_layer_, attention_mask_fp16, 0.0, is_causal=False
 )

the onnx graph cannot be exported.

When q, k ,v have the fp16 type, we can support this attn_mask to be fp16 type, by adding

elif (
        _type_utils.JitScalarType.from_value(attn_mask)
        == _type_utils.JitScalarType.FLOAT
        in (_type_utils.JitScalarType.FLOAT, _type_utils.JitScalarType.HALF)

This can export .onnx graph.

Fixes #109336

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 29, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110306

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6149727 with merge base d04b35e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Sep 29, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Sep 29, 2023
@titaiwangms
Copy link
Collaborator

titaiwangms commented Sep 29, 2023

This looks good!
Please sign CLA and lint the code (https://github.com/pytorch/pytorch/wiki/lintrunner).

@titaiwangms titaiwangms self-assigned this Sep 29, 2023
@titaiwangms titaiwangms added module: onnx Related to torch.onnx topic: improvements topic category labels Sep 29, 2023
@rui-ren rui-ren marked this pull request as ready for review September 29, 2023 23:59
Copy link
Collaborator

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@titaiwangms titaiwangms added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 1, 2023
@titaiwangms
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@rui-ren rui-ren deleted the rui-ren/onnx-support-attn-mask-fp16-dtype branch October 1, 2023 15:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[onnx exporter] Falcon-7b onnx graph exporter issue from huggingface source
4 participants