New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] STFT Support #92087
[ONNX] STFT Support #92087
Conversation
|
Thanks for your contribution! Could you sign the CLA following instructions by the bot comment above, and fix lint issues by running |
Thanks for reviewing this, @justinchuby! I just signed the CLA, and associated my last commits with the right email address. Moreover, I have passed the linter and fixed a problem with the unit tests regarding the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there's still errors in the CLA. Squashing the commits may work?
Apologies, I realized that the CLA should be signed under a "company contribution", since most of this work was done using my company's resources (Adobe). This contribution was internally approved before I submitted this PR, so I might need a working day or two to have someone from the Open Source Office sign the CLA. Will report asap. |
Not sure if related to this PR or the underlying onnxruntime, but there seems to be a off-by-one error somewhere when using graph optimization with STFT.
The output shape matches, but the expected shape is different, and so when we add nodes on top of that it causes shape mismatch errors. |
torch/onnx/symbolic_opset17.py
Outdated
# Checks | ||
assert ( | ||
return_complex is None or not return_complex | ||
), "STFT does not currently support complex types" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return_complex=False
is deprecated so I don't like the idea of forcing people to use it. I see that complex64 and complex128 are mentioned as types in this document, so is it possible to implement view_as_real
/view_as_complex
in onnx?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of having view_as_real and view_as_complex ops in onnx. Would you be willing to open an issue for a new operator? https://github.com/onnx/onnx/issues/new?assignees=&labels=operator&template=operator.md&title=
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peterbell10 I also don't like the idea of forcing return_complex=False
, but it is an easy workaround if we don't want to wait to have complex conversion support on ONNX, which seems like it might take quite some time (see below).
@justinchuby Seems like view_as_complex
was already requested here: #49793 Unfortunately, there was not enough interest for this when it was reported. I just commented on the issue, to try to bring it back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I right in saying onnx Cast doesn't support complex types? If it did then you could do something like:
def _as_real(z):
return float(z), float(-1j * z)
def _as_complex(real, imag):
return complex(real) + 1j * complex(imag)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is correct: ONNX doesn't currently support casting for complex types (https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast).
If it did, we could simply use Cast
on the result of STFT
to make sure the result is returned as complex if return_complex=True
.
That being said, would it be possible to obtain the output of STFT
, put it on a PyTorch tensor, and convert it to complex using the _as_complex
function above, and then put it back into the ONNX Graph?
Sorry, just realized that that's not gonna work, because Cast
from/to complex is not supported, which is exactly what @peterbell10 said (😅).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, otherwise there will be a graph break in onnx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@milesial I noticed that as well. The fact that it actually returns the correct shape in all the tests I wrote makes me think it might be a bug on the the onnxruntime API (or the ONNX STFT definition?). I couldn't find an issue reported on the I was thinking of reporting it once this PR is merged, so that it might be easier to reproduce. Or if you wanna report it now, by all means do :) |
Got it, good you caught it too, I opened microsoft/onnxruntime#14316 since I don't know if this MR is going to be merged soon |
The CLA is finally signed, apologies it took so long. I'll work on the comments and feedback asap. |
I notice the CLA bot is still having errors. Perhaps squash or rebase some commits? |
There are some build errors I believe happened because this branch is too old. Could you rebase with master? |
Hey @justinchuby , I fixed the problem regarding the type mismatch here. The other problem (i.e., the small differences between the results), should be resolved setting the tolerance to around 1e-5 (like I did, e.g., here). I'm not familiar with how to change this tolerance in the For reference, here's the last chunk of my output of FAILED test/onnx/test_op_consistency.py::TestOnnxModelOutputConsistency_opset17CPU::test_output_match_stft_cpu_float32 - AssertionError: Tensor-likes are not close!
FAILED test/onnx/test_op_consistency.py::TestOnnxModelOutputConsistency_opset17CPU::test_output_match_stft_cpu_float64 - AssertionError: Tensor-likes are not close!
FAILED test/onnx/test_op_consistency.py::TestOnnxModelOutputConsistency_opset18CPU::test_output_match_stft_cpu_float32 - AssertionError: Tensor-likes are not close!
FAILED test/onnx/test_op_consistency.py::TestOnnxModelOutputConsistency_opset18CPU::test_output_match_stft_cpu_float64 - AssertionError: Tensor-likes are not close! Thanks for your help! |
Thanks! I will work on it next week when I am back |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for creating this!
@pytorchbot merge -g |
@pytorchbot merge -g |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge -f "unrelated cuda and other failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR addresses issue [#81075](pytorch/pytorch#81075), making `torch.stft` compatible with ONNX Opset 17's STFT operator. The conversion works for _most_ of `torch.stft` functionality: - Batched or unbatched inputs - Normalization - Pre-computed windows - Rectangular windows - One-sided returns - Window centering (implicitly supported) What is currently _not_ supported is **complex types**, due to the lack of conversion functionality between PyTorch and ONNX (pytorch/pytorch#86746). Regardless, this is easy to bypass by setting `return_complex=False` when using `torch.stft`. Note that there is already a draft PR to address this (pytorch/pytorch#83944), but it is currently closed and it only partially addresses the conversion (i.e., most of `torch.stft` functionality is lacking, and unit tests are missing). Pull Request resolved: pytorch/pytorch#92087 Approved by: https://github.com/justinchuby
This PR addresses issue #81075, making
torch.stft
compatible with ONNX Opset 17's STFT operator.The conversion works for most of
torch.stft
functionality:What is currently not supported is complex types, due to the lack of conversion functionality between PyTorch and ONNX (#86746).
Regardless, this is easy to bypass by setting
return_complex=False
when usingtorch.stft
.Note that there is already a draft PR to address this (#83944), but it is currently closed and it only partially addresses the conversion (i.e., most of
torch.stft
functionality is lacking, and unit tests are missing).