Skip to content

Commit

Permalink
[xla:gpu] NFC: Remove AddressComputationFusion emitter
Browse files Browse the repository at this point in the history
After merging implementations always use DynamicAddressComputationFusion

PiperOrigin-RevId: 620901172
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Apr 2, 2024
1 parent 0849a6b commit ec6af2b
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 49 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ cc_library(
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:errors",
Expand Down
22 changes: 0 additions & 22 deletions third_party/xla/xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -726,28 +726,6 @@ absl::StatusOr<FusionEmissionResult> CustomFusion::Emit(
return result;
}

absl::StatusOr<FusionEmissionResult> AddressComputationFusion::Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const {
const HloFusionAdaptor& adaptor = analysis_.fusion();
auto maybe_custom_call_adaptor = HloFindIf(
adaptor.GetRoots(), adaptor,
[](auto node) { return node.opcode() == HloOpcode::kCustomCall; });
if (maybe_custom_call_adaptor == std::nullopt) {
return absl::InternalError(
"AddressComputationFusion requires a CustomCall hero");
}

const auto& custom_call = *static_cast<const HloCustomCallInstruction*>(
&maybe_custom_call_adaptor->instruction());
// TODO(vuson): these Emit* are mostly duplicated from ir_emitter_unnested
if (IsLegacyCublasMatmul(custom_call)) {
return EmitGemm(ir_emitter_context, adaptor, fusion, custom_call);
}

return EmitCustomCall(ir_emitter_context, adaptor, fusion, custom_call);
}

absl::StatusOr<FusionEmissionResult> DynamicAddressComputationFusion::Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const {
Expand Down
22 changes: 0 additions & 22 deletions third_party/xla/xla/service/gpu/fusions/custom.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,6 @@ class CustomFusion : public FusionInterface {
// compile-time instead of allocating a new buffer for it at runtime by
// translating the static slice into offset + size of the original buffer passed
// into the custom call `%gemm`.
class AddressComputationFusion : public FusionInterface {
public:
explicit AddressComputationFusion(const HloFusionAnalysis& analysis)
: analysis_(analysis) {}

absl::StatusOr<FusionEmissionResult> Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const final;

private:
const HloFusionAnalysis& analysis_;
};

// TODO(vuson): merge these two fusions.
// Emitter for custom fusions implementing dynamic address computation. A
// dynamic address computation contains a custom call hero, with at least one of
// its operands coming from a dynamic contiguous slice, and/or with at least one
// of its results feeding into a contiguous DUS.
//
// The goal is to compute the buffer addresses for sliced operands/results
// without having to allocate new buffers for these by wrapping
// AddressComputationThunk around the original custom call thunk.
class DynamicAddressComputationFusion : public FusionInterface {
public:
explicit DynamicAddressComputationFusion(const HloFusionAnalysis& analysis)
Expand Down
7 changes: 2 additions & 5 deletions third_party/xla/xla/service/gpu/fusions/fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License.

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "absl/strings/match.h"
#include "mlir/IR/Value.h" // from @llvm-project
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand Down Expand Up @@ -176,10 +176,7 @@ absl::StatusOr<std::unique_ptr<FusionInterface>> GetFusionEmitter(
switch (analysis.GetEmitterFusionKind()) {
case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: {
const auto& config = backend_config.custom_fusion_config();
if (config.name() == "address_computation") {
return std::make_unique<AddressComputationFusion>(analysis);
}
if (config.name() == "dynamic_address_computation") {
if (absl::StrContains(config.name(), "address_computation")) {
return std::make_unique<DynamicAddressComputationFusion>(analysis);
}
return std::make_unique<CustomFusion>();
Expand Down

0 comments on commit ec6af2b

Please sign in to comment.