diff --git a/tensorflow_io/core/kernels/arrow/arrow_kernels.cc b/tensorflow_io/core/kernels/arrow/arrow_kernels.cc index 5a8378d8d..6641f06a0 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_kernels.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_kernels.cc @@ -544,97 +544,68 @@ class FeatherReadable : public IOReadableInterface { new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); - // FEA1.....[metadata][uint32 metadata_length]FEA1 - static constexpr const char* kFeatherMagicBytes = "FEA1"; - - size_t header_length = strlen(kFeatherMagicBytes); - size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes); - - string buffer; - buffer.resize(header_length > footer_length ? header_length - : footer_length); - - StringPiece result; - - TF_RETURN_IF_ERROR(file_->Read(0, header_length, &result, &buffer[0])); - if (memcmp(buffer.data(), kFeatherMagicBytes, header_length) != 0) { - return errors::InvalidArgument("not a feather file"); + std::shared_ptr feather_file; + feather_file.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); + auto maybe_reader = arrow::ipc::feather::Reader::Open(feather_file); + if (!maybe_reader.ok()) { + return errors::Internal(maybe_reader.status().ToString()); } + std::shared_ptr reader = + maybe_reader.ValueOrDie(); + std::shared_ptr schema = reader->schema(); - TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length, footer_length, - &result, &buffer[0])); - if (memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes, - footer_length - sizeof(uint32)) != 0) { - return errors::InvalidArgument("incomplete feather file"); - } - - uint32 metadata_length = *reinterpret_cast(buffer.data()); - - buffer.resize(metadata_length); - - TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length - metadata_length, - metadata_length, &result, &buffer[0])); - - const ::arrow::ipc::feather::fbs::CTable* table = - ::arrow::ipc::feather::fbs::GetCTable(buffer.data()); - - if (table->version() < ::arrow::ipc::feather::kFeatherV1Version) { - return errors::InvalidArgument("feather file is old: ", table->version(), - " vs. ", - ::arrow::ipc::feather::kFeatherV1Version); + std::shared_ptr table; + arrow::Status s = reader->Read(&table); + if (!s.ok()) { + return errors::Internal(s.ToString()); } - for (size_t i = 0; i < table->columns()->size(); i++) { + for (int i = 0; i < schema->num_fields(); i++) { ::tensorflow::DataType dtype = ::tensorflow::DataType::DT_INVALID; - switch (table->columns()->Get(i)->values()->type()) { - case ::arrow::ipc::feather::fbs::Type::BOOL: + switch (schema->field(i)->type()->id()) { + case ::arrow::Type::BOOL: dtype = ::tensorflow::DataType::DT_BOOL; break; - case ::arrow::ipc::feather::fbs::Type::INT8: + case ::arrow::Type::INT8: dtype = ::tensorflow::DataType::DT_INT8; break; - case ::arrow::ipc::feather::fbs::Type::INT16: + case ::arrow::Type::INT16: dtype = ::tensorflow::DataType::DT_INT16; break; - case ::arrow::ipc::feather::fbs::Type::INT32: + case ::arrow::Type::INT32: dtype = ::tensorflow::DataType::DT_INT32; break; - case ::arrow::ipc::feather::fbs::Type::INT64: + case ::arrow::Type::INT64: dtype = ::tensorflow::DataType::DT_INT64; break; - case ::arrow::ipc::feather::fbs::Type::UINT8: + case ::arrow::Type::UINT8: dtype = ::tensorflow::DataType::DT_UINT8; break; - case ::arrow::ipc::feather::fbs::Type::UINT16: + case ::arrow::Type::UINT16: dtype = ::tensorflow::DataType::DT_UINT16; break; - case ::arrow::ipc::feather::fbs::Type::UINT32: + case ::arrow::Type::UINT32: dtype = ::tensorflow::DataType::DT_UINT32; break; - case ::arrow::ipc::feather::fbs::Type::UINT64: + case ::arrow::Type::UINT64: dtype = ::tensorflow::DataType::DT_UINT64; break; - case ::arrow::ipc::feather::fbs::Type::FLOAT: + case ::arrow::Type::FLOAT: dtype = ::tensorflow::DataType::DT_FLOAT; break; - case ::arrow::ipc::feather::fbs::Type::DOUBLE: + case ::arrow::Type::DOUBLE: dtype = ::tensorflow::DataType::DT_DOUBLE; break; - case ::arrow::ipc::feather::fbs::Type::UTF8: - case ::arrow::ipc::feather::fbs::Type::BINARY: - case ::arrow::ipc::feather::fbs::Type::CATEGORY: - case ::arrow::ipc::feather::fbs::Type::TIMESTAMP: - case ::arrow::ipc::feather::fbs::Type::DATE: - case ::arrow::ipc::feather::fbs::Type::TIME: - // case ::arrow::ipc::feather::fbs::Type::LARGE_UTF8: - // case ::arrow::ipc::feather::fbs::Type::LARGE_BINARY: + case ::arrow::Type::BINARY: + dtype = ::tensorflow::DataType::DT_STRING; + break; default: break; } shapes_.push_back(TensorShape({static_cast(table->num_rows())})); dtypes_.push_back(dtype); - columns_.push_back(table->columns()->Get(i)->name()->str()); - columns_index_[table->columns()->Get(i)->name()->str()] = i; + columns_.push_back(schema->field(i)->name()); + columns_index_[schema->field(i)->name()] = i; } return Status::OK(); @@ -751,6 +722,17 @@ class FeatherReadable : public IOReadableInterface { FEATHER_PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>); break; + case DT_STRING: { + int64 curr_index = 0; + for (auto chunk : slice->chunks()) { + for (int64_t item = 0; item < chunk->length(); item++) { + value->flat()(curr_index) = + (dynamic_cast<::arrow::BinaryArray*>(chunk.get())) + ->GetString(item); + curr_index++; + } + } + } break; default: return errors::InvalidArgument("data type is not supported: ", DataTypeString(value->dtype())); diff --git a/tests/test_feather.py b/tests/test_feather.py index 52668582b..abf030734 100644 --- a/tests/test_feather.py +++ b/tests/test_feather.py @@ -17,11 +17,24 @@ import os import tempfile +import pytest +import tensorflow as tf import tensorflow_io as tfio -def test_feather_format(): +@pytest.mark.parametrize( + ("version"), + [ + 1, + 2, + ], + ids=[ + "v1", + "v2", + ], +) +def test_feather_format(version): """test_feather_format""" import numpy as np import pandas as pd @@ -39,7 +52,7 @@ def test_feather_format(): } df = pd.DataFrame(data).sort_index(axis=1) with tempfile.NamedTemporaryFile(delete=False) as f: - pa_feather.write_feather(df, f, version=1) + pa_feather.write_feather(df, f, version=version) feather = tfio.IOTensor.from_feather(f.name) for column in df.columns: @@ -50,5 +63,32 @@ def test_feather_format(): os.unlink(f.name) +def test_binary_feather_format(): + """test_binary_feather_format""" + import numpy as np + import pandas as pd + + from pyarrow import feather as pa_feather + import pyarrow as pa + + local_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.png" + ) + with open(local_path, "rb") as f: + data = [f.read()] + table = pa.Table.from_arrays([data], ["data"]) + + chunk_size = 1000 + with tempfile.NamedTemporaryFile(delete=False) as f: + pa_feather.write_feather(table, f, chunksize=chunk_size) + + feather = tfio.IOTensor.from_feather(f.name) + assert feather("data").shape == [1] + assert feather("data").dtype == tf.string + assert np.all(feather("data").to_tensor().numpy() == data[0]) + + os.unlink(f.name) + + if __name__ == "__main__": test.main()