Skip to content

Commit

Permalink
[xla:gpu] Add a GemmDegenerateDimRemover pass
Browse files Browse the repository at this point in the history
This pass removes the degenerate dimension introduced by GemvRewriter. We should remove degenerate dimensions after we run GemmFusion.

PiperOrigin-RevId: 629606212
  • Loading branch information
anlunx authored and tensorflower-gardener committed May 1, 2024
1 parent 8808c11 commit c734e43
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 0 deletions.
35 changes: 35 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,41 @@ xla_cc_test(
],
)

cc_library(
name = "gemm_degenerate_dim_remover",
srcs = ["gemm_degenerate_dim_remover.cc"],
hdrs = ["gemm_degenerate_dim_remover.h"],
deps = [
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "gemm_degenerate_dim_remover_test",
srcs = ["gemm_degenerate_dim_remover_test.cc"],
deps = [
":gemm_degenerate_dim_remover",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:statusor",
],
)

cc_library(
name = "split_k_gemm_rewriter",
srcs = ["split_k_gemm_rewriter.cc"],
Expand Down
147 changes: 147 additions & 0 deletions third_party/xla/xla/service/gpu/gemm_degenerate_dim_remover.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/gemm_degenerate_dim_remover.h"

#include <cstdint>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout.h"
#include "xla/layout_util.h"
#include "xla/shape.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace gpu {

namespace {

// Construct a new layout by adding removing the minor-most dimension to the
// input layout. For example, {3, 2, 1, 0} is extended to {2, 1, 0}.
// We expect that the input layout is normalized by LayoutNormalizer, so that
// the input layout has a descending ordering.
absl::StatusOr<Layout> GetLayoutWithNewMinorMostDimension(
const Layout& layout) {
if (!LayoutUtil::IsMonotonicWithDim0Major(layout)) {
return absl::InvalidArgumentError("Layout is not normalized.");
}
return LayoutUtil::MakeDescendingLayout(layout.minor_to_major_size() - 1);
}

class GemmDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor {
public:
absl::Status HandleDot(HloInstruction* instr) override {
HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
HloInstruction* lhs = dot->mutable_operand(0);
HloInstruction* rhs = dot->mutable_operand(1);

HloInstruction* new_lhs = nullptr;
HloInstruction* new_rhs = nullptr;

// The degenerate dimension is the last dimension of the LHS or RHS.
if (lhs->shape().dimensions().back() == 1) {
if (lhs->opcode() != HloOpcode::kBitcast) {
return absl::InternalError("Degenerate operand is not a bitcast.");
}
new_lhs = lhs->mutable_operand(0);
new_rhs = rhs;
} else if (rhs->shape().dimensions().back() == 1) {
if (rhs->opcode() != HloOpcode::kBitcast) {
return absl::InternalError("Degenerate operand is not a bitcast.");
}
new_lhs = lhs;
new_rhs = rhs->mutable_operand(0);
} else {
return absl::OkStatus();
}

changed_ = true;

std::vector<int64_t> new_out_dimensions;
new_out_dimensions.reserve(dot->shape().dimensions().size() - 1);
for (int64_t dim_size : dot->shape().dimensions()) {
if (dim_size == 1) {
continue;
}
new_out_dimensions.push_back(dim_size);
}

// GemvRewriter should only add one degenerate dimension.
if (new_out_dimensions.size() != dot->shape().dimensions().size() - 1) {
return absl::InternalError(
"More than one degenerate dimension in the output shape.");
}

Shape new_out_shape(
dot->shape().element_type(), new_out_dimensions,
absl::InlinedVector<bool, 4>(new_out_dimensions.size(), false),
/*tuple_shapes=*/{});
TF_ASSIGN_OR_RETURN(
*new_out_shape.mutable_layout(),
GetLayoutWithNewMinorMostDimension(dot->shape().layout()));

HloComputation* computation = dot->parent();
HloInstruction* new_dot =
computation->AddInstruction(HloInstruction::CreateDot(
new_out_shape, new_lhs, new_rhs, dot->dot_dimension_numbers(),
dot->precision_config()));

if (dot->user_count() != 1) {
return absl::InternalError("Dot should have exactly one user.");
}
HloInstruction* bitcast = dot->users()[0];
if (bitcast->opcode() != HloOpcode::kBitcast) {
return absl::InternalError("Dot user should be a bitcast.");
}
return computation->ReplaceInstruction(bitcast, new_dot);
}

bool changed() const { return changed_; }

private:
bool changed_ = false;
};

} // namespace

absl::StatusOr<bool> GemmDegenerateDimRemover::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
GemmDegenerateDimRemoverVisitor visitor;
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
}
return visitor.changed();
}

} // namespace gpu
} // namespace xla
48 changes: 48 additions & 0 deletions third_party/xla/xla/service/gpu/gemm_degenerate_dim_remover.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_GPU_GEMM_DEGENERATE_DIM_REMOVER_H_
#define XLA_SERVICE_GPU_GEMM_DEGENERATE_DIM_REMOVER_H_

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo_pass_interface.h"

