Skip to content

Conversation

titaiwangms
Copy link
Collaborator

@titaiwangms titaiwangms commented Feb 15, 2023

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Feb 15, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 15, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94919

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ea118eb:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@BowenBao
Copy link
Collaborator

Could you add a TODO/FIXME comment and/or tracking issue if this is temp?

@titaiwangms
Copy link
Collaborator Author

titaiwangms commented Feb 15, 2023

Could you add a TODO/FIXME comment and/or tracking issue if this is temp?

Haven't figured out a way on this ...
https://github.com/microsoft/onnx-script/pull/439/files#r1107520708

We don't have it in torch.onnx.export API actually... So maybe not a temp...

@_onnx_symbolic("aten::bitwise_not")
@_beartype.beartype
def bitwise_not(g: jit_utils.GraphContext, input):
    if not symbolic_helper._is_bool(input):
        raise errors.SymbolicValueError(
            "ONNX export does NOT support exporting bitwise Not "
            "for non-boolean input values",
            input,
        )
    return g.op("Not", input)

Concusion: ONNX export does NOT support exporting bitwise with INT dtype right now.

@titaiwangms titaiwangms changed the title [ONNX] Temp support aten::bit_wise_not in fx-onnx exporter [ONNX] Support aten::bit_wise_not in fx-onnx exporter Feb 15, 2023
@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge -g

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 16, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pruthvistony added a commit to ROCm/pytorch that referenced this pull request May 2, 2023
@facebook-github-bot facebook-github-bot deleted the gh/AllenTiTaiWang/41/head branch June 8, 2023 14:22
jhavukainen pushed a commit to kulinseth/pytorch that referenced this pull request Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: onnx torch.onnx related changes that should show up in the release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants