Skip to content

Commit

Permalink
Updating the algorithm attribute from the mlir_gemm_test as requested…
Browse files Browse the repository at this point in the history
… in the PR feedback
  • Loading branch information
deven-amd committed May 14, 2021
1 parent f4c6913 commit ffb210f
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 8 deletions.
7 changes: 0 additions & 7 deletions tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ static Status DoGemmWithAlgorithm(
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);

// Ignore the "algorithm" field on the ROCm platform. This is because
// autotuning for GEMM is not yet available on the ROCm platform
// The "algorithm" field does not get populated in the "normal" flow
// on the ROCm platform, but atleast one unittest directly populates it
// and hence the need for this check
#if !defined(TENSORFLOW_USE_ROCM)
if (algorithm) {
// Autotuning is disabled for batch_size != 1.
CHECK_EQ(1, batch_size);
Expand All @@ -143,7 +137,6 @@ static Status DoGemmWithAlgorithm(
/*leading dim of output=*/output_matrix.num_rows, computation_type,
*algorithm, output_profile_result);
}
#endif // !defined(TENSORFLOW_USE_ROCM)

if (batch_size != 1) {
int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ TEST_F(GemmTest, SimpleCase1) {
%arg2: memref<2x2xf32> {lmhlo.output_index = dense<[0]> : tensor<1xindex>}) attributes {
result_xla_shape = "(f32[4]) "
} {
"lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {algorithm = 7 : i64, alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, batch_size = 1 : i64, dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, batch_size = 1 : i64, dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}
})";
Expand Down

0 comments on commit ffb210f

Please sign in to comment.