From 193ecc38e2d3d99dce622b5a5ebc04f27aba5cab Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Wed, 19 Feb 2020 12:42:13 -0800 Subject: [PATCH] Avoid copying TF string into std:string. --- third_party/xla_client/record_reader.cc | 9 ++------- third_party/xla_client/record_reader.h | 4 +++- torch_xla/csrc/init_python_bindings.cpp | 10 +++++----- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/third_party/xla_client/record_reader.cc b/third_party/xla_client/record_reader.cc index b4689aaa374b..bedcb87330a1 100644 --- a/third_party/xla_client/record_reader.cc +++ b/third_party/xla_client/record_reader.cc @@ -20,18 +20,13 @@ RecordReader::RecordReader(std::string path, const string& compression, reader_.reset(new tensorflow::io::RecordReader(file_.get(), options)); } -bool RecordReader::Read(std::string* value) { - // We need to pass a tensorflow::tstring here, which will ultimately result in - // making a copy. Hopefully the tensorflow string story will end with a nice - // outcome. - tensorflow::tstring tvalue; +bool RecordReader::Read(Data* value) { std::lock_guard slock(lock_); - xla::Status status = reader_->ReadRecord(&offset_, &tvalue); + xla::Status status = reader_->ReadRecord(&offset_, value); if (tensorflow::errors::IsOutOfRange(status)) { return false; } XLA_CHECK_OK(status) << path_ << " offset " << offset_; - *value = tvalue; return true; } diff --git a/third_party/xla_client/record_reader.h b/third_party/xla_client/record_reader.h index fe828833b7c0..e931a89cbed6 100644 --- a/third_party/xla_client/record_reader.h +++ b/third_party/xla_client/record_reader.h @@ -13,12 +13,14 @@ namespace util { class RecordReader { public: + using Data = tensorflow::tstring; + RecordReader(std::string path, const std::string& compression, int64 buffer_size); const std::string& path() const { return path_; } - bool Read(std::string* value); + bool Read(Data* value); private: std::string path_; diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 0fa48e5a5849..7d7b34e5a59f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -275,7 +275,7 @@ std::shared_ptr CreateRecordReader( } bool RecordRead(const std::shared_ptr& reader, - std::string* value) { + xla::util::RecordReader::Data* value) { NoGilSection nogil; return reader->Read(value); } @@ -286,12 +286,12 @@ py::object RecordReadExample( return std::vector({size}); }; - std::string value; + xla::util::RecordReader::Data value; if (!RecordRead(reader, &value)) { return py::none(); } tensorflow::Example exmsg; - if (!exmsg.ParseFromString(value)) { + if (!exmsg.ParseFromArray(value.data(), value.size())) { XLA_ERROR() << "Unable to parse TF example from " << reader->path(); } auto example = py::dict(); @@ -592,11 +592,11 @@ void InitXlaModuleBindings(py::module m) { m.def( "_xla_tfrecord_read", [](const std::shared_ptr& reader) -> py::object { - std::string record; + xla::util::RecordReader::Data record; if (!RecordRead(reader, &record)) { return py::none(); } - return py::bytes(record); + return py::bytes(record.data(), record.size()); }); m.def("_xla_tfexample_read", [](const std::shared_ptr& reader) {