Skip to content

Commit 8a30179

Browse files
Move MemoryFormat/Layout to headeronly
[ghstack-poisoned]
1 parent 0b4d592 commit 8a30179

File tree

13 files changed

+484
-123
lines changed

13 files changed

+484
-123
lines changed

c10/core/Layout.h

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,9 @@
33
#include <c10/core/Backend.h>
44
#include <c10/util/Exception.h>
55

6-
#include <cstdint>
7-
#include <ostream>
6+
#include <torch/headeronly/core/Layout.h>
87

98
namespace c10 {
10-
enum class Layout : int8_t {
11-
Strided,
12-
Sparse,
13-
SparseCsr,
14-
Mkldnn,
15-
SparseCsc,
16-
SparseBsr,
17-
SparseBsc,
18-
Jagged,
19-
NumOptions
20-
};
21-
22-
constexpr auto kStrided = Layout::Strided;
23-
constexpr auto kSparse = Layout::Sparse;
24-
constexpr auto kSparseCsr = Layout::SparseCsr;
25-
constexpr auto kMkldnn = Layout::Mkldnn;
26-
constexpr auto kSparseCsc = Layout::SparseCsc;
27-
constexpr auto kSparseBsr = Layout::SparseBsr;
28-
constexpr auto kSparseBsc = Layout::SparseBsc;
29-
constexpr auto kJagged = Layout::Jagged;
309

3110
inline Layout layout_from_backend(Backend backend) {
3211
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
@@ -56,28 +35,4 @@ inline Layout layout_from_backend(Backend backend) {
5635
C10_DIAGNOSTIC_POP()
5736
}
5837

59-
inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) {
60-
switch (layout) {
61-
case at::kStrided:
62-
return stream << "Strided";
63-
case at::kSparse:
64-
return stream << "Sparse";
65-
case at::kSparseCsr:
66-
return stream << "SparseCsr";
67-
case at::kSparseCsc:
68-
return stream << "SparseCsc";
69-
case at::kSparseBsr:
70-
return stream << "SparseBsr";
71-
case at::kSparseBsc:
72-
return stream << "SparseBsc";
73-
case at::kMkldnn:
74-
return stream << "Mkldnn";
75-
case at::kJagged:
76-
return stream << "Jagged";
77-
case Layout::NumOptions:
78-
default:
79-
TORCH_CHECK(false, "Unknown layout");
80-
}
81-
}
82-
8338
} // namespace c10

c10/core/MemoryFormat.h

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,64 +3,18 @@
33
#include <c10/util/ArrayRef.h>
44
#include <c10/util/Exception.h>
55

6+
#include <torch/headeronly/core/MemoryFormat.h>
7+
68
#include <cstdint>
7-
#include <ostream>
89
#include <vector>
910

