Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 41 additions & 59 deletions tensorflow_io/core/kernels/arrow/arrow_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,97 +544,68 @@ class FeatherReadable : public IOReadableInterface {
new SizedRandomAccessFile(env_, filename, memory_data, memory_size));
TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_));

// FEA1.....[metadata][uint32 metadata_length]FEA1
static constexpr const char* kFeatherMagicBytes = "FEA1";

size_t header_length = strlen(kFeatherMagicBytes);
size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes);

string buffer;
buffer.resize(header_length > footer_length ? header_length
: footer_length);

StringPiece result;

TF_RETURN_IF_ERROR(file_->Read(0, header_length, &result, &buffer[0]));
if (memcmp(buffer.data(), kFeatherMagicBytes, header_length) != 0) {
return errors::InvalidArgument("not a feather file");
std::shared_ptr<ArrowRandomAccessFile> feather_file;
feather_file.reset(new ArrowRandomAccessFile(file_.get(), file_size_));
auto maybe_reader = arrow::ipc::feather::Reader::Open(feather_file);
if (!maybe_reader.ok()) {
return errors::Internal(maybe_reader.status().ToString());
}
std::shared_ptr<arrow::ipc::feather::Reader> reader =
maybe_reader.ValueOrDie();
std::shared_ptr<arrow::Schema> schema = reader->schema();

TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length, footer_length,
&result, &buffer[0]));
if (memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes,
footer_length - sizeof(uint32)) != 0) {
return errors::InvalidArgument("incomplete feather file");
}

uint32 metadata_length = *reinterpret_cast<const uint32*>(buffer.data());

buffer.resize(metadata_length);

TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length - metadata_length,
metadata_length, &result, &buffer[0]));

const ::arrow::ipc::feather::fbs::CTable* table =
::arrow::ipc::feather::fbs::GetCTable(buffer.data());