namespace xla {
namespace gpu {

// Rewrite a gemm with a degenerate dimension to a matrix-vector multiplication.
// For example, [m x n] @ [n x 1] is rewritten to [m x n] @ [n], and [n x 1]
// @ [m x n] is rewritten to [n] @ [m x n].
//
// The degenerate dimension is introduced by GemvRewriter, we should remove it
// after GemmFusion is run.
class GemmDegenerateDimRemover : public HloModulePass {
public:
absl::string_view name() const override {
return "gemm-degenerate-dim-remover";
}

using HloPassInterface::Run;
absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
};

} // namespace gpu
} // namespace xla

#endif // XLA_SERVICE_GPU_GEMM_DEGENERATE_DIM_REMOVER_H_
102 changes: 102 additions & 0 deletions third_party/xla/xla/service/gpu/gemm_degenerate_dim_remover_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/gemm_degenerate_dim_remover.h"

#include <memory>
#include <optional>

#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/platform/statusor.h"

namespace xla::gpu {
namespace {

class GemmDegenerateDimRemoverTest : public HloTestBase {};

TEST_F(GemmDegenerateDimRemoverTest, RewriteMatrixVectorMultiplicationToGemm) {
const char* hlo = R"(
HloModule m
ENTRY e {
p0 = f32[32,7] parameter(0)
p1 = f32[7] parameter(1)
bitcast = f32[7, 1] bitcast(p1)
dot = f32[32,1] dot(p0, bitcast),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT result = f32[32] bitcast(dot)
})";

const char* expected = R"()
// CHECK: %[[P0:.*]] = f32[32,7]{1,0} parameter(0)
// CHECK: %[[P1:.*]] = f32[7]{0} parameter(1)
// CHECK: ROOT %[[DOT:.*]] = f32[32]{0} dot(%[[P0]], %[[P1]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
})";

RunAndFilecheckHloRewrite(hlo, GemmDegenerateDimRemover(), expected);
}

TEST_F(GemmDegenerateDimRemoverTest, RewriteVectorMatrixMultiplicationToGemm) {
const char* hlo = R"(
HloModule m
ENTRY e {
p0 = f32[7] parameter(0)
p1 = f32[7,32] parameter(1)
bitcast = f32[7, 1] bitcast(p0)
dot = f32[1,32] dot(bitcast, p1),
lhs_contracting_dims={0}, rhs_contracting_dims={0}
ROOT result = f32[32] bitcast(dot)
})";

const char* expected = R"()
// CHECK: %[[P0:.*]] = f32[7]{0} parameter(0)
// CHECK: %[[P1:.*]] = f32[7,32]{1,0} parameter(1)
// CHECK: ROOT %[[DOT:.*]] = f32[32]{0} dot(%[[P0]], %[[P1]]), lhs_contracting_dims={0}, rhs_contracting_dims={0}
})";

RunAndFilecheckHloRewrite(hlo, GemmDegenerateDimRemover(), expected);
}

TEST_F(GemmDegenerateDimRemoverTest,
RewriteMatrixVectorMultiplicationWithBatch) {
const char* hlo = R"(
HloModule m
ENTRY e {
p0 = f32[2,5,32,7] parameter(0)
p1 = f32[2,5,7] parameter(1)
bitcast = f32[2,5,7,1]{3,2,1,0} bitcast(p1)
d = f32[2,5,32,1] dot(p0, bitcast),
lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
lhs_contracting_dims={3}, rhs_contracting_dims={2}
ROOT result = f32[2,5,32] bitcast(d)
})";

const char* expected = R"()
// CHECK: %[[P0:.*]] = f32[2,5,32,7]{3,2,1,0} parameter(0)
// CHECK: %[[P1:.*]] = f32[2,5,7]{2,1,0} parameter(1)
// CHECK: ROOT %[[DOT:.*]] = f32[2,5,32]{2,1,0} dot(%[[P0]], %[[P1]]),
// CHECK-SAME: lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
})";

RunAndFilecheckHloRewrite(hlo, GemmDegenerateDimRemover(), expected);
}

} // namespace
} // namespace xla::gpu

0 comments on commit c734e43

Please sign in to comment.