Skip to content

Commit

Permalink
Reject saving GSM when not compatible (#1384)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhurt committed Mar 12, 2021
1 parent 6d81ab8 commit 47d0008
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
10 changes: 9 additions & 1 deletion test/torchaudio_unittest/backend/sox_io/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,15 @@ def test_save_amr_nb(self, test_mode, bit_rate):
)
def test_save_gsm(self, test_mode):
self.assert_save_consistency(
"gsm", test_mode=test_mode)
"gsm", num_channels=1, test_mode=test_mode)
with self.assertRaises(
RuntimeError, msg="gsm format only supports single channel audio."):
self.assert_save_consistency(
"gsm", num_channels=2, test_mode=test_mode)
with self.assertRaises(
RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."):
self.assert_save_consistency(
"gsm", sample_rate=16000, test_mode=test_mode)

@parameterized.expand([
("wav", "PCM_S", 16),
Expand Down
17 changes: 17 additions & 0 deletions torchaudio/csrc/sox/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ void save_audio_file(
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "htk format only supports single channel audio.");
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "gsm format only supports single channel audio.");
TORCH_CHECK(
sample_rate == 8000,
"gsm format only supports a sampling rate of 8kHz.");
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
Expand Down Expand Up @@ -243,6 +250,16 @@ void save_audio_fileobj(
throw std::runtime_error(
"htk format only supports single channel audio.");
}
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"gsm format only supports single channel audio.");
}
if (sample_rate != 8000) {
throw std::runtime_error(
"gsm format only supports a sampling rate of 8kHz.");
}
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
Expand Down

0 comments on commit 47d0008

Please sign in to comment.