if (table->version() < ::arrow::ipc::feather::kFeatherV1Version) {
return errors::InvalidArgument("feather file is old: ", table->version(),
" vs. ",
::arrow::ipc::feather::kFeatherV1Version);
std::shared_ptr<arrow::Table> table;
arrow::Status s = reader->Read(&table);
if (!s.ok()) {
return errors::Internal(s.ToString());
}

for (size_t i = 0; i < table->columns()->size(); i++) {
for (int i = 0; i < schema->num_fields(); i++) {
::tensorflow::DataType dtype = ::tensorflow::DataType::DT_INVALID;
switch (table->columns()->Get(i)->values()->type()) {
case ::arrow::ipc::feather::fbs::Type::BOOL:
switch (schema->field(i)->type()->id()) {
case ::arrow::Type::BOOL:
dtype = ::tensorflow::DataType::DT_BOOL;
break;
case ::arrow::ipc::feather::fbs::Type::INT8:
case ::arrow::Type::INT8:
dtype = ::tensorflow::DataType::DT_INT8;
break;
case ::arrow::ipc::feather::fbs::Type::INT16:
case ::arrow::Type::INT16:
dtype = ::tensorflow::DataType::DT_INT16;
break;
case ::arrow::ipc::feather::fbs::Type::INT32:
case ::arrow::Type::INT32:
dtype = ::tensorflow::DataType::DT_INT32;
break;
case ::arrow::ipc::feather::fbs::Type::INT64:
case ::arrow::Type::INT64:
dtype = ::tensorflow::DataType::DT_INT64;
break;
case ::arrow::ipc::feather::fbs::Type::UINT8:
case ::arrow::Type::UINT8:
dtype = ::tensorflow::DataType::DT_UINT8;
break;
case ::arrow::ipc::feather::fbs::Type::UINT16:
case ::arrow::Type::UINT16:
dtype = ::tensorflow::DataType::DT_UINT16;
break;
case ::arrow::ipc::feather::fbs::Type::UINT32:
case ::arrow::Type::UINT32:
dtype = ::tensorflow::DataType::DT_UINT32;
break;
case ::arrow::ipc::feather::fbs::Type::UINT64:
case ::arrow::Type::UINT64:
dtype = ::tensorflow::DataType::DT_UINT64;
break;
case ::arrow::ipc::feather::fbs::Type::FLOAT:
case ::arrow::Type::FLOAT:
dtype = ::tensorflow::DataType::DT_FLOAT;
break;
case ::arrow::ipc::feather::fbs::Type::DOUBLE:
case ::arrow::Type::DOUBLE:
dtype = ::tensorflow::DataType::DT_DOUBLE;
break;
case ::arrow::ipc::feather::fbs::Type::UTF8:
case ::arrow::ipc::feather::fbs::Type::BINARY:
case ::arrow::ipc::feather::fbs::Type::CATEGORY:
case ::arrow::ipc::feather::fbs::Type::TIMESTAMP:
case ::arrow::ipc::feather::fbs::Type::DATE:
case ::arrow::ipc::feather::fbs::Type::TIME:
// case ::arrow::ipc::feather::fbs::Type::LARGE_UTF8:
// case ::arrow::ipc::feather::fbs::Type::LARGE_BINARY:
case ::arrow::Type::BINARY:
dtype = ::tensorflow::DataType::DT_STRING;
break;
default:
break;
}
shapes_.push_back(TensorShape({static_cast<int64>(table->num_rows())}));
dtypes_.push_back(dtype);
columns_.push_back(table->columns()->Get(i)->name()->str());
columns_index_[table->columns()->Get(i)->name()->str()] = i;
columns_.push_back(schema->field(i)->name());
columns_index_[schema->field(i)->name()] = i;
}

return Status::OK();
Expand Down Expand Up @@ -751,6 +722,17 @@ class FeatherReadable : public IOReadableInterface {
FEATHER_PROCESS_TYPE(double,
::arrow::NumericArray<::arrow::DoubleType>);
break;
case DT_STRING: {
int64 curr_index = 0;
for (auto chunk : slice->chunks()) {
for (int64_t item = 0; item < chunk->length(); item++) {
value->flat<tstring>()(curr_index) =
(dynamic_cast<::arrow::BinaryArray*>(chunk.get()))
->GetString(item);
curr_index++;
}
}
} break;
default:
return errors::InvalidArgument("data type is not supported: ",
DataTypeString(value->dtype()));
Expand Down
44 changes: 42 additions & 2 deletions tests/test_feather.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,24 @@

import os
import tempfile
import pytest

import tensorflow as tf
import tensorflow_io as tfio


def test_feather_format():
@pytest.mark.parametrize(
("version"),
[
1,
2,
],
ids=[
"v1",
"v2",
],
)
def test_feather_format(version):
"""test_feather_format"""
import numpy as np
import pandas as pd
Expand All @@ -39,7 +52,7 @@ def test_feather_format():
}
df = pd.DataFrame(data).sort_index(axis=1)
with tempfile.NamedTemporaryFile(delete=False) as f:
pa_feather.write_feather(df, f, version=1)
pa_feather.write_feather(df, f, version=version)

feather = tfio.IOTensor.from_feather(f.name)
for column in df.columns:
Expand All @@ -50,5 +63,32 @@ def test_feather_format():
os.unlink(f.name)


def test_binary_feather_format():
"""test_binary_feather_format"""
import numpy as np
import pandas as pd

from pyarrow import feather as pa_feather
import pyarrow as pa

local_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.png"
)
with open(local_path, "rb") as f:
data = [f.read()]
table = pa.Table.from_arrays([data], ["data"])

chunk_size = 1000
with tempfile.NamedTemporaryFile(delete=False) as f:
pa_feather.write_feather(table, f, chunksize=chunk_size)

feather = tfio.IOTensor.from_feather(f.name)
assert feather("data").shape == [1]
assert feather("data").dtype == tf.string
assert np.all(feather("data").to_tensor().numpy() == data[0])

os.unlink(f.name)


if __name__ == "__main__":
test.main()