Skip to content

Commit

Permalink
Correctly process sparse dots (do not drop sparsity info when calling…
Browse files Browse the repository at this point in the history
… CreateDot)

PiperOrigin-RevId: 636517867
  • Loading branch information
sergeykozub authored and tensorflower-gardener committed May 23, 2024
1 parent 2dda854 commit 6f61468
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ auto OptionalBroadcast(Pattern pattern) {

bool IsBatchedMatmul(const HloInstruction* instr) {
if (instr->opcode() != HloOpcode::kDot) return false;
if (Cast<HloDotInstruction>(instr)->sparse_operands()) return false;
const DotDimensionNumbers& dot_dims = instr->dot_dimension_numbers();
bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
!dot_dims.rhs_batch_dimensions().empty();
Expand Down
5 changes: 5 additions & 0 deletions third_party/xla/xla/service/gpu/matmul_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "xla/autotuning.pb.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/primitive_util.h"
#include "xla/service/algorithm_util.h"
Expand Down Expand Up @@ -246,6 +248,9 @@ std::vector<int64_t> NormalizedRelativeOrder(absl::Span<const int64_t> dims) {

absl::StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
int64_t operand_idx) {
if (Cast<HloDotInstruction>(&dot)->sparse_operands()) {
return false;
}
TF_RET_CHECK(dot.opcode() == HloOpcode::kDot);
TF_RET_CHECK(dot.operand_count() > operand_idx);

Expand Down

0 comments on commit 6f61468

Please sign in to comment.