Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace torch.assert_allclose with assertEqual from pytorch #1387

Merged
merged 1 commit into from
Mar 15, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 8 additions & 8 deletions test/torchaudio_unittest/compliance_kaldi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
window[f, s] = wave[s_in_wave]


@common_utils.skipIfNoSoxBackend
@common_utils.skipIfNoSox
class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
backend = 'sox'
discort marked this conversation as resolved.
Show resolved Hide resolved
backend = 'sox_io'

kaldi_output_dir = common_utils.get_asset_path('kaldi')
test_filepath = common_utils.get_asset_path('kaldi_file.wav')
Expand Down Expand Up @@ -91,7 +91,7 @@ def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_

for r in range(m):
extract_window(window, waveform, r, window_size, window_shift, snip_edges)
torch.testing.assert_allclose(window, output)
self.assertEqual(window, output)

def test_get_strided(self):
# generate any combination where 0 < window_size <= num_samples and
Expand All @@ -116,7 +116,7 @@ def _create_data_set(self):
sound, sample_rate = torchaudio.load(self.test_filepath, normalization=False)
print(y >> 16)
self.assertTrue(sample_rate == sr)
torch.testing.assert_allclose(y, sound)
self.assertEqual(y, sound)

def _print_diagnostic(self, output, expect_output):
# given an output and expected output, it will print the absolute/relative errors (max and mean squared)
Expand Down Expand Up @@ -170,15 +170,15 @@ def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_fil
output = get_output_fn(sound, args)

self._print_diagnostic(output, kaldi_output)
torch.testing.assert_allclose(output, kaldi_output, atol=atol, rtol=rtol)
self.assertEqual(output, kaldi_output, atol=atol, rtol=rtol)

def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))

def test_resample_waveform(self):
def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound, args[1], args[2])
output = kaldi.resample_waveform(sound.to(torch.float32), args[1], args[2])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why it is necessary to add to(torch.float32) here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't why but sound equals to int, so you will get RuntimeError: expected scalar type Long but found Float

In [1]: sound.dtype
Out[1]: torch.int16

To reproduce:

sound = torch.randint(-2742, 8204, (1, 8000))
kaldi.resample_waveform(sound, 16000, 1000)
~/python/projects/torchaudio/torchaudio/compliance/kaldi.py in resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width)
    841     num_wavs, length = waveform.shape
    842     waveform = F.pad(waveform, (width, width + orig_freq))
--> 843     resampled = F.conv1d(waveform[:, None], kernel, stride=orig_freq)
    844     resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
    845     target_length = int(math.ceil(new_freq * length / orig_freq))

RuntimeError: expected scalar type Long but found Float

Is it expected? @mthrok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @mthrok

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @discort

Sorry for the late response. I looked at the code, and it seems that this should have been caught in previous PRs but the combination of forgetting to change skipIfNoSoxBackend and the removal of torchaudio.load_wav caused this. I think this is fine. Thanks for the followup.

return output

self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)
Expand Down Expand Up @@ -221,7 +221,7 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
estimate = estimate[..., n_to_trim:-n_to_trim]

torch.testing.assert_allclose(estimate, ground_truth, atol=atol, rtol=rtol)
self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)

def test_resample_waveform_downsample_accuracy(self):
for i in range(1, 20):
Expand All @@ -246,4 +246,4 @@ def test_resample_waveform_multi_channel(self):
single_channel = self.test1_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
self.test1_signal_sr // 2)
torch.testing.assert_allclose(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)