Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SpectralOps CPU implementation for ARM/PowerPC processors (where MKL is not available) #41592

Closed
arnasRad opened this issue Jul 17, 2020 · 24 comments
Assignees
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: build Build system issues module: fft module: POWER Issues specific to the POWER/ppc architecture triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@arnasRad
Copy link

arnasRad commented Jul 17, 2020

🐛 Bug

fft: ATen not compiled with MKL support RuntimeError thrown when trying to compute Spectrogram on Jetson Nano that uses ARM64 processor.

To Reproduce

Code sample:

import torchaudio

waveform, sample_rate = torchaudio.load('test.wav')
spectrogram = torchaudio.transforms.Spectrogram(sample_rate)(waveform)

Stack trace:

Traceback (most recent call last):
  File "spectrogram_test.py", line 4, in <module>
    spectrogram = torchaudio.transforms.Spectrogram(sample_rate)(waveform)
  File "/home/witty/ai-benchmark-2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/witty/ai-benchmark-2/lib/python3.6/site-packages/torchaudio-0.7.0a0+102174e-py3.6-linux-aarch64.egg/torchaudio/transforms.py", line 84, in forward
    self.win_length, self.power, self.normalized)
  File "/home/witty/ai-benchmark-2/lib/python3.6/site-packages/torchaudio-0.7.0a0+102174e-py3.6-linux-aarch64.egg/torchaudio/functional.py", line 162, in spectrogram
    waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
  File "/home/witty/ai-benchmark-2/lib/python3.6/site-packages/torch/functional.py", line 465, in stft
    return _VF.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
RuntimeError: fft: ATen not compiled with MKL support

Expected behavior

Spectrogram from waveform created

Environment

Commands used to install PyTorch:

wget https://nvidia.box.com/shared/static/yr6sjswn25z7oankw8zy1roow9cy5ur1.whl -O torch-1.6.0rc2-cp36-cp36m-linux_aarch64.whl
sudo apt-get install python-pip libopenblas-base libopenmpi-dev 
pip install Cython
pip install numpy torch-1.6.0rc2-cp36-cp36m-linux_aarch64.whl

Commands used to install torchaudio:
sox:

sudo apt-get update -y
sudo apt-get install -y libsox-dev
pip install sox

torchaudio:

git clone https://github.com/pytorch/audio.git audio
cd audio && python setup.py install

torchaudio.__version__ output:
0.7.0a0+102174e

collect_env.py output:

PyTorch version: 1.6.0
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu/Linaro 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/aarch64-linux-gnu/libcudnn.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_adv_infer.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_adv_train.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_cnn_infer.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_cnn_train.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_etc.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_ops_infer.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_ops_train.so.8.0.0

Versions of relevant libraries:
[pip3] numpy==1.16.1
[pip3] pytorch-ignite==0.3.0
[pip3] torch==1.6.0
[pip3] torchaudio==0.7.0a0+102174e
[conda] Could not collect

Other relevant information:
MKL is not installed, because it is not supported on ARM processors; oneDNN installed

Additional context

I did not install MKL because it is not supported on ARM processors, so building PyTorch from source with MKL support is not possible. Is there any workaround to this problem?

cc @malfet @seemethere @walterddr @mruberry @peterbell10 @ezyang

@mthrok
Copy link
Contributor

mthrok commented Jul 17, 2020

Looks like this is PyTorch issue (torch.functional.stft). I will move this issue to Torch.

@mthrok mthrok transferred this issue from pytorch/audio Jul 17, 2020
@zou3519 zou3519 added module: build Build system issues module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module shadow review Request the triage shadow to take a second look at your triage and see if they agree or not labels Jul 17, 2020
@zou3519
Copy link
Contributor

zou3519 commented Jul 17, 2020

I am not sure if we support Jetson Nano and ARM64 processor. Will try to find someone who knows more...

@ssnl
Copy link
Collaborator

ssnl commented Jul 17, 2020

The Jetson Nano wheels are compiled by NVIDIA, so maybe they will now who are responsible for this.

@malfet malfet added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: build Build system issues and removed module: build Build system issues labels Jul 17, 2020
@malfet
Copy link
Contributor

malfet commented Jul 17, 2020

This sounds like an enhancement to me: at the moment, there are only MKL-accelerate cpu implementation of at::_fft_with_size :


MKL is not available on ARM, so in order to achieve feature parity, one should rely on other implementations, for example on http://www.fftw.org/

@malfet malfet changed the title fft: ATen not compiled with MKL support on ARM processors Add SpectralOps CPU implementation for ARM/PowerPC processors (where MKL is not available) Jul 17, 2020
@arnasRad
Copy link
Author

arnasRad commented Jul 20, 2020

Thanks @malfet. I was able to compute fft on ARM by using CUDA device on waveform:

