Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate inverse short-time Fourier transform from torchaudio #34827

Closed
wants to merge 2 commits into from

Conversation

mthrok
Copy link
Contributor

@mthrok mthrok commented Mar 16, 2020

As suggested in #3775, this PR migrates Python code to perform ISTFT from torchaudio to PyTorch.

  • Migrate function.
  • Migrate tests.
    • round trip precision test
    • Batch test
    • Linearity test
    • Sine test
    • Argument validation test
    • Fundamental case test (zeros and ones)

@dr-ci
Copy link

dr-ci bot commented Mar 16, 2020

💊 CircleCI build failures summary and remediations

As of commit 45a146f (more details on the Dr. CI page):


None of the build failures appear to be your fault 💚


  • 1/1 broken upstream at merge base b55dee9 from Apr 06 until Apr 07 (9 commits; 2e8f954 - 444073e)

    Please rebase on the viable/strict branch (expand for instructions)

    If your commit is newer than viable/strict, you can try basing on an older, stable commit:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase --onto FETCH_HEAD $(git merge-base origin/master HEAD)
    

    If your commit is older than viable/strict:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase FETCH_HEAD
    

    Check out the recency history of this "viable master" tracking branch.


🚧 1 upstream failure:

These were probably caused by upstream breakages:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 95 times.

@cpuhrsch cpuhrsch requested a review from vincentqb March 16, 2020 20:23
@mthrok mthrok changed the title WIP: Migrate inverse short-time Fourier transform Migrate inverse short-time Fourier transform from torchaudio Mar 17, 2020
@mthrok mthrok marked this pull request as ready for review March 17, 2020 18:16
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add the documentation? Some documentation tests are failing.

@mthrok mthrok force-pushed the dev-istft branch 2 times, most recently from 50d40af to e9d8538 Compare March 17, 2020 21:46
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mthrok
Copy link
Contributor Author

mthrok commented Mar 20, 2020

@vincentqb

I am confused about jit-ability of istft function. It seems that moving the function from torchaudio to pytorch slightly changes the property.
In both case, the bare istft function is jit-able, but if istft function is a part of custom function, it fails only when it's in pytorch.
The later case, the error message read, it requires implementation in aten.

import torch
import torchaudio.functional

# Toggle this to change the behavior.
istft = torch.istft
# istft = torchaudio.functional.istft

stft_kwargs = {
    'n_fft': 15,
    'hop_length': 3,
    'win_length': 11,
    'window': torch.hamming_window(11),
    'center': True,
    'pad_mode': 'constant',
    'normalized': True,
    'onesided': False,
}

istft_kwargs = stft_kwargs.copy()
istft_kwargs['length'] = None

input = torch.stft(torch.randn(3, 15), **stft_kwargs)


def nested_istft(
        stft_matrix, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, length):
    # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool, Optional[int]) -> Tensor
    return istft(stft_matrix, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, length)


def test_func(istft_func):
    jit_istft = torch.jit.script(istft_func)
    py_result = istft_func(input, **istft_kwargs)
    jit_result = jit_istft(input, **istft_kwargs)
    assert torch.allclose(py_result, jit_result, atol=1e-6)


print('checking bare istft')
test_func(istft)
print('checking nested istft')
test_func(nested_istft)
print('ok')

When istft function is inside of torchaudio.functional, it works.

$ python test.py
checking bare istft
checking nested istft
ok

but when istft is inside of pytorch, then it fails.

$ python test.py
checking bare istft
checking nested istft
Traceback (most recent call last):
  File "foo.py", line 39, in <module>
    test_func(nested_istft)
  File "foo.py", line 30, in test_func
    jit_istft = torch.jit.script(istft_func)
  File "/scratch/moto/pytorch/torch/jit/__init__.py", line 1296, in script
    fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
RuntimeError:
Unknown builtin op: aten::istft.
Here are some suggestions:
	aten::stft
	aten::ifft
	aten::irfft

The original call is:
  File "foo.py", line 27
        stft_matrix, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, length):
    # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool, Optional[int]) -> Tensor
    return istft(stft_matrix, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, length)
           ~~~~~ <--- HERE

@vincentqb
Copy link
Contributor

vincentqb commented Mar 23, 2020

This is related comment about aten error, and linked issue #30780.

@eellison
Copy link
Contributor

@mthrok could you follow the example in #33737 ? Since the inputs all have one single type it shouldn't be very hard to make this torchscriptable. Just add
"istft" to ops = ["stft", "lu", "lu_unpack", "cdist"].

basically we generally resolve torch.{op} to "aten::{op}", but functions in torch/functional.py are an exception to this. There still exist some functions in torch/functional.py which we want to resolve to their aten equivalent like broadcast_tensors so ops have to opt-in to compiling the python function. (broadcast_tensors we didn't add bc we don't yet support (*args), but the aten builtin compiles correctly).

@mthrok
Copy link
Contributor Author

mthrok commented Mar 26, 2020

Thanks @eellison for the suggestion. The trick did solve the issue above.

I applied that to PR, but now some tests are failing due to something related to ATen, which were not happening earlier.
On separate note I am working on porting istft to ATen, and I think ATen implementation will resolve this. So I will come back when I am done with ATen implementation of istft.

@facebook-github-bot
Copy link
Contributor

@mthrok merged this pull request in 5a27ec0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants