Skip to content

Commit

Permalink
List[index]::toOptionalStringRef (#42263)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #42263

Allow a way to get a reference to the stored string in a `List<optional<string>>` without having to copy the string.
This for example improves perf of the map_lookup op by 3x.
ghstack-source-id: 109162026

Test Plan: unit tests

Reviewed By: ezyang

Differential Revision: D22830381

fbshipit-source-id: e6af2bc8cebd6e68794eb18daf183979bc6297ae
  • Loading branch information
smessmer authored and facebook-github-bot committed Aug 6, 2020
1 parent f22aa60 commit b44a10c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions aten/src/ATen/core/List.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class ListElementReference final {
return iterator_->toStringRef();
}

template<class _T = T>
std::enable_if_t<std::is_same<c10::optional<std::string>, T>::value && std::is_same<_T, T>::value, c10::optional<std::reference_wrapper<const std::string>>> toOptionalStringRef() {
return iterator_->toOptionalStringRef();
}

friend void swap<T, Iterator>(ListElementReference&& lhs, ListElementReference&& rhs);

private:
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/core/List_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,3 +1088,11 @@ TEST(ListTest, canAccessStringByReference) {
const std::string& str = list[1].toStringRef();
EXPECT_EQ("two", str);
}

TEST(ListTest, canAccessOptionalStringByReference) {
List<c10::optional<std::string>> list({"one", "two", c10::nullopt});
c10::optional<std::reference_wrapper<const std::string>> str1 = list[1].toOptionalStringRef();
c10::optional<std::reference_wrapper<const std::string>> str2 = list[2].toOptionalStringRef();
EXPECT_EQ("two", str1.value().get());
EXPECT_FALSE(str2.has_value());
}
1 change: 1 addition & 0 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ struct CAFFE2_API IValue final {
c10::intrusive_ptr<ivalue::ConstantString> toString() &&;
c10::intrusive_ptr<ivalue::ConstantString> toString() const &;
const std::string& toStringRef() const;
c10::optional<std::reference_wrapper<const std::string>> toOptionalStringRef() const;

// DoubleList
bool isDoubleList() const;
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,13 @@ inline const std::string& IValue::toStringRef() const {
AT_ASSERT(isString(), "Expected String but got ", tagKind());
return static_cast<const c10::ivalue::ConstantString*>(payload.as_intrusive_ptr)->string();
}
inline c10::optional<std::reference_wrapper<const std::string>> IValue::toOptionalStringRef() const {
if (isNone()) {
return c10::nullopt;
}
AT_ASSERT(isString(), "Expected optional<string> but got ", tagKind());
return std::reference_wrapper<const std::string>(static_cast<const c10::ivalue::ConstantString*>(payload.as_intrusive_ptr)->string());
}

inline PyObject* IValue::toPyObject() const {
return toPyObjectHolder()->getPyObject();
Expand Down

0 comments on commit b44a10c

Please sign in to comment.