10-
// Memory format is not the property of a Tensor. It is the way to tell an
11-
// operator how the result should be organized in memory and nothing more. That
12-
// means memory format should never be used as return value for any tensor state
13-
// interrogation functions (internally and externally).
14-
//
15-
// Possible options are:
16-
// Preserve:
17-
// If any of the input tensors is in channels_last format, operator output
18-
// should be in channels_last format
19-
//
20-
// Contiguous:
21-
// Regardless of input tensors format, the output should be contiguous
22-
// Tensor.
23-
//
24-
// ChannelsLast:
25-
// Regardless of input tensors format, the output should be in channels_last
26-
// format.
27-
2811
namespace c10 {
29-
enum class MemoryFormat : int8_t {
30-
Contiguous,
31-
Preserve,
32-
ChannelsLast,
33-
ChannelsLast3d,
34-
NumOptions
35-
};
3612

3713
// If you are seeing this, it means that this call site was not checked if
3814
// the memory format could be preserved, and it was switched to old default
3915
// behaviour of contiguous
4016
#define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format()
4117

42-
inline MemoryFormat get_contiguous_memory_format() {
43-
return MemoryFormat::Contiguous;
44-
}
45-
46-
inline std::ostream& operator<<(
47-
std::ostream& stream,
48-
at::MemoryFormat memory_format) {
49-
switch (memory_format) {
50-
case MemoryFormat::Preserve:
51-
return stream << "Preserve";
52-
case MemoryFormat::Contiguous:
53-
return stream << "Contiguous";
54-
case MemoryFormat::ChannelsLast:
55-
return stream << "ChannelsLast";
56-
case MemoryFormat::ChannelsLast3d:
57-
return stream << "ChannelsLast3d";
58-
case MemoryFormat::NumOptions:
59-
default:
60-
TORCH_CHECK(false, "Unknown memory format ", memory_format);
61-
}
62-
}
63-
6418
// Note: Hardcoded the channel last stride indices here to get better
6519
// performance
6620
template <typename T>

test/cpp/aoti_abi_check/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ set(AOTI_ABI_CHECK_TEST_SRCS
1616
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
1717
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
1818
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
19+
${AOTI_ABI_CHECK_TEST_ROOT}/test_layout.cpp
1920
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
2021
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
22+
${AOTI_ABI_CHECK_TEST_ROOT}/test_memoryformat.cpp
2123
${AOTI_ABI_CHECK_TEST_ROOT}/test_metaprogramming.cpp
2224
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
2325
${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <torch/headeronly/core/Layout.h>
4+
5+
TEST(TestLayout, TestLayout) {
6+
using torch::headeronly::Layout;
7+
constexpr Layout expected_layouts[] = {
8+
torch::headeronly::kStrided,
9+
torch::headeronly::kSparse,
10+
torch::headeronly::kSparseCsr,
11+
torch::headeronly::kMkldnn,
12+
torch::headeronly::kSparseCsc,
13+
torch::headeronly::kSparseBsr,
14+
torch::headeronly::kSparseBsc,
15+
torch::headeronly::kJagged,
16+
};
17+
for (int8_t i = 0; i < static_cast<int8_t>(Layout::NumOptions); i++) {
18+
EXPECT_EQ(static_cast<Layout>(i), expected_layouts[i]);
19+
}
20+
}
21+
22+
TEST(TestLayout, operator_left_shift) {
23+
using torch::headeronly::Layout;
24+
25+
{
26+
std::stringstream ss;
27+
ss << Layout::Strided;
28+
EXPECT_EQ(ss.str(), "Strided");
29+
}
30+
{
31+
std::stringstream ss;
32+
ss << Layout::Sparse;
33+
EXPECT_EQ(ss.str(), "Sparse");
34+
}
35+
{
36+
std::stringstream ss;
37+
ss << Layout::SparseCsr;
38+
EXPECT_EQ(ss.str(), "SparseCsr");
39+
}
40+
{
41+
std::stringstream ss;
42+
ss << Layout::SparseCsc;
43+
EXPECT_EQ(ss.str(), "SparseCsc");
44+
}
45+
{
46+
std::stringstream ss;
47+
ss << Layout::SparseBsr;
48+
EXPECT_EQ(ss.str(), "SparseBsr");
49+
}
50+
{
51+
std::stringstream ss;
52+
ss << Layout::SparseBsc;
53+
EXPECT_EQ(ss.str(), "SparseBsc");
54+
}
55+
{
56+
std::stringstream ss;
57+
ss << Layout::Mkldnn;
58+
EXPECT_EQ(ss.str(), "Mkldnn");
59+
}
60+
{
61+
std::stringstream ss;
62+
ss << Layout::Jagged;
63+
EXPECT_EQ(ss.str(), "Jagged");
64+
}
65+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <torch/headeronly/core/MemoryFormat.h>
4+
5+
TEST(TestMemoryFormat, TestMemoryFormat) {
6+
using torch::headeronly::MemoryFormat;
7+
constexpr MemoryFormat expected_memory_formats[] = {
8+
MemoryFormat::Contiguous,
9+
MemoryFormat::Preserve,
10+
MemoryFormat::ChannelsLast,
11+
MemoryFormat::ChannelsLast3d,
12+
};
13+
for (int8_t i = 0; i < static_cast<int8_t>(MemoryFormat::NumOptions); i++) {
14+
EXPECT_EQ(static_cast<MemoryFormat>(i), expected_memory_formats[i]);
15+
}
16+
}
17+
18+
TEST(TestMemoryFormat, get_contiguous_memory_format) {
19+
using torch::headeronly::get_contiguous_memory_format;
20+
using torch::headeronly::MemoryFormat;
21+
22+
EXPECT_EQ(get_contiguous_memory_format(), MemoryFormat::Contiguous);
23+
}
24+
25+
TEST(TestMemoryFormat, operator_left_shift) {
26+
using torch::headeronly::MemoryFormat;
27+
28+
{
29+
std::stringstream ss;
30+
ss << MemoryFormat::Preserve;
31+
EXPECT_EQ(ss.str(), "Preserve");
32+
}
33+
{
34+
std::stringstream ss;
35+
ss << MemoryFormat::Contiguous;
36+
EXPECT_EQ(ss.str(), "Contiguous");
37+
}
38+
{
39+
std::stringstream ss;
40+
ss << MemoryFormat::ChannelsLast;
41+
EXPECT_EQ(ss.str(), "ChannelsLast");
42+
}
43+
{
44+
std::stringstream ss;
45+
ss << MemoryFormat::ChannelsLast3d;
46+
EXPECT_EQ(ss.str(), "ChannelsLast3d");
47+
}
48+
}

test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@ using torch::stable::Tensor;
1010
Tensor my_empty(
1111
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
1212
std::optional<torch::headeronly::ScalarType> dtype,
13+
std::optional<torch::headeronly::Layout> layout,
1314
std::optional<torch::stable::Device> device,
14-
std::optional<bool> pin_memory) {
15-
return empty(size, dtype, device, pin_memory);
15+
std::optional<bool> pin_memory,
16+
std::optional<torch::headeronly::MemoryFormat> memory_format) {
17+
return empty(size, dtype, layout, device, pin_memory, memory_format);
1618
}
1719

1820
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
1921
m.def(
20-
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
22+
"my_empty(int[] size, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
2123
}
2224

2325
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {

test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,24 @@ def test_get_num_threads() -> int:
156156
return torch.ops.libtorch_agnostic_2_10.test_get_num_threads.default()
157157

158158

159-
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
159+
def my_empty(
160+
size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
161+
) -> Tensor:
160162
"""
161-
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
163+
Creates an empty tensor with the specified size, dtype, layout, device, pin_memory, and memory_format.
162164
163165
Args:
164166
size: list[int] - size of the tensor to create
165167
dtype: ScalarType or None - data type of the tensor
168+
layout: Layout or None - layout of the tensor
166169
device: Device or None - device on which to create the tensor
167170
pin_memory: bool or None - whether to use pinned memory
171+
memory_format: MemoryFormat or None - memory format of the tensor
168172
169173
Returns: Tensor - an uninitialized tensor with the specified properties
170174
"""
171175
return torch.ops.libtorch_agnostic_2_10.my_empty.default(
172-
size, dtype, device, pin_memory
176+
size, dtype, layout, device, pin_memory, memory_format
173177
)
174178

175179

0 commit comments

Comments
 (0)