-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
TypeProperties.cpp
144 lines (121 loc) · 4.14 KB
/
TypeProperties.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TypeProperties.h>
#include <type_traits>
namespace at { namespace native {
bool is_cuda(const Tensor& self) {
return self.is_cuda();
}
bool is_distributed(const Tensor& self) {
return false;
}
bool is_complex(const Tensor& self) {
return at::isComplexType(self.scalar_type());
}
bool is_floating_point(const Tensor& self) {
return at::isFloatingType(self.scalar_type());
}
bool is_signed(const Tensor &self) {
return at::isSignedType(self.scalar_type());
}
bool is_sparse(const Tensor& self) {
return self.is_sparse();
}
bool is_quantized(const Tensor& self) {
return self.is_quantized();
}
// True if `self` and `from` have compatible tensor type so that `from`'s
// TensorImpl can be copied to `self`.
bool _has_compatible_shallow_copy_type(const Tensor& self, const Tensor& from) {
return self.unsafeGetTensorImpl()->has_compatible_shallow_copy_type(
from.key_set());
}
Tensor type_as(const Tensor& self, const Tensor& other) {
return self.to(other.options());
}
static inline ScalarType promote_skip_undefined(ScalarType a, ScalarType b) {
if (a == ScalarType::Undefined) {
return b;
}
if (b == ScalarType::Undefined) {
return a;
}
return promoteTypes(a, b);
}
static inline ScalarType combine_categories(ScalarType higher, ScalarType lower) {
if(isComplexType(higher)) {
return higher;
}
else if(!isComplexType(lower) && isFloatingType(higher)) {
return higher;
}
if (higher == ScalarType::Bool || isFloatingType(lower) || isComplexType(lower)) {
return promote_skip_undefined(higher, lower);
}
if (higher != ScalarType::Undefined) {
return higher;
}
return lower;
}
ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state) {
if (!tensor.defined()) {
return in_state;
}
ResultTypeState new_state = in_state;
ScalarType current = tensor.scalar_type();
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
auto current_default = typeMetaToScalarType(at::get_default_dtype());
if(isComplexType(current)) {
current = typeMetaToScalarType(at::get_default_complex_dtype());
}
else if(isFloatingType(current)) {
current = current_default;
}
}
if ( tensor.dim() > 0 ) {
new_state.dimResult = promote_skip_undefined(in_state.dimResult, current);
} else if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
new_state.wrappedResult = promote_skip_undefined(in_state.wrappedResult, current);
} else {
new_state.zeroResult = promote_skip_undefined(in_state.zeroResult, current);
}
return new_state;
}
ScalarType result_type(const ResultTypeState& in_state) {
return combine_categories(in_state.dimResult, combine_categories(in_state.zeroResult, in_state.wrappedResult));
}
ScalarType result_type(TensorList tensors) {
ResultTypeState state = {};
for (const Tensor& tensor : tensors) {
state = update_result_type_state(tensor, state);
}
return result_type(state);
}
ScalarType result_type(const Tensor &tensor, const Tensor &other) {
std::vector<Tensor> tensors{std::move(tensor), std::move(other)};
return native::result_type(tensors);
}
ScalarType result_type(const Tensor &tensor, const Scalar other) {
auto tensor2 = scalar_to_tensor(other);
tensor2.unsafeGetTensorImpl()->set_wrapped_number(true);
std::vector<Tensor> tensors{std::move(tensor), std::move(tensor2)};
return native::result_type(tensors);
}
ScalarType result_type(const Scalar scalar, const Tensor &tensor) {
return at::result_type(tensor, scalar);
}
ScalarType result_type(const Scalar scalar1, const Scalar scalar2) {
auto tensor1 = scalar_to_tensor(scalar1);
tensor1.unsafeGetTensorImpl()->set_wrapped_number(true);
return at::result_type(tensor1, scalar2);
}
bool can_cast(const at::ScalarType from, const at::ScalarType to) {
return at::canCast(from, to);
}
ScalarType promote_types(ScalarType type1, ScalarType type2) {
ScalarType ret = promoteTypes(type1, type2);
TORCH_CHECK(ret != ScalarType::Undefined, "Promotion from ", type1, " and ", type2, " is unsupported.");
return ret;
}
}} // namespace at::native