Skip to content

Commit

Permalink
Fix SignalInfo member name to frame (#734)
Browse files Browse the repository at this point in the history
This PR fixes the wrong member name of SignalInfo introduced in #718. 

 - `num_samples` == `num_frames` * `num_channels`.
  • Loading branch information
mthrok committed Jun 23, 2020
1 parent 7427bf5 commit e0f4c0e
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 18 deletions.
10 changes: 5 additions & 5 deletions test/sox_io_backend/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_wav(self, dtype, sample_rate, num_channels):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
Expand All @@ -55,7 +55,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
Expand All @@ -74,7 +74,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate):
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
# mp3 does not preserve the number of samples
# assert info.get_num_samples() == sample_rate * duration
# assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
Expand All @@ -92,7 +92,7 @@ def test_flac(self, sample_rate, num_channels, compression_level):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
Expand All @@ -110,5 +110,5 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
2 changes: 1 addition & 1 deletion test/sox_io_backend/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ def test_info_wav(self, dtype, sample_rate, num_channels):
ts_info = ts_info_func(audio_path)

assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_samples() == ts_info.get_num_samples()
assert py_info.get_num_frames() == ts_info.get_num_frames()
assert py_info.get_num_channels() == ts_info.get_num_channels()
2 changes: 1 addition & 1 deletion torchaudio/csrc/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ static auto registerSignalInfo =
.def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_samples", &SignalInfo::getNumSamples);
.def("get_num_frames", &SignalInfo::getNumFrames);

static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options()
Expand Down
8 changes: 4 additions & 4 deletions torchaudio/csrc/typedefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ namespace torchaudio {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_)
const int64_t num_frames_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_samples(num_samples_){};
num_frames(num_frames_){};

int64_t SignalInfo::getSampleRate() const {
return sample_rate;
Expand All @@ -17,7 +17,7 @@ int64_t SignalInfo::getNumChannels() const {
return num_channels;
}

int64_t SignalInfo::getNumSamples() const {
return num_samples;
int64_t SignalInfo::getNumFrames() const {
return num_frames;
}
} // namespace torchaudio
6 changes: 3 additions & 3 deletions torchaudio/csrc/typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ namespace torchaudio {
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_samples;
int64_t num_frames;

SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_);
const int64_t num_frames_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumSamples() const;
int64_t getNumFrames() const;
};

} // namespace torchaudio
Expand Down
8 changes: 4 additions & 4 deletions torchaudio/extension/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ class SignalInfo:
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def __init__(self, sample_rate: int, num_channels: int, num_samples: int):
def __init__(self, sample_rate: int, num_channels: int, num_frames: int):
self.sample_rate = sample_rate
self.num_channels = num_channels
self.num_samples = num_samples
self.num_frames = num_frames

def get_sample_rate(self):
return self.sample_rate

def get_num_channels(self):
return self.num_channels

def get_num_samples(self):
return self.num_samples
def get_num_frames(self):
return self.num_frames

DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo)
Expand Down

0 comments on commit e0f4c0e

Please sign in to comment.