Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow std::array as operator argument and return #34399

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ namespace impl {
static_assert(!std::is_same<T, at::Scalar>::value, "You tried to register a kernel with an unsupported input type: ArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
};

template<class T, size_t N, bool AllowDeprecatedTypes>
struct assert_is_valid_input_type<std::array<T, N>, AllowDeprecatedTypes>
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
static_assert(!std::is_same<T, at::Scalar>::value, "You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
};

// The following specialisations of assert_is_valid_input_type are technically not
// necessary since we would hit the base case and show an error message
// there if they didn't exist, but we can show a better error message
Expand Down Expand Up @@ -171,6 +177,12 @@ namespace impl {
// TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported output type: std::vector<T>. Please use List<T> instead.");
};

template<class T, size_t N, bool AllowDeprecatedTypes>
struct assert_is_valid_output_type<std::array<T, N>, AllowDeprecatedTypes>
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {
static_assert(!std::is_same<T, at::Scalar>::value, "You tried to register a kernel with an unsupported output type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
};

// The following specialisations of assert_is_valid_output_type are technically not
// necessary since we would hit the base case and show an error message
// there if they didn't exist, but we can show a better error message
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/core/boxing/impl/test_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
<< expectMessageContains << "\" but didn't throw";
}

template<class T, size_t N>
void expectListEquals(c10::ArrayRef<T> expected, std::array<T, N> actual) {
EXPECT_EQ(expected.size(), actual.size());
for (size_t i = 0; i < expected.size(); ++i) {
EXPECT_EQ(expected[i], actual[i]);
}
}

template<class T>
void expectListEquals(c10::ArrayRef<T> expected, c10::ArrayRef<T> actual) {
EXPECT_EQ(expected.size(), actual.size());
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ struct CAFFE2_API IValue final {
class T,
enable_if_ivalue_constructible<T> = nullptr>
IValue(const std::vector<T>& v);
template<class T, size_t N>
IValue(std::array<T, N> v);

// GenericDict
IValue(c10::Dict<IValue, IValue> v);
Expand Down
30 changes: 30 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,28 @@ c10::List<Elem> generic_to(
return impl::toTypedList<Elem>(std::move(ivalue).toList());
}

namespace detail {
template <typename Elem, size_t... I>
std::array<Elem, sizeof...(I)> generic_to_array(
IValue ivalue,
_fake_type<std::array<Elem, sizeof...(I)>>,
std::index_sequence<I...>) {
// We need to do a deep copy of the array because there might be other
// references to this same IValue that also use the list. We can't just
// move the elements out.
auto list = std::move(ivalue).to<List<Elem>>();
TORCH_CHECK(list.size() == sizeof...(I), "Tried to convert a List with ", list.size()," elements to a fixed-size array of size ", sizeof...(I));
return {list[I]...};
}
}

template <typename Elem, size_t N>
std::array<Elem, N> generic_to(
IValue ivalue,
_fake_type<std::array<Elem, N>> ft) {
return detail::generic_to_array(ivalue, ft, std::make_index_sequence<N>());
}

template <typename Key, typename Value>
c10::Dict<Key, Value> generic_to(
IValue ivalue,
Expand Down Expand Up @@ -793,6 +815,14 @@ inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
list.push_back(e);
}
}
template<class T, size_t N> inline IValue::IValue(std::array<T, N> v)
: IValue(c10::List<T>()) {
auto list = to<c10::List<T>>();
list.reserve(v.size());
for (auto& e : v) {
list.push_back(std::move(e));
}
}

inline IValue::IValue(c10::impl::GenericDict v)
: tag(Tag::GenericDict), is_intrusive_ptr(true) {
Expand Down
54 changes: 54 additions & 0 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,60 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
"(Tensor[] a) -> Tensor[]");


// std::array list types (with empty list)
testArgTypes<std::array<double, 0>>::test(
std::array<double, 0>(), [] (std::array<double, 0> v) {},
std::array<double, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<c10::List<double>>().size()));},
"(float[0] a) -> float[0]");
testArgTypes<std::array<int64_t, 0>>::test(
std::array<int64_t, 0>(), [] (std::array<int64_t, 0> v) {},
std::array<int64_t, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<c10::List<int64_t>>().size()));},
"(int[0] a) -> int[0]");
testArgTypes<std::array<bool, 0>>::test(
std::array<bool, 0>(), [] (std::array<bool, 0> v) {},
std::array<bool, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<std::array<bool, 0>>().size()));},
"(bool[0] a) -> bool[0]");
testArgTypes<std::array<std::string, 0>>::test(
std::array<std::string, 0>(), [] (std::array<std::string, 0> v) {EXPECT_EQ(0, v.size());},
std::array<std::string, 0>(), [] (const IValue& v) {EXPECT_EQ(0, v.toListRef().size());},
"(str[0] a) -> str[0]");


// std::array list types (with non-empty list)
testArgTypes<std::array<double, 2>>::test(
std::array<double, 2>({1.5, 2.5}), [] (std::array<double, 2> v) {expectListEquals({1.5, 2.5}, v);},
std::array<double, 2>({3.5, 4.5}), [] (const IValue& v) {expectListEquals({3.5, 4.5}, v.to<std::array<double, 2>>());},
"(float[2] a) -> float[2]");
testArgTypes<std::array<int64_t, 2>>::test(
std::array<int64_t, 2>({1, 2}), [] (std::array<int64_t, 2> v) {expectListEquals({1, 2}, v);},
std::array<int64_t, 2>({3, 4}), [] (const IValue& v) {expectListEquals({3, 4}, v.to<std::array<int64_t, 2>>());},
"(int[2] a) -> int[2]");
testArgTypes<std::array<bool, 2>>::test(
std::array<bool, 2>({true, false}), [] (std::array<bool, 2> v) {expectListEquals({true, false}, v);},
std::array<bool, 2>({true, false}), [] (const IValue& v) {expectListEquals({true, false}, v.to<std::array<bool, 2>>());},
"(bool[2] a) -> bool[2]");
testArgTypes<std::array<std::string, 2>>::test(
std::array<std::string, 2>({"first", "second"}), [] (std::array<std::string, 2> v) {expectListEquals({"first", "second"}, v);},
std::array<std::string, 2>({"first", "second"}), [] (const IValue& v) {
EXPECT_EQ(2, v.toListRef().size());
EXPECT_EQ("first", v.toListRef()[0].toStringRef());
EXPECT_EQ("second", v.toListRef()[1].toStringRef());
},
"(str[2] a) -> str[2]");
testArgTypes<std::array<Tensor, 2>>::test(
std::array<Tensor, 2>({dummyTensor(c10::DispatchKey::CPUTensorId), dummyTensor(c10::DispatchKey::CUDATensorId)}), [] (std::array<Tensor, 2> v) {
EXPECT_EQ(2, v.size());
EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v[0]));
EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v[1]));
},
std::array<Tensor, 2>({dummyTensor(c10::DispatchKey::CUDATensorId), dummyTensor(c10::DispatchKey::CPUTensorId)}), [] (const IValue& v) {
EXPECT_EQ(2, v.to<c10::List<at::Tensor>>().size());
EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(0)));
EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(1)));
},
"(Tensor[2] a) -> Tensor[2]");


// deprecated list types (with empty list)
testArgTypes<std::vector<double>>::test<TestLegacyAPI>(
std::vector<double>(), [] (const std::vector<double>& v) {EXPECT_EQ(0, v.size());},
Expand Down