New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace torch.assert_allclose with assertEqual from pytorch #1387
Conversation
I believe I didn't touch failed functionality. Any suggestions how to make the build passing? binary_macos_conda_py3.9 - Failed
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @discort
Thanks for the update. It looks mostly good. Regarding the build failure, since this PR does not change anything about the library code / build process, we can ignore them. (Build failure often happens due to upstream packaging and there might be something very fragile about the way build job is setup for macOS.)
|
||
def test_mfcc_empty(self): | ||
# Passing in an empty tensor should result in an error | ||
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0)) | ||
|
||
def test_resample_waveform(self): | ||
def get_output_fn(sound, args): | ||
output = kaldi.resample_waveform(sound, args[1], args[2]) | ||
output = kaldi.resample_waveform(sound.to(torch.float32), args[1], args[2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why it is necessary to add to(torch.float32)
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't why but sound
equals to int, so you will get RuntimeError: expected scalar type Long but found Float
In [1]: sound.dtype
Out[1]: torch.int16
To reproduce:
sound = torch.randint(-2742, 8204, (1, 8000))
kaldi.resample_waveform(sound, 16000, 1000)
~/python/projects/torchaudio/torchaudio/compliance/kaldi.py in resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width)
841 num_wavs, length = waveform.shape
842 waveform = F.pad(waveform, (width, width + orig_freq))
--> 843 resampled = F.conv1d(waveform[:, None], kernel, stride=orig_freq)
844 resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
845 target_length = int(math.ceil(new_freq * length / orig_freq))
RuntimeError: expected scalar type Long but found Float
Is it expected? @mthrok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping @mthrok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @discort
Sorry for the late response. I looked at the code, and it seems that this should have been caught in previous PRs but the combination of forgetting to change skipIfNoSoxBackend
and the removal of torchaudio.load_wav
caused this. I think this is fine. Thanks for the followup.
|
||
def test_mfcc_empty(self): | ||
# Passing in an empty tensor should result in an error | ||
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0)) | ||
|
||
def test_resample_waveform(self): | ||
def get_output_fn(sound, args): | ||
output = kaldi.resample_waveform(sound, args[1], args[2]) | ||
output = kaldi.resample_waveform(sound.to(torch.float32), args[1], args[2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @discort
Sorry for the late response. I looked at the code, and it seems that this should have been caught in previous PRs but the combination of forgetting to change skipIfNoSoxBackend
and the removal of torchaudio.load_wav
caused this. I think this is fine. Thanks for the followup.
The code was not formatted properly, which caused the subsequent steps to fail
Closes #680
torch.assert_allclose
withassertEqual
from pytorch test framework.