Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Added ROCm support for inplace_ops #29094

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions tensorflow/core/kernels/inplace_ops.cc
Expand Up @@ -211,7 +211,7 @@ REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
ParallelConcatUpdate<CPUDevice>);
#endif // TENSORFLOW_USE_SYCL

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

typedef Eigen::GpuDevice GPUDevice;

Expand Down Expand Up @@ -482,7 +482,7 @@ REGISTER_EMPTY(int64, CPU)
REGISTER_EMPTY(bool, CPU)
REGISTER_EMPTY(uint8, CPU)

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

typedef Eigen::GpuDevice GPUDevice;

Expand Down Expand Up @@ -545,7 +545,7 @@ REGISTER_EMPTY(Eigen::half, GPU);
REGISTER_EMPTY(int64, GPU);
REGISTER_EMPTY(int32, GPU);

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // end namespace
} // end namespace tensorflow
48 changes: 24 additions & 24 deletions tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

Expand All @@ -30,7 +30,7 @@ template <typename T>
__global__ void DoParallelConcatOpKernel(int nthreads, const int64 rows,
const int64 cols, int32 loc,
const T* src, T* dst) {
CUDA_1D_KERNEL_LOOP(idx, nthreads) {
GPU_1D_KERNEL_LOOP(idx, nthreads) {
int64 c = idx % cols;
int64 r = (loc % rows + rows) % rows; // Guard index range.
T* p = dst + r * cols + c;
Expand All @@ -43,13 +43,13 @@ template <typename T>
Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32 loc,
Tensor* output) {
const int64 nelem = value.NumElements();
GpuLaunchConfig cfg = GetCudaLaunchConfig(nelem, d);
GpuLaunchConfig cfg = GetGpuLaunchConfig(nelem, d);
auto Toutput = output->flat_outer_dims<T>();
const int64 nrows = Toutput.dimension(0);
const int64 ncols = Toutput.dimension(1);
const T* src = value.flat<T>().data();
T* dst = output->flat<T>().data();
TF_CHECK_OK(CudaLaunchKernel(
TF_CHECK_OK(GpuLaunchKernel(
DoParallelConcatOpKernel<T>, cfg.block_count, cfg.thread_per_block, 0,
d.stream(), cfg.virtual_thread_count, nrows, ncols, loc, src, dst));
return Status::OK();
Expand Down Expand Up @@ -82,7 +82,7 @@ template <typename T, InplaceOpType op>
__global__ void DoInplaceOpKernel(int nthreads, const int64 rows,
const int64 cols, const int64 n, const T* src,
const int32* rowids, T* dst) {
CUDA_1D_KERNEL_LOOP(idx, nthreads) {
GPU_1D_KERNEL_LOOP(idx, nthreads) {
int64 r = idx / cols;
int64 c = idx % cols;
r = (rowids[r] % rows + rows) % rows; // Guard index range.
Expand All @@ -106,7 +106,7 @@ template <typename T>
void DoInplaceOp(const Device& d, InplaceOpType op, const Tensor& i,
const Tensor& v, Tensor* y) {
const int64 nelem = v.NumElements();
GpuLaunchConfig cfg = GetCudaLaunchConfig(nelem, d);
GpuLaunchConfig cfg = GetGpuLaunchConfig(nelem, d);
auto Ty = y->flat_outer_dims<T>();
const int64 nrows = Ty.dimension(0);
const int64 ncols = Ty.dimension(1);
Expand All @@ -117,22 +117,22 @@ void DoInplaceOp(const Device& d, InplaceOpType op, const Tensor& i,
T* dst = y->flat<T>().data();
switch (op) {
case I_UPDATE:
TF_CHECK_OK(CudaLaunchKernel(DoInplaceOpKernel<T, I_UPDATE>,
cfg.block_count, cfg.thread_per_block, 0,
d.stream(), cfg.virtual_thread_count, nrows,
ncols, n, src, rowids, dst));
TF_CHECK_OK(GpuLaunchKernel(DoInplaceOpKernel<T, I_UPDATE>,
cfg.block_count, cfg.thread_per_block, 0,
d.stream(), cfg.virtual_thread_count, nrows,
ncols, n, src, rowids, dst));
break;
case I_ADD:
TF_CHECK_OK(CudaLaunchKernel(DoInplaceOpKernel<T, I_ADD>, cfg.block_count,
cfg.thread_per_block, 0, d.stream(),
cfg.virtual_thread_count, nrows, ncols, n,
src, rowids, dst));
TF_CHECK_OK(GpuLaunchKernel(DoInplaceOpKernel<T, I_ADD>, cfg.block_count,
cfg.thread_per_block, 0, d.stream(),
cfg.virtual_thread_count, nrows, ncols, n,
src, rowids, dst));
break;
case I_SUB:
TF_CHECK_OK(CudaLaunchKernel(DoInplaceOpKernel<T, I_SUB>, cfg.block_count,
cfg.thread_per_block, 0, d.stream(),
cfg.virtual_thread_count, nrows, ncols, n,
src, rowids, dst));
TF_CHECK_OK(GpuLaunchKernel(DoInplaceOpKernel<T, I_SUB>, cfg.block_count,
cfg.thread_per_block, 0, d.stream(),
cfg.virtual_thread_count, nrows, ncols, n,
src, rowids, dst));
break;
}
}
Expand All @@ -141,7 +141,7 @@ template <bool>
void DoInplaceOp(const Device& d, InplaceOpType op, const Tensor& i,
const Tensor& v, Tensor* y) {
const int64 nelem = v.NumElements();
GpuLaunchConfig cfg = GetCudaLaunchConfig(nelem, d);
GpuLaunchConfig cfg = GetGpuLaunchConfig(nelem, d);
auto Ty = y->flat_outer_dims<bool>();
const int64 nrows = Ty.dimension(0);
const int64 ncols = Ty.dimension(1);
Expand All @@ -151,10 +151,10 @@ void DoInplaceOp(const Device& d, InplaceOpType op, const Tensor& i,
const int32* rowids = i.flat<int32>().data();
bool* dst = y->flat<bool>().data();
if (op == I_UPDATE) {
TF_CHECK_OK(CudaLaunchKernel(DoInplaceOpKernel<bool, I_UPDATE>,
cfg.block_count, cfg.thread_per_block, 0,
d.stream(), cfg.virtual_thread_count, nrows,
ncols, n, src, rowids, dst));
TF_CHECK_OK(GpuLaunchKernel(DoInplaceOpKernel<bool, I_UPDATE>,
cfg.block_count, cfg.thread_per_block, 0,
d.stream(), cfg.virtual_thread_count, nrows,
ncols, n, src, rowids, dst));
}
}

Expand Down Expand Up @@ -205,4 +205,4 @@ Status DoCopy(const Device& d, const Tensor& x, Tensor* y) {

} // end namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM