Skip to content

Commit

Permalink
Add std::variant backport as c10::variant (pytorch#26836)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#26836

* **pytorch#26836 Add std::variant backport as c10::variant**

Test Plan: Imported from OSS

Differential Revision: D17649064

Pulled By: yf225

fbshipit-source-id: aa5ee26fe7078cc66d03663b9ff9e998e1d5839a
  • Loading branch information
Will Feng authored and facebook-github-bot committed Sep 28, 2019
1 parent cca3a36 commit 0cd1880
Show file tree
Hide file tree
Showing 3 changed files with 2,865 additions and 1 deletion.
3 changes: 2 additions & 1 deletion aten/src/ATen/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor_iterator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/variant_test.cpp)

list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
Expand Down
41 changes: 41 additions & 0 deletions aten/src/ATen/test/variant_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include <gtest/gtest.h>

#include <c10/util/variant.h>

namespace testns {

namespace enumtype {
// NOTE: We need to provide the default constructor for each struct,
// otherwise Clang 3.8 would complain:
// ```
// error: default initialization of an object of const type 'const enumtype::Enum1'
// without a user-provided default constructor
// ```
struct Enum1 { Enum1() {} };
struct Enum2 { Enum2() {} };
struct Enum3 { Enum3() {} };
} // namespace enumtype

const enumtype::Enum1 kEnum1;
const enumtype::Enum2 kEnum2;
const enumtype::Enum3 kEnum3;

} // namespace testns

std::string func(c10::variant<testns::enumtype::Enum1, testns::enumtype::Enum2, testns::enumtype::Enum3> v) {
if (c10::get_if<testns::enumtype::Enum1>(&v)) {
return "Enum1";
} else if (c10::get_if<testns::enumtype::Enum2>(&v)) {
return "Enum2";
} else if (c10::get_if<testns::enumtype::Enum3>(&v)) {
return "Enum3";
} else {
return "Unsupported enum";
}
}

TEST(VariantTest, Basic) {
ASSERT_EQ(func(testns::kEnum1), "Enum1");
ASSERT_EQ(func(testns::kEnum2), "Enum2");
ASSERT_EQ(func(testns::kEnum3), "Enum3");
}

0 comments on commit 0cd1880

Please sign in to comment.