Skip to content

Commit

Permalink
Allow std::array as operator argument and return
Browse files Browse the repository at this point in the history
Pull Request resolved: #34399

Custom ops can now take std::array as arguments and return it.
This PR also moves the ops in native_functions.yaml that were blocked by this to now `use_c10_dispatcher: full`.
ghstack-source-id: 99833901

Differential Revision: [D20315072](https://our.internmc.facebook.com/intern/diff/D20315072/)
  • Loading branch information
smessmer committed Mar 10, 2020
1 parent 8a1f95b commit 8541db5
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 0 deletions.
12 changes: 12 additions & 0 deletions aten/src/ATen/core/boxing/kernel_functor.h
Expand Up @@ -105,6 +105,12 @@ namespace detail {
static_assert(!std::is_same<T, at::Scalar>::value, "You tried to register a kernel with an unsupported input type: ArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
};

template<class T, size_t N, bool AllowDeprecatedTypes>
struct assert_is_valid_input_type<std::array<T, N>, AllowDeprecatedTypes>
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
static_assert(!std::is_same<T, at::Scalar>::value, "You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
};

// The following specialisations of assert_is_valid_input_type are technically not
// necessary since we would hit the base case and show an error message
// there if they didn't exist, but we can show a better error message
Expand Down Expand Up @@ -171,6 +177,12 @@ namespace detail {
// TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported output type: std::vector<T>. Please use List<T> instead.");
};

template<class T, size_t N, bool AllowDeprecatedTypes>
struct assert_is_valid_output_type<std::array<T, N>, AllowDeprecatedTypes>
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {
static_assert(!std::is_same<T, at::Scalar>::value, "You tried to register a kernel with an unsupported output type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
};

// The following specialisations of assert_is_valid_output_type are technically not
// necessary since we would hit the base case and show an error message
// there if they didn't exist, but we can show a better error message
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/core/boxing/test_helpers.h
Expand Up @@ -97,6 +97,14 @@ void expectListEquals(c10::ArrayRef<T> expected, std::vector<T> actual) {
}
}

template<class T, size_t N>
void expectListEquals(c10::ArrayRef<T> expected, std::array<T, N> actual) {
EXPECT_EQ(expected.size(), actual.size());
for (size_t i = 0; i < expected.size(); ++i) {
EXPECT_EQ(expected[i], actual[i]);
}
}

// NB: This is not really sound, but all of the type sets constructed here
// are singletons so it's fine
static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) {
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/ivalue.h
Expand Up @@ -329,6 +329,8 @@ struct CAFFE2_API IValue final {
IValue(at::ArrayRef<T> v);
template<class T>
IValue(const std::vector<T>& v);
template<class T, size_t N>
IValue(std::array<T, N> v);

// GenericDict
IValue(c10::Dict<IValue, IValue> v);
Expand Down
30 changes: 30 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -547,6 +547,28 @@ c10::List<Elem> generic_to(
return impl::toTypedList<Elem>(std::move(ivalue).toList());
}

namespace detail {
template <typename Elem, size_t... I>
std::array<Elem, sizeof...(I)> generic_to_array(
IValue ivalue,
_fake_type<std::array<Elem, sizeof...(I)>>,
std::index_sequence<I...>) {
// We need to do a deep copy of the array because there might be other
// references to this same IValue that also use the list. We can't just
// move the elements out.
auto list = std::move(ivalue).to<List<Elem>>();
TORCH_CHECK(list.size() == sizeof...(I), "Tried to convert a List with ", list.size()," elements to a fixed-size array of size ", sizeof...(I));
return {list[I]...};
}
}

template <typename Elem, size_t N>
std::array<Elem, N> generic_to(
IValue ivalue,
_fake_type<std::array<Elem, N>> ft) {
return detail::generic_to_array(ivalue, ft, std::make_index_sequence<N>());
}

template <typename Key, typename Value>
c10::Dict<Key, Value> generic_to(
IValue ivalue,
Expand Down Expand Up @@ -742,6 +764,14 @@ template<class T> inline IValue::IValue(const std::vector<T>& v)
list.push_back(e);
}
}
template<class T, size_t N> inline IValue::IValue(std::array<T, N> v)
: IValue(c10::List<T>()) {
auto list = to<c10::List<T>>();
list.reserve(v.size());
for (auto& e : v) {
list.push_back(std::move(e));
}
}

inline IValue::IValue(c10::impl::GenericDict v)
: tag(Tag::GenericDict), is_intrusive_ptr(true) {
Expand Down
54 changes: 54 additions & 0 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Expand Up @@ -1169,6 +1169,60 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
"(Tensor[] a) -> Tensor[]");


// std::array list types (with empty list)
testArgTypes<std::array<double, 0>>::test(
std::array<double, 0>(), [] (std::array<double, 0> v) {},
std::array<double, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<c10::List<double>>().size()));},
"(float[0] a) -> float[0]");
testArgTypes<std::array<int64_t, 0>>::test(
std::array<int64_t, 0>(), [] (std::array<int64_t, 0> v) {},
std::array<int64_t, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<c10::List<int64_t>>().size()));},
"(int[0] a) -> int[0]");
testArgTypes<std::array<bool, 0>>::test(
std::array<bool, 0>(), [] (std::array<bool, 0> v) {},
std::array<bool, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<std::array<bool, 0>>().size()));},
"(bool[0] a) -> bool[0]");
testArgTypes<std::array<std::string, 0>>::test(
std::array<std::string, 0>(), [] (std::array<std::string, 0> v) {EXPECT_EQ(0, v.size());},
std::array<std::string, 0>(), [] (const IValue& v) {EXPECT_EQ(0, v.toListRef().size());},
"(str[0] a) -> str[0]");


// std::array list types (with non-empty list)
testArgTypes<std::array<double, 2>>::test(
std::array<double, 2>({1.5, 2.5}), [] (std::array<double, 2> v) {expectListEquals({1.5, 2.5}, v);},
std::array<double, 2>({3.5, 4.5}), [] (const IValue& v) {expectListEquals({3.5, 4.5}, v.to<std::array<double, 2>>());},
"(float[2] a) -> float[2]");
testArgTypes<std::array<int64_t, 2>>::test(
std::array<int64_t, 2>({1, 2}), [] (std::array<int64_t, 2> v) {expectListEquals({1, 2}, v);},
std::array<int64_t, 2>({3, 4}), [] (const IValue& v) {expectListEquals({3, 4}, v.to<std::array<int64_t, 2>>());},
"(int[2] a) -> int[2]");
testArgTypes<std::array<bool, 2>>::test(
std::array<bool, 2>({true, false}), [] (std::array<bool, 2> v) {expectListEquals({true, false}, v);},
std::array<bool, 2>({true, false}), [] (const IValue& v) {expectListEquals({true, false}, v.to<std::array<bool, 2>>());},
"(bool[2] a) -> bool[2]");
testArgTypes<std::array<std::string, 2>>::test(
std::array<std::string, 2>({"first", "second"}), [] (std::array<std::string, 2> v) {expectListEquals({"first", "second"}, v);},
std::array<std::string, 2>({"first", "second"}), [] (const IValue& v) {
EXPECT_EQ(2, v.toListRef().size());
EXPECT_EQ("first", v.toListRef()[0].toStringRef());
EXPECT_EQ("second", v.toListRef()[1].toStringRef());
},
"(str[2] a) -> str[2]");
testArgTypes<std::array<Tensor, 2>>::test(
std::array<Tensor, 2>({dummyTensor(c10::DispatchKey::CPUTensorId), dummyTensor(c10::DispatchKey::CUDATensorId)}), [] (std::array<Tensor, 2> v) {
EXPECT_EQ(2, v.size());
EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v[0]));
EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v[1]));
},
std::array<Tensor, 2>({dummyTensor(c10::DispatchKey::CUDATensorId), dummyTensor(c10::DispatchKey::CPUTensorId)}), [] (const IValue& v) {
EXPECT_EQ(2, v.to<c10::List<at::Tensor>>().size());
EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(0)));
EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(1)));
},
"(Tensor[2] a) -> Tensor[2]");


// deprecated list types (with empty list)
testArgTypes<std::vector<double>>::test<TestLegacyAPI>(
std::vector<double>(), [] (const std::vector<double>& v) {EXPECT_EQ(0, v.size());},
Expand Down

0 comments on commit 8541db5

Please sign in to comment.