diff --git a/tensorflow_io/audio/kernels/audio_kernels.cc b/tensorflow_io/audio/kernels/audio_kernels.cc index 022dc6f49..a46b79eb8 100644 --- a/tensorflow_io/audio/kernels/audio_kernels.cc +++ b/tensorflow_io/audio/kernels/audio_kernels.cc @@ -83,18 +83,36 @@ class WAVReadable : public IOReadableInterface { StringPiece result; TF_RETURN_IF_ERROR(file_->Read(0, sizeof(header_), &result, (char *)(&header_))); + header_length_ = sizeof(header_); + int64 fmt_position = 12; + while (memcmp(header_.fmt, "fmt ", 4) != 0) { + // Skip JUNK/bext/etc field. + if (memcmp(header_.fmt, "JUNK", 4) != 0 && + memcmp(header_.fmt, "bext", 4) != 0 && + memcmp(header_.fmt, "iXML", 4) != 0 && + memcmp(header_.fmt, "qlty", 4) != 0 && + memcmp(header_.fmt, "mext", 4) != 0 && + memcmp(header_.fmt, "levl", 4) != 0 && + memcmp(header_.fmt, "link", 4) != 0 && + memcmp(header_.fmt, "axml", 4) != 0) { + return errors::InvalidArgument("unexpected field: ", header_.fmt); + } + int32 size_of_chunk = 4 + 4 + header_.fmt_size; + if (header_.fmt_size % 2 == 1) { + size_of_chunk += 1; + } + fmt_position += size_of_chunk; + // Re-read the header + TF_RETURN_IF_ERROR(file_->Read(fmt_position, sizeof(header_) - 12, &result, (char *)(&header_) + 12)); + header_length_ = fmt_position + sizeof(header_) - 12; + } TF_RETURN_IF_ERROR(ValidateWAVHeader(&header_)); if (header_.riff_size + 8 != file_size_) { // corrupted file? } int64 filesize = header_.riff_size + 8; - - int64 position = result.size(); - - if (header_.fmt_size != 16) { - position += header_.fmt_size - 16; - } + int64 position = header_length_ + header_.fmt_size - 16; int64 nSamples = 0; do { @@ -167,7 +185,7 @@ class WAVReadable : public IOReadableInterface { // corrupted file? } int64 filesize = header_.riff_size + 8; - int64 position = sizeof(header_) + header_.fmt_size - 16; + int64 position = header_length_ + header_.fmt_size - 16; do { StringPiece result; struct DataHeader head; @@ -241,6 +259,7 @@ class WAVReadable : public IOReadableInterface { DataType dtype_; TensorShape shape_; struct WAVHeader header_; + int64 header_length_; }; REGISTER_KERNEL_BUILDER(Name("IO>WAVReadableInit").Device(DEVICE_CPU),