Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Add ROCm support for batch_kernels #26756

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions tensorflow/core/kernels/batch_kernels.cc
Expand Up @@ -82,12 +82,12 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor>& inputs,
context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
if (output->NumElements() > 0) {
auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (std::is_same<Device, GPUDevice>::value) {
ConcatGPU<T>(context, inputs_flat, output, &output_flat);
return Status::OK();
}
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
}

Expand Down Expand Up @@ -172,7 +172,7 @@ Status SplitCPU(OpKernelContext* context, const Tensor& input,
return Status::OK();
}

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Handles the general case, on GPU.
template <typename T>
Expand All @@ -183,7 +183,7 @@ Status SplitGPU(OpKernelContext* context, const Tensor& input,
LOG(FATAL) << "Not yet implemented"; // Crash ok
}

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// The outer function that dispatches to the various Split*() functions above.
template <typename T>
Expand All @@ -197,10 +197,10 @@ Status Split(OpKernelContext* context, const Tensor& input,
return Status::OK();
}

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// TODO(olston, apassos): Handle non-CPU cases.
// return SplitGPU<T>(context, input, sizes, outputs);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
return SplitCPU<T>(context, input, sizes, outputs);
}

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/concat_lib.h
Expand Up @@ -47,7 +47,7 @@ void ConcatCPU(
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
inputs,
typename TTypes<T, 2>::Matrix* output);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T>
void ConcatGPU(
OpKernelContext* c,
Expand All @@ -73,7 +73,7 @@ TF_CALL_bfloat16(REGISTER);
TF_CALL_bool(REGISTER);
TF_CALL_uint8(REGISTER);
#undef REGISTER
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENOSORFLOW_USE_ROCM
whchung marked this conversation as resolved.
Show resolved Hide resolved

#ifdef TENSORFLOW_USE_SYCL
template <typename T>
Expand Down