Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/gpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#include <unordered_map>

Expand Down Expand Up @@ -214,6 +214,6 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,

} // namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#endif // TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
16 changes: 9 additions & 7 deletions tensorflow/core/kernels/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ limitations under the License.
#include "tensorflow/core/util/matmul_autotune.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

namespace tensorflow {

Expand Down Expand Up @@ -111,11 +113,11 @@ bool ExplicitVectorMatrixOptimization<Eigen::half>(

template <typename Device, typename T>
struct LaunchMatMulBase {
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
typedef se::blas::AlgorithmType AlgorithmType;
#else
typedef int64 AlgorithmType;
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

static void launch(
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
Expand Down Expand Up @@ -154,7 +156,7 @@ template <typename T, bool USE_CUBLAS>
struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
#endif // TENSORFLOW_USE_SYCL

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

namespace {

Expand Down Expand Up @@ -433,7 +435,7 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
}
};

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

template <typename Device, typename T, bool USE_CUBLAS>
class MatMulOp : public OpKernel {
Expand Down Expand Up @@ -622,13 +624,13 @@ TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
#endif // INTEL_MKL && ENABLE_MKL

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
TF_CALL_half(REGISTER_GPU);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL(T) \
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/matmul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct MatMulFunctor {

} // end namespace functor

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Encapsulate all the shape information that is used in matmul operations.
class MatmulParameters {
public:
Expand Down Expand Up @@ -117,7 +117,7 @@ class MatmulParameters {

typedef Eigen::GpuDevice GPUDevice;

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // end namespace tensorflow

Expand Down