import torchaudio
import torch

waveform, sample_rate = torchaudio.load('test.wav')
waveform = waveform.to("cuda:0")

spectrogram = torchaudio.transforms.Spectrogram(sample_rate).to("cuda:0")(waveform)

@mruberry mruberry added module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: fft module: POWER Issues specific to the POWER/ppc architecture and removed module: operators (deprecated) shadow review Request the triage shadow to take a second look at your triage and see if they agree or not labels Oct 7, 2020
@Flamefire
Copy link
Collaborator

Test code from #42426:

import torch
x = torch.randn(10,10,2)
torch.fft(x,1)

I'd suggest to change the title to the other: "Support torch.fft without MKL support"

@mruberry mruberry added function request A request for a new function or the addition of new arguments/modes to an existing function. and removed enhancement Not as big of a feature, but technically not a bug. Should be easy to fix labels Nov 30, 2020
@StuartIanNaylor
Copy link

StuartIanNaylor commented Dec 26, 2020

Is there anyway Torch could incorporate https://developer.arm.com/tools-and-software/server-and-hpc/downloads/arm-performance-libraries

Or is there a cross platform equivalent as seem only 64bit? Openblas,fftw... ?

I was overjoyed someone had done some great work for Raspberry but my 1st test used FFT (also 32bit).

isakbosman/pytorch_arm_builds#1

@mruberry
Copy link
Collaborator

Is there anyway Torch could incorporate https://developer.arm.com/tools-and-software/server-and-hpc/downloads/arm-performance-libraries

It might be interesting to build a PyTorch-compatible library that can use the linked software, but I think the ARM community would be expected to drive any effort to support fft-like functionality in PyTorch on their hardware.

@walterddr
Copy link
Contributor

another alternative is to support fftw3?

@StuartIanNaylor
Copy link

StuartIanNaylor commented Dec 27, 2020

It might be interesting to build a PyTorch-compatible library that can use the linked software, but I think the ARM community would be expected to drive any effort to support fft-like functionality in PyTorch on their hardware.

The reality is we should all be supporting opensource such as Openblas & FFTW and not using hardware specifics such as Intel MKL.

Arm perforamnce-libs, Intel-MKL & anything AMD might want to submit should be feature requests, to a normal opensource core not the other way round.
Really we shouldn't be using any manufacturer based libs unless like nvidia where they provide pytorch for there jetson boards or GPU specifics and like wise Intel could maintain a pytorch Intel MKL accelerated one.

Uses industry-standard C and Fortran APIs for compatibility with popular BLAS, LAPACK, and FFTW functions—no code changes required

If Intel MKL uses industry-standard C and Fortran APIs why are we using 'Intel MKL' and not the 'industry-standard C and Fortran APIs' as a core?

@mruberry
Copy link
Collaborator

@StuartIanNaylor I encourage you to produce a PyTorch-compatible library using these alternatives or a PR that discusses their impact on build size and performance.

@StuartIanNaylor
Copy link

StuartIanNaylor commented Dec 28, 2020

@StuartIanNaylor I encourage you to produce a PyTorch-compatible library using these alternatives or a PR that discusses their impact on build size and performance.

The rationale of my reply is that would be illogical as surely the industry standards should be core and maybe I should provide hardware specifics such as IntelMKL.

But the necessity to use a PyTorch-compatible library is what at least I am questioning as why is it needed when there are industry standard apis and libaries?
Why use a hardware and manufacture specific one?

@mruberry
Copy link
Collaborator

Sorry I'm not sure what you're getting at, @StuartIanNaylor. As mentioned, we would consider adopting other math libraries or even entirely native implementations for these operations, but someone needs to do the work and demonstrate the correctness and performance of these alternatives. The best way to advocate for these changes is by doing that work.

@StuartIanNaylor
Copy link

StuartIanNaylor commented Dec 29, 2020

Sorry I'm not sure what you're getting at, @StuartIanNaylor. As mentioned, we would consider adopting other math libraries or even entirely native implementations for these operations, but someone needs to do the work and demonstrate the correctness and performance of these alternatives. The best way to advocate for these changes is by doing that work.

@mruberry I apologise if you are not sure what I am getting at but I just find it really strange that firstly you use a specific hardware vendors math libs when industry standard cross platform libs exist as they do.
It doesn't need other vendor math libs but the problem is for some reason you picked Intel MKL over the well known and standard libs that opensource has.

What is the point of even being able to raise or comment on a issue if the response is if you want it do it yourself?!

I am asking why was Intel MKL chosen for the core as that really does confuse me with all honesty and no unfortunately I don't have the ability to implement standard math libs or arm specific.

Also why do we need to demonstrate anything when its in the title of Intel MKL and that is worse than bad performance as its exclusive to an architecture.

@mruberry
Copy link
Collaborator

I see. I think we're straying far from the original issue here. These more general type of questions are best asked on our forum: https://discuss.pytorch.org/.

@Flamefire
Copy link
Collaborator

My 2c: I think the misunderstanding comes from the reply:

It might be interesting to build a PyTorch-compatible library that can use the linked software, but I think the ARM community would be expected to drive any effort to support fft-like functionality in PyTorch on their hardware.

That is reasonable for the extension to this issue: Support high(er)-performance special implementations for a specific hardware (ARM).

However the basic suggestion was that PyTorch supports a cross-architecture FFT lib like FFTW by default and not default to a "special implementation for a specific hardware (Intel MKL)" which renders it unusable when doing CPU on non-"Intel x86" hardware (e.g. AMD [bad performance], Power, ARM, ...)

See my comment #41592 (comment)

@mruberry
Copy link
Collaborator

mruberry commented Jan 4, 2021

However the basic suggestion was that PyTorch supports a cross-architecture FFT lib like FFTW by default and not default to a "special implementation for a specific hardware (Intel MKL)" which renders it unusable when doing CPU on non-"Intel x86" hardware (e.g. AMD [bad performance], Power, ARM, ...)

See my comment #41592 (comment)

Supporting fftw (or any particular library) is interesting. The questions we should answer when considering an alternative to our current approach are:

  • what is the performance vs. what we have today? (easy to evaluate on platforms that doesn't support mkldnn)
  • what would be the increase in build size from adopting a new solution?
  • how much harder will PyTorch be to maintain if we adopt a new library?
  • how will PyTorch be deployed with the new library?
  • how will the new library be tested?
  • how costly is it to implement the proposed alternative?
  • how compelling are the scenarios the new library will unlock?
  • how challenging are these scenarios to pursue without PyTorch implementing this functionality?

These can be tricky questions to answer. To the last question, however, is there not a PyTorch-compatible library calling fftw already available? PyTorch CPU tensors can be converted to NumPy arrays without copying memory. Is there no fftw package that operates on NumPy arrays?

@ezyang
Copy link
Contributor

ezyang commented Jan 4, 2021

@mruberry I think that is a little too adversarial ;) There are several places in PyTorch where we have multiple libraries providing implementations of one function, with some decisions about when to select which one. If we are serious about ARM (which we should be!) then it's not a hard call to say that we should add another library to cover FFT support in this situation. Now, obviously there is work to figure out which library is appropriate and whether or not we should even ship it with our regular CPU binaries (probably not, if MKL fft is universally better), and as core developers we might not prioritize this work, but if my job were to make PyTorch work as well as possible on ARM, this would probably be part of the mandate.

@Flamefire
Copy link
Collaborator

probably not, if MKL fft is universally better

I did a few simple tests using numpy, pyfft and mkl-fft in Python and this seems to be true for x86. But again: MKL does not work at all on non x86. I'm actually surprised because I expected to see the slowdown on an AMD Rome processor due to the known "downgrading" of MKL performance on non-Intel processors. But I was not able to verify that for FFTs using Python as the interface.

@ezyang
Copy link
Contributor

ezyang commented Jan 5, 2021

yeah, sorry, I should have specified, regular x86 cpu binaries :)

@malfet malfet self-assigned this Jan 14, 2021
@discort
Copy link

discort commented Mar 29, 2021

Thanks for working on it @malfet
Any progress?

@iseeyuan
Copy link
Contributor

The issue also happens in mobile: https://discuss.pytorch.org/t/fft-operations-on-mobile/119598. Please keep us posted for any updates!

@malfet
Copy link
Contributor

malfet commented Apr 28, 2021

We are likely going to use pocketfft on non-x86 platforms

@StuartIanNaylor
Copy link

Pocketfft would be a great inclusion but any choice other than vendor specific libs will do.

Has anyone got an ETA as would really like to use on ARM64 and have tried and would seem so have many but still hit.

/home/pi/speech-brain/venv/lib/python3.7/site-packages/torch/functional.py:585: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.
Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at  ../aten/src/ATen/native/SpectralOps.cpp:483.)
  normalized, onesided, return_complex)
Traceback (most recent call last):
  File "delay-sum.py", line 38, in <module>
    Xs = stft(xs)
  File "/home/pi/speech-brain/venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 880, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/pi/speech-brain/speechbrain-0.5.7/speechbrain/processing/features.py", line 171, in forward
    return_complex=False,
  File "/home/pi/speech-brain/venv/lib/python3.7/site-packages/torch/functional.py", line 585, in stft
    normalized, onesided, return_complex)
RuntimeError: fft: ATen not compiled with MKL support

As just tried all the great whls supplied by https://mathinf.eu/pytorch/arm64/2021-01/
Still the same...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: build Build system issues module: fft module: POWER Issues specific to the POWER/ppc architecture triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.