-
Notifications
You must be signed in to change notification settings - Fork 21.3k
/
Common.h
82 lines (64 loc) · 1.89 KB
/
Common.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#pragma once
#include <ATen/ATen.h>
#ifdef USE_XNNPACK
#include <xnnpack.h>
namespace at {
namespace native {
namespace xnnpack {
namespace internal {
struct Layout final {
// 4D Activation Maps
struct Activation4D final {
static constexpr size_t batch = 0u;
static constexpr size_t channels = 1u;
static constexpr size_t height = 2u;
static constexpr size_t width = 3u;
};
// ND Activation Maps
struct ActivationND final {
// Some operators may not be limited to 4 dimensional tensors. In that scenario,
// XNNPACK denotes that operator with an _nc suffix and expects all dimensions,
// except channels, to be flattened into one argument: batch_size.
static int64_t batch(const IntArrayRef tensor) {
if (C10_UNLIKELY(tensor.empty())) {
return -1;
}
// Handle the case where batch size is zero.
int64_t batch = std::max<int64_t>(1, tensor[0]);
for (size_t index = 1u; index < (tensor.size() - 1u); ++index) {
batch *= tensor[index];
}
return batch;
};
static int64_t channel(const IntArrayRef tensor) {
if (C10_UNLIKELY(tensor.empty())) {
return -1;
}
return tensor.back();
};
};
// Convolution Filters
struct Filter final {
static constexpr size_t output = 0u;
static constexpr size_t input = 1u;
static constexpr size_t height = 2u;
static constexpr size_t width = 3u;
};
// Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.)
struct Parameter final {
static constexpr size_t height = 0u;
static constexpr size_t width = 1u;
};
};
struct Deleter final {
void operator()(const xnn_operator_t op) const {
xnn_delete_operator(op);
}
};
using Operator = std::unique_ptr<xnn_operator, Deleter>;
bool available();
} // namespace internal
} // namespace xnnpack
} // namespace native
} // namespace at
#endif /* USE_XNNPACK */