-
Notifications
You must be signed in to change notification settings - Fork 74k
/
scatter_functor_gpu.cu.cc
64 lines (46 loc) · 2.3 KB
/
scatter_functor_gpu.cu.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/scatter_functor_gpu.cu.h"
#include "tensorflow/core/framework/register_types.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_SPECS_OP(T, Index, op) \
template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \
template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>;
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
#define DEFINE_GPU_SPECS(T) \
DEFINE_GPU_SPECS_INDEX(T, int32); \
DEFINE_GPU_SPECS_INDEX(T, int64);
DEFINE_GPU_SPECS(Eigen::half);
DEFINE_GPU_SPECS(Eigen::bfloat16);
DEFINE_GPU_SPECS(float);
DEFINE_GPU_SPECS(double);
#define DEFINE_GPU_SPECS_ASSIGN_ONLY(T) \
DEFINE_GPU_SPECS_OP(T, int32, scatter_op::UpdateOp::ASSIGN); \
DEFINE_GPU_SPECS_OP(T, int64, scatter_op::UpdateOp::ASSIGN);
TF_CALL_bool(DEFINE_GPU_SPECS_ASSIGN_ONLY);
TF_CALL_INTEGRAL_TYPES(DEFINE_GPU_SPECS_ASSIGN_ONLY);
#undef DEFINE_GPU_SPECS_ASSIGN_ONLY
#undef DEFINE_GPU_SPECS
#undef DEFINE_GPU_SPECS_INDEX
#undef DEFINE_GPU_SPECS_OP
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM