Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions tensorflow_io/audio/kernels/audio_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,36 @@ class WAVReadable : public IOReadableInterface {

StringPiece result;
TF_RETURN_IF_ERROR(file_->Read(0, sizeof(header_), &result, (char *)(&header_)));
header_length_ = sizeof(header_);
int64 fmt_position = 12;
while (memcmp(header_.fmt, "fmt ", 4) != 0) {
// Skip JUNK/bext/etc field.
if (memcmp(header_.fmt, "JUNK", 4) != 0 &&
memcmp(header_.fmt, "bext", 4) != 0 &&
memcmp(header_.fmt, "iXML", 4) != 0 &&
memcmp(header_.fmt, "qlty", 4) != 0 &&
memcmp(header_.fmt, "mext", 4) != 0 &&
memcmp(header_.fmt, "levl", 4) != 0 &&
memcmp(header_.fmt, "link", 4) != 0 &&
memcmp(header_.fmt, "axml", 4) != 0) {
return errors::InvalidArgument("unexpected field: ", header_.fmt);
}
int32 size_of_chunk = 4 + 4 + header_.fmt_size;
if (header_.fmt_size % 2 == 1) {
size_of_chunk += 1;
}
fmt_position += size_of_chunk;
// Re-read the header
TF_RETURN_IF_ERROR(file_->Read(fmt_position, sizeof(header_) - 12, &result, (char *)(&header_) + 12));
header_length_ = fmt_position + sizeof(header_) - 12;
}

TF_RETURN_IF_ERROR(ValidateWAVHeader(&header_));
if (header_.riff_size + 8 != file_size_) {
// corrupted file?
}
int64 filesize = header_.riff_size + 8;

int64 position = result.size();

if (header_.fmt_size != 16) {
position += header_.fmt_size - 16;
}
int64 position = header_length_ + header_.fmt_size - 16;

int64 nSamples = 0;
do {
Expand Down Expand Up @@ -167,7 +185,7 @@ class WAVReadable : public IOReadableInterface {
// corrupted file?
}
int64 filesize = header_.riff_size + 8;
int64 position = sizeof(header_) + header_.fmt_size - 16;
int64 position = header_length_ + header_.fmt_size - 16;
do {
StringPiece result;
struct DataHeader head;
Expand Down Expand Up @@ -241,6 +259,7 @@ class WAVReadable : public IOReadableInterface {
DataType dtype_;
TensorShape shape_;
struct WAVHeader header_;
int64 header_length_;
};

REGISTER_KERNEL_BUILDER(Name("IO>WAVReadableInit").Device(DEVICE_CPU),
Expand Down