-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
NativeFunctions.h
54 lines (46 loc) · 1.95 KB
/
NativeFunctions.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#pragma once
// ${generated_comment}
#include <ATen/Context.h>
#include <ATen/core/Reduction.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <array>
#include <functional>
#include <string>
#include <tuple>
#include <vector>
namespace c10 {
class Scalar;
}
namespace at {
struct Generator;
class Tensor;
struct Type;
} // namespace at
namespace at {
// These functions are defined in ATen/Utils.cpp.
#define TENSOR(T, S) \
CAFFE2_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
inline Tensor tensor( \
std::initializer_list<T> values, const TensorOptions& options) { \
return at::tensor(ArrayRef<T>(values), options); \
} \
inline Tensor tensor(T value, const TensorOptions& options) { \
return at::tensor(ArrayRef<T>(value), options); \
} \
inline Tensor tensor(ArrayRef<T> values) { \
return at::tensor(std::move(values), at::dtype(k##S)); \
} \
inline Tensor tensor(std::initializer_list<T> values) { \
return at::tensor(ArrayRef<T>(values)); \
} \
inline Tensor tensor(T value) { \
return at::tensor(ArrayRef<T>(value)); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
AT_FORALL_COMPLEX_TYPES(TENSOR)
#undef TENSOR
namespace native {
${native_function_declarations}
} // namespace native
} // namespace at