|
3 | 3 | #include <c10/util/ArrayRef.h> |
4 | 4 | #include <c10/util/Exception.h> |
5 | 5 |
|
| 6 | +#include <torch/headeronly/core/MemoryFormat.h> |
| 7 | + |
6 | 8 | #include <cstdint> |
7 | | -#include <ostream> |
8 | 9 | #include <vector> |
9 | 10 |
|
10 | | -// Memory format is not the property of a Tensor. It is the way to tell an |
11 | | -// operator how the result should be organized in memory and nothing more. That |
12 | | -// means memory format should never be used as return value for any tensor state |
13 | | -// interrogation functions (internally and externally). |
14 | | -// |
15 | | -// Possible options are: |
16 | | -// Preserve: |
17 | | -// If any of the input tensors is in channels_last format, operator output |
18 | | -// should be in channels_last format |
19 | | -// |
20 | | -// Contiguous: |
21 | | -// Regardless of input tensors format, the output should be contiguous |
22 | | -// Tensor. |
23 | | -// |
24 | | -// ChannelsLast: |
25 | | -// Regardless of input tensors format, the output should be in channels_last |
26 | | -// format. |
27 | | - |
28 | 11 | namespace c10 { |
29 | | -enum class MemoryFormat : int8_t { |
30 | | - Contiguous, |
31 | | - Preserve, |
32 | | - ChannelsLast, |
33 | | - ChannelsLast3d, |
34 | | - NumOptions |
35 | | -}; |
36 | 12 |
|
37 | 13 | // If you are seeing this, it means that this call site was not checked if |
38 | 14 | // the memory format could be preserved, and it was switched to old default |
39 | 15 | // behaviour of contiguous |
40 | 16 | #define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format() |
41 | 17 |
|
42 | | -inline MemoryFormat get_contiguous_memory_format() { |
43 | | - return MemoryFormat::Contiguous; |
44 | | -} |
45 | | - |
46 | | -inline std::ostream& operator<<( |
47 | | - std::ostream& stream, |
48 | | - at::MemoryFormat memory_format) { |
49 | | - switch (memory_format) { |
50 | | - case MemoryFormat::Preserve: |
51 | | - return stream << "Preserve"; |
52 | | - case MemoryFormat::Contiguous: |
53 | | - return stream << "Contiguous"; |
54 | | - case MemoryFormat::ChannelsLast: |
55 | | - return stream << "ChannelsLast"; |
56 | | - case MemoryFormat::ChannelsLast3d: |
57 | | - return stream << "ChannelsLast3d"; |
58 | | - case MemoryFormat::NumOptions: |
59 | | - default: |
60 | | - TORCH_CHECK(false, "Unknown memory format ", memory_format); |
61 | | - } |
62 | | -} |
63 | | - |
64 | 18 | // Note: Hardcoded the channel last stride indices here to get better |
65 | 19 | // performance |
66 | 20 | template <typename T> |
|
0 commit comments