-
Notifications
You must be signed in to change notification settings - Fork 74k
/
reduction_ops.h
206 lines (180 loc) · 8.96 KB
/
reduction_ops.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
// Functor definitions for Reduction ops, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
namespace functor {
template <typename Reducer>
struct ReducerTraits {
enum { IsScalarIdentity = true };
};
// Dummy class used for template specialization for mean reduction, which is
// accomplished by SumReducer and on-the-fly division by the reduction factor.
template <typename Scalar>
struct MeanReducer {
Scalar initialize() const { return Scalar(0); }
};
// Dummy class used for template specialization for l2-norm reduction.
template <typename Scalar>
struct EuclideanNormReducer {
Scalar initialize() const { return Scalar(0); }
};
template <typename Scalar>
struct ReducerTraits<EuclideanNormReducer<Scalar>> {
enum { IsScalarIdentity = false };
};
template <typename Device, typename OUT_T, typename IN_T,
typename ReductionAxes, typename Reducer>
struct ReduceEigenImpl {
void operator()(const Device& d, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes, const Reducer& reducer) {
out.device(d) = in.reduce(reduction_axes, reducer);
}
};
// Specialization for BF16 Reducer to fix accuracy.
// TODO: All BF16 reducers should have specializations to fix accuracy.
#define CASTING_SPECIALIZATION(Reducer, ScalarType, IntermediateType) \
template <typename Device, typename OUT_T, typename IN_T, \
typename ReductionAxes> \
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \
Reducer<ScalarType>> { \
void operator()(const Device& d, OUT_T out, IN_T in, \
const ReductionAxes& reduction_axes, \
const Reducer<ScalarType>& reducer) { \
static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
""); \
Reducer<IntermediateType> intermediate_reducer; \
auto in_as_intermediate = in.template cast<IntermediateType>(); \
out.device(d) = \
in_as_intermediate.reduce(reduction_axes, intermediate_reducer) \
.template cast<ScalarType>(); \
} \
};
CASTING_SPECIALIZATION(Eigen::internal::SumReducer, bfloat16, float);
#undef CASTING_SPECIALIZATION
template <typename Device, typename OUT_T, typename IN_T,
typename ReductionAxes, typename Scalar>
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
functor::MeanReducer<Scalar>> {
void operator()(const Device& d, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const functor::MeanReducer<Scalar>& reducer) {
static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "");
Eigen::internal::SumReducer<Scalar> sum_reducer;
out.device(d) = in.reduce(reduction_axes, sum_reducer) /
static_cast<Scalar>(in.size() / out.size());
}
};
// Specialization for which we do the reduction in IntermediateType to
// avoid integer overflow.
#define CASTING_SPECIALIZATION(ScalarType, IntermediateType) \
template <typename Device, typename OUT_T, typename IN_T, \
typename ReductionAxes> \
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \
functor::MeanReducer<ScalarType>> { \
void operator()(const Device& d, OUT_T out, IN_T in, \
const ReductionAxes& reduction_axes, \
const functor::MeanReducer<ScalarType>& reducer) { \
static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
""); \
Eigen::internal::SumReducer<IntermediateType> sum_reducer; \
out.device(d) = (in.template cast<IntermediateType>().reduce( \
reduction_axes, sum_reducer) / \
static_cast<IntermediateType>(in.size() / out.size())) \
.template cast<ScalarType>(); \
} \
}
CASTING_SPECIALIZATION(uint8, uint64);
CASTING_SPECIALIZATION(uint16, uint64);
CASTING_SPECIALIZATION(uint32, uint64);
CASTING_SPECIALIZATION(int8, int64_t);
CASTING_SPECIALIZATION(int16, int64_t);
CASTING_SPECIALIZATION(int32, int64_t);
#undef CASTING_SPECIALIZATION
// TODO(rmlarsen): Refactor this such that taking the sqrt can be optional
// controlled by an attribute.
template <typename Device, typename OUT_T, typename IN_T,
typename ReductionAxes, typename Scalar>
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
functor::EuclideanNormReducer<Scalar>> {
void operator()(const Device& d, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const functor::EuclideanNormReducer<Scalar>& reducer) {
static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "");
Eigen::internal::SumReducer<Scalar> sum_reducer;
out.device(d) =
(in * in.conjugate()).reduce(reduction_axes, sum_reducer).sqrt();
}
};
template <typename Device, typename OUT_T, typename IN_T,
typename ReductionAxes>
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
functor::EuclideanNormReducer<bfloat16>> {
void operator()(const Device& d, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const functor::EuclideanNormReducer<bfloat16>& reducer) {
static_assert(std::is_same<bfloat16, typename OUT_T::Scalar>::value, "");
Eigen::internal::SumReducer<float> sum_reducer;
auto in_as_float = in.template cast<float>();
out.device(d) = (in_as_float * in_as_float.conjugate())
.reduce(reduction_axes, sum_reducer)
.sqrt()
.template cast<bfloat16>();
}
};
// For most reducers, the identity is Reducer::initialize()
template <typename Reducer>
struct Identity {
static auto identity(const Reducer& reducer)
-> decltype(reducer.initialize()) {
return reducer.initialize();
}
};
// MeanReducer is a special case, since it doesn't technically have an identity.
// Thus, ideally we'd return nan. However, mean is instantiated for integer
// types as well, so we do the nan override only for floating point types.
#define FIX_MEAN_IDENTITY(T) \
template <> \
struct Identity<functor::MeanReducer<T>> { \
static T identity(const functor::MeanReducer<T>&) { \
return Eigen::NumTraits<T>::quiet_NaN(); \
} \
};
FIX_MEAN_IDENTITY(Eigen::half)
FIX_MEAN_IDENTITY(Eigen::bfloat16)
FIX_MEAN_IDENTITY(float)
FIX_MEAN_IDENTITY(double)
#undef FIX_MEAN_IDENTITY
template <typename Device, typename OUT_T, typename Reducer>
void FillIdentityEigenImpl(const Device& d, OUT_T out, const Reducer& reducer) {
MaybeWith32BitIndexing<Device>(
[&](auto out32) {
out32.device(d) = out32.constant(Identity<Reducer>::identity(reducer));
},
out);
}
template <typename Device, typename Reducer>
struct ReduceFunctor {
template <typename OUT_T, typename IN_T, typename ReductionAxes>
static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const Reducer& reducer);
template <typename OUT_T>
static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer);
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_