Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] STFT Support #92087

Closed
wants to merge 12 commits into from
Closed

[ONNX] STFT Support #92087

wants to merge 12 commits into from

Conversation

urinieto
Copy link
Contributor

This PR addresses issue #81075, making torch.stft compatible with ONNX Opset 17's STFT operator.

The conversion works for most of torch.stft functionality:

  • Batched or unbatched inputs
  • Normalization
  • Pre-computed windows
  • Rectangular windows
  • One-sided returns
  • Window centering (implicitly supported)

What is currently not supported is complex types, due to the lack of conversion functionality between PyTorch and ONNX (#86746).

Regardless, this is easy to bypass by setting return_complex=False when using torch.stft.

Note that there is already a draft PR to address this (#83944), but it is currently closed and it only partially addresses the conversion (i.e., most of torch.stft functionality is lacking, and unit tests are missing).

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Jan 12, 2023
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jan 12, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: urinieto / name: Oriol Nieto (a95ffd7)

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 12, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92087

Note: Links to docs will display an error until the docs builds have been completed.

❌ 7 Failures

As of commit b88508c:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@justinchuby justinchuby self-assigned this Jan 12, 2023
@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 13, 2023
@justinchuby
Copy link
Collaborator

Thanks for your contribution! Could you sign the CLA following instructions by the bot comment above, and fix lint issues by running lintrunner -a -m master? I will add comments after the CLA is signed.

https://github.com/pytorch/pytorch/wiki/lintrunner

@justinchuby justinchuby added the module: onnx Related to torch.onnx label Jan 13, 2023
@urinieto
Copy link
Contributor Author

Thanks for reviewing this, @justinchuby! I just signed the CLA, and associated my last commits with the right email address. Moreover, I have passed the linter and fixed a problem with the unit tests regarding the return_complex parameter. Should be good to be reviewed :)

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's still errors in the CLA. Squashing the commits may work?

test/onnx/test_operators.py Outdated Show resolved Hide resolved
torch/onnx/symbolic_opset17.py Show resolved Hide resolved
torch/onnx/symbolic_opset17.py Outdated Show resolved Hide resolved
torch/onnx/symbolic_opset17.py Outdated Show resolved Hide resolved
torch/onnx/symbolic_opset17.py Outdated Show resolved Hide resolved
torch/onnx/symbolic_opset17.py Outdated Show resolved Hide resolved
torch/onnx/symbolic_opset17.py Show resolved Hide resolved
torch/onnx/symbolic_opset17.py Outdated Show resolved Hide resolved
torch/onnx/symbolic_opset17.py Outdated Show resolved Hide resolved
torch/onnx/symbolic_opset17.py Outdated Show resolved Hide resolved
@urinieto
Copy link
Contributor Author

Apologies, I realized that the CLA should be signed under a "company contribution", since most of this work was done using my company's resources (Adobe). This contribution was internally approved before I submitted this PR, so I might need a working day or two to have someone from the Open Source Office sign the CLA. Will report asap.

@milesial
Copy link
Contributor

Not sure if related to this PR or the underlying onnxruntime, but there seems to be a off-by-one error somewhere when using graph optimization with STFT.

class Test(nn.Module):
    def forward(self, audio):
        stft = torch.stft(audio, 400, 160, return_complex=False)
        return stft

m = Test()
print('torch', m(torch.randn(1, 16000 * 30)).shape)

torch.onnx.export(m,
                  torch.randn(1, 16000 * 30),
                  'test.onnx',
                  export_params=True,
                  input_names=['in'],
                  output_names=['out'],
                  opset_version=17, verbose=False)

sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_options.optimized_model_filepath = 'test.onnx'
session = rt.InferenceSession('test.onnx', sess_options)
print('onnx unoptimized', session.run(None, {'in': torch.randn(1, 16000 * 30).numpy()})[0].shape)

sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.optimized_model_filepath = 'test.onnx'
session = rt.InferenceSession('test.onnx', sess_options)
print('onnx optimized', session.run(None, {'in': torch.randn(1, 16000 * 30).numpy()})[0].shape)

torch torch.Size([1, 201, 4801, 2])
onnx unoptimized (1, 201, 4801, 2)
onnx optimized (1, 201, 4801, 2)

2023-01-15 19:38:25.256968539 [W:onnxruntime:, execution_frame.cc:828 VerifyOutputSizes] Expected shape from model of {1,201,4802,2} does not match actual shape of {1,201,4801,2} for output out

The output shape matches, but the expected shape is different, and so when we add nodes on top of that it causes shape mismatch errors.

image

# Checks
assert (
return_complex is None or not return_complex
), "STFT does not currently support complex types"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return_complex=False is deprecated so I don't like the idea of forcing people to use it. I see that complex64 and complex128 are mentioned as types in this document, so is it possible to implement view_as_real/view_as_complex in onnx?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of having view_as_real and view_as_complex ops in onnx. Would you be willing to open an issue for a new operator? https://github.com/onnx/onnx/issues/new?assignees=&labels=operator&template=operator.md&title=

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peterbell10 I also don't like the idea of forcing return_complex=False, but it is an easy workaround if we don't want to wait to have complex conversion support on ONNX, which seems like it might take quite some time (see below).

@justinchuby Seems like view_as_complex was already requested here: #49793 Unfortunately, there was not enough interest for this when it was reported. I just commented on the issue, to try to bring it back.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right in saying onnx Cast doesn't support complex types? If it did then you could do something like:

def _as_real(z):
    return float(z), float(-1j * z)

def _as_complex(real, imag):
    return complex(real) + 1j * complex(imag)

Copy link
Contributor Author

@urinieto urinieto Jan 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct: ONNX doesn't currently support casting for complex types (https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast).

If it did, we could simply use Cast on the result of STFT to make sure the result is returned as complex if return_complex=True.

That being said, would it be possible to obtain the output of STFT, put it on a PyTorch tensor, and convert it to complex using the _as_complex function above, and then put it back into the ONNX Graph?

Sorry, just realized that that's not gonna work, because Cast from/to complex is not supported, which is exactly what @peterbell10 said (😅).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, otherwise there will be a graph break in onnx.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@urinieto
Copy link
Contributor Author

Not sure if related to this PR or the underlying onnxruntime, but there seems to be a off-by-one error somewhere when using graph optimization with STFT.

@milesial I noticed that as well. The fact that it actually returns the correct shape in all the tests I wrote makes me think it might be a bug on the the onnxruntime API (or the ONNX STFT definition?). I couldn't find an issue reported on the onnxruntime repo.

I was thinking of reporting it once this PR is merged, so that it might be easier to reproduce. Or if you wanna report it now, by all means do :)

@milesial
Copy link
Contributor

Got it, good you caught it too, I opened microsoft/onnxruntime#14316 since I don't know if this MR is going to be merged soon

@urinieto
Copy link
Contributor Author

urinieto commented Feb 9, 2023

The CLA is finally signed, apologies it took so long. I'll work on the comments and feedback asap.

@justinchuby
Copy link
Collaborator

I notice the CLA bot is still having errors. Perhaps squash or rebase some commits?

@justinchuby
Copy link
Collaborator

There are some build errors I believe happened because this branch is too old. Could you rebase with master?

@urinieto
Copy link
Contributor Author

urinieto commented Mar 1, 2023

You may run the test with pytest test/onnx/test_op_consistency.py -v -s -k stft

Hey @justinchuby , I fixed the problem regarding the type mismatch here.

The other problem (i.e., the small differences between the results), should be resolved setting the tolerance to around 1e-5 (like I did, e.g., here). I'm not familiar with how to change this tolerance in the test_op_consistency.py file, could you take care of it?

For reference, here's the last chunk of my output of pytest test/onnx/test_op_consistency.py -v -s -k stft:

FAILED test/onnx/test_op_consistency.py::TestOnnxModelOutputConsistency_opset17CPU::test_output_match_stft_cpu_float32 - AssertionError: Tensor-likes are not close!
FAILED test/onnx/test_op_consistency.py::TestOnnxModelOutputConsistency_opset17CPU::test_output_match_stft_cpu_float64 - AssertionError: Tensor-likes are not close!
FAILED test/onnx/test_op_consistency.py::TestOnnxModelOutputConsistency_opset18CPU::test_output_match_stft_cpu_float32 - AssertionError: Tensor-likes are not close!
FAILED test/onnx/test_op_consistency.py::TestOnnxModelOutputConsistency_opset18CPU::test_output_match_stft_cpu_float64 - AssertionError: Tensor-likes are not close!

Thanks for your help!

@justinchuby
Copy link
Collaborator

Thanks! I will work on it next week when I am back

@justinchuby justinchuby added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 9, 2023
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for creating this!

@justinchuby
Copy link
Collaborator

@pytorchbot merge -g

@justinchuby
Copy link
Collaborator

@pytorchbot merge -g

@justinchuby justinchuby removed the ciflow/trunk Trigger trunk jobs on your pull request label Mar 9, 2023
@justinchuby
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 10, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@justinchuby
Copy link
Collaborator

@pytorchbot merge -f "unrelated cuda and other failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 12, 2023
This PR addresses issue [#81075](pytorch/pytorch#81075),  making `torch.stft` compatible with ONNX Opset 17's STFT operator.

The conversion works for _most_ of `torch.stft` functionality:

- Batched or unbatched inputs
- Normalization
- Pre-computed windows
- Rectangular windows
- One-sided returns
- Window centering (implicitly supported)

What is currently _not_ supported is **complex types**, due to the lack of conversion functionality between PyTorch and ONNX (pytorch/pytorch#86746).

Regardless, this is easy to bypass by setting `return_complex=False` when using `torch.stft`.

Note that there is already a draft PR to address this (pytorch/pytorch#83944), but it is currently closed and it only partially addresses the conversion (i.e., most of `torch.stft` functionality is lacking, and unit tests are missing).
Pull Request resolved: pytorch/pytorch#92087
Approved by: https://github.com/justinchuby
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants