Skip to content

Commit

Permalink
[xla:gpu] Add support for AddressComputationFusion
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600307299
  • Loading branch information
tyb0807 authored and tensorflower-gardener committed Jan 22, 2024
1 parent 9a5a1a9 commit 0c9abcc
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 3 deletions.
25 changes: 25 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("//xla/tests:build_defs.bzl", "xla_test")
load("//xla:xla.bzl", "xla_cc_test")
load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")

Expand Down Expand Up @@ -55,12 +56,21 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":fusion_emitter",
"//xla:shape_util",
"//xla:status_macros",
"//xla:statusor",
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo:lhlo",
"//xla/service:buffer_assignment",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:cublas_cudnn",
"//xla/service/gpu:gemm_thunk",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:ir_emitter_context",
"//xla/service/gpu:kernel_arguments",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu:thunk",
"//xla/service/gpu/kernels:custom_fusion",
"//xla/service/gpu/kernels:custom_kernel",
Expand All @@ -73,6 +83,21 @@ cc_library(
],
)

xla_test(
name = "address_computation_fusion_test",
srcs = ["address_computation_fusion_test.cc"],
backends = ["gpu"],
deps = [
"//xla:array3d",
"//xla:error_spec",
"//xla:literal_util",
"//xla:types",
"//xla/tests:hlo_test_base",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
],
)

cc_library(
name = "fusion_emitter",
srcs = ["fusion_emitter.cc"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/array3d.h"
#include "xla/error_spec.h"
#include "xla/literal_util.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/types.h"
#include "tsl/platform/test.h"

namespace xla {
namespace gpu {
namespace {

class AddressComputationFusionTest : public HloTestBase {};

TEST_F(AddressComputationFusionTest, CublasGemmSimple) {
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};

const char* hlo_ref = R"(
HloModule jit_slice, entry_computation_layout={(bf16[2,8,8]{2,1,0}, bf16[2,8,8]{2,1,0})->bf16[8,8]{1,0}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.9 (Arg_0.1: bf16[2,8,8], Arg_1.2: bf16[2,8,8]) -> bf16[8,8] {
%Arg_0.1 = bf16[2,8,8]{2,1,0} parameter(0), sharding={replicated}
%Arg_1.2 = bf16[2,8,8]{2,1,0} parameter(1), sharding={replicated}
%slice.13 = bf16[1,8,8]{2,1,0} slice(bf16[2,8,8]{2,1,0} %Arg_0.1), slice={[1:2], [0:8], [0:8]}
%bitcast.41 = bf16[8,8]{1,0} bitcast(bf16[1,8,8]{2,1,0} %slice.13)
%slice.14 = bf16[1,8,8]{2,1,0} slice(bf16[2,8,8]{2,1,0} %Arg_1.2), slice={[1:2], [0:8], [0:8]}
%bitcast.42 = bf16[8,8]{1,0} bitcast(bf16[1,8,8]{2,1,0} %slice.14)
ROOT %custom-call.1 = bf16[8,8]{1,0} custom-call(bf16[8,8]{1,0} %bitcast.41, bf16[8,8]{1,0} %bitcast.42), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","lhs_stride":"64","rhs_stride":"64","grad_x":false,"grad_y":false}}
})";

const char* hlo_opt = R"(
HloModule jit_slice, entry_computation_layout={(bf16[2,8,8]{2,1,0}, bf16[2,8,8]{2,1,0})->bf16[8,8]{1,0}}, allow_spmd_sharding_propagation_to_output={true}
%fused_computation (param_0_0: bf16[2,8,8], param_1_0: bf16[2,8,8]) -> bf16[8,8]{1,0} {
%param_0_0 = bf16[2,8,8]{2,1,0} parameter(0)
%slice.13 = bf16[1,8,8]{2,1,0} slice(bf16[2,8,8]{2,1,0} %param_0_0), slice={[1:2], [0:8], [0:8]}
%bitcast.41 = bf16[8,8]{1,0} bitcast(bf16[1,8,8]{2,1,0} %slice.13)
%param_1_0 = bf16[2,8,8]{2,1,0} parameter(1)
%slice.14 = bf16[1,8,8]{2,1,0} slice(bf16[2,8,8]{2,1,0} %param_1_0), slice={[1:2], [0:8], [0:8]}
%bitcast.42 = bf16[8,8]{1,0} bitcast(bf16[1,8,8]{2,1,0} %slice.14)
ROOT %custom-call.1 = bf16[8,8]{1,0} custom-call(bf16[8,8]{1,0} %bitcast.41, bf16[8,8]{1,0} %bitcast.42), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","lhs_stride":"64","rhs_stride":"64","grad_x":false,"grad_y":false}}
}
ENTRY %main.9 (Arg_0.1: bf16[2,8,8], Arg_1.2: bf16[2,8,8]) -> bf16[8,8] {
%Arg_0.1 = bf16[2,8,8]{2,1,0} parameter(0), sharding={replicated}
%Arg_1.2 = bf16[2,8,8]{2,1,0} parameter(1), sharding={replicated}
ROOT %fusion.2 = bf16[8,8]{1,0} fusion(bf16[2,8,8]{2,1,0} %Arg_0.1, bf16[2,8,8]{2,1,0} %Arg_1.2), kind=kCustom, calls=%fused_computation,
backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}}
})";

Array3D<bfloat16> arr0(2, 8, 8); // bf16[2,8,8]
Array3D<bfloat16> arr1(2, 8, 8); // bf16[2,8,8]
arr0.FillIota(static_cast<bfloat16>(1.0));
arr1.FillRandom(bfloat16(0.01f), 0.02);

auto a0 = LiteralUtil::CreateFromArray(arr0);
auto a1 = LiteralUtil::CreateFromArray(arr1);

EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, {&a0, &a1}, error_spec,
/*run_hlo_passes=*/false));
}

TEST_F(AddressComputationFusionTest, CublasGemmWithWorkspace) {
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};

const char* hlo_ref = R"(
HloModule jit_slice, entry_computation_layout={(f16[2,8,8]{2,1,0}, f16[2,8,8]{2,1,0})->(f16[8,8]{1,0}, s8[256]{0})}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.9 (Arg_0.1: f16[2,8,8], Arg_1.2: f16[2,8,8]) -> (f16[8,8]{1,0}, s8[256]{0}) {
%Arg_0.1 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated}
%Arg_1.2 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated}
%slice.13 = f16[1,8,8]{2,1,0} slice(f16[2,8,8]{2,1,0} %Arg_0.1), slice={[1:2], [0:8], [0:8]}
%bitcast.41 = f16[8,8]{1,0} bitcast(f16[1,8,8]{2,1,0} %slice.13)
%slice.14 = f16[1,8,8]{2,1,0} slice(f16[2,8,8]{2,1,0} %Arg_1.2), slice={[1:2], [0:8], [0:8]}
%bitcast.42 = f16[8,8]{1,0} bitcast(f16[1,8,8]{2,1,0} %slice.14)
ROOT %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(f16[8,8]{1,0} %bitcast.41, f16[8,8]{1,0} %bitcast.42), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","lhs_stride":"64","rhs_stride":"64","grad_x":false,"grad_y":false}}
})";

const char* hlo_opt = R"(
HloModule jit_slice, entry_computation_layout={(f16[2,8,8]{2,1,0}, f16[2,8,8]{2,1,0})->(f16[8,8]{1,0}, s8[256]{0})}, allow_spmd_sharding_propagation_to_output={true}
%fused_computation (param_0_0: f16[2,8,8], param_1_0: f16[2,8,8]) -> (f16[8,8]{1,0}, s8[256]{0}) {
%param_0_0 = f16[2,8,8]{2,1,0} parameter(0)
%slice.13 = f16[1,8,8]{2,1,0} slice(f16[2,8,8]{2,1,0} %param_0_0), slice={[1:2], [0:8], [0:8]}
%bitcast.41 = f16[8,8]{1,0} bitcast(f16[1,8,8]{2,1,0} %slice.13)
%param_1_0 = f16[2,8,8]{2,1,0} parameter(1)
%slice.14 = f16[1,8,8]{2,1,0} slice(f16[2,8,8]{2,1,0} %param_1_0), slice={[1:2], [0:8], [0:8]}
%bitcast.42 = f16[8,8]{1,0} bitcast(f16[1,8,8]{2,1,0} %slice.14)
%custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(f16[8,8]{1,0} %bitcast.41, f16[8,8]{1,0} %bitcast.42), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","lhs_stride":"64","rhs_stride":"64","grad_x":false,"grad_y":false}}
%get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element((f16[8,8]{1,0}, s8[256]{0}) %custom-call.1), index=0
%get-tuple-element.1 = s8[256]{0} get-tuple-element((f16[8,8]{1,0}, s8[256]{0}) %custom-call.1), index=1
ROOT %tuple = (f16[8,8]{1,0}, s8[256]{0}) tuple(%get-tuple-element.0, %get-tuple-element.1)
}
ENTRY %main.9 (Arg_0.1: f16[2,8,8], Arg_1.2: f16[2,8,8]) -> (f16[8,8]{1,0}, s8[256]{0}) {
%Arg_0.1 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated}
%Arg_1.2 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated}
ROOT %fusion.2 = (f16[8,8]{1,0}, s8[256]{0}) fusion(f16[2,8,8]{2,1,0} %Arg_0.1, f16[2,8,8]{2,1,0} %Arg_1.2), kind=kCustom, calls=%fused_computation,
backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}}
})";

Array3D<bfloat16> arr0(2, 8, 8); // bf16[2,8,8]
Array3D<bfloat16> arr1(2, 8, 8); // bf16[2,8,8]
arr0.FillRandom(bfloat16(0.01f), 0.02);
arr1.FillIota(static_cast<bfloat16>(10.0));

auto a0 = LiteralUtil::CreateFromArray(arr0);
auto a1 = LiteralUtil::CreateFromArray(arr1);

EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, {&a0, &a1}, error_spec,
/*run_hlo_passes=*/false));
}

} // namespace
} // namespace gpu
} // namespace xla
133 changes: 133 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ limitations under the License.
==============================================================================*/
#include "xla/service/gpu/fusions/custom.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <variant>
#include <vector>
Expand All @@ -23,17 +26,29 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mlir/IR/Operation.h" // from @llvm-project
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
#include "xla/service/gpu/gemm_thunk.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/kernel_arguments.h"
#include "xla/service/gpu/kernels/custom_fusion.h"
#include "xla/service/gpu/kernels/custom_kernel.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/runtime3/kernel_thunk.h"
#include "xla/service/gpu/thunk.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/statusor.h"
#include "tsl/platform/statusor.h"

Expand Down Expand Up @@ -62,6 +77,61 @@ absl::StatusOr<std::unique_ptr<Thunk>> BuildCustomKernelThunkForFusion(
instr, std::move(custom_kernel), std::move(kernel_arguments.args()));
}

bool IsSliceInLeadingDimOnly(const HloInstruction& instr) {
auto slice = DynCast<HloSliceInstruction>(&instr);
if (!slice) return false;
const Shape& shape = slice->operand(0)->shape();
int64_t major_dim = shape.layout().minor_to_major().back();
for (size_t i = 0; i < shape.rank(); ++i) {
if (i == major_dim) continue;
if (slice->slice_starts(i) != 0 ||
slice->slice_limits(i) != shape.dimensions(i))
return false;
}
return true;
}

absl::StatusOr<BufferAllocation::Slice> GetSliceWithUpdatedOffsetAndSize(
const BufferAssignment& buffer_assignment, const HloFusionAdaptor& fusion,
const HloInstruction* bufferized_instr, const HloInstruction& start) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice orig_slice,
GetAllocationSlice(buffer_assignment, bufferized_instr, {}));

auto maybe_slice_adaptor =
HloFindIf({HloInstructionAdaptor(start)}, fusion,
[](auto node) { return node.opcode() == HloOpcode::kSlice; });
if (maybe_slice_adaptor == std::nullopt) return orig_slice;

const auto& slice_instr = *static_cast<const HloSliceInstruction*>(
&maybe_slice_adaptor->instruction());

TF_RET_CHECK(IsSliceWithUnitStrides(&slice_instr))
<< "AddressComputationFusion only handles slices with unit strides "
"currently";
TF_RET_CHECK(IsSliceInLeadingDimOnly(slice_instr))
<< "AddressComputationFusion only handles slices in leading dim "
"currently";

// Given this shape f16[10,10,10]{2,1,0}, sliced into f16[2,10,10]{2,1,0}
// We say that the sliced shape contains 2 slice units of f16[10,10]
const Shape& shape = slice_instr.shape();
int64_t major_dim = shape.layout().minor_to_major().back();
// The sliced leading dim is the number of slice units.
int64_t slice_unit_count = shape.dimensions(major_dim);
int64_t num_elem = ShapeUtil::ElementsIn(shape);
// The number of elements in a slice unit is the total number of elements
// divided by the number of slice units.
int64_t slice_unit_num_elem = num_elem / slice_unit_count;
int64_t slice_unit_byte_size =
slice_unit_num_elem *
ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());

int64_t offset = slice_instr.slice_starts(major_dim) * slice_unit_byte_size;
int64_t size = slice_unit_count * slice_unit_byte_size;
return BufferAllocation::Slice(orig_slice.allocation(), offset, size);
}

} // namespace

absl::StatusOr<FusionEmissionResult> CustomFusionEmitter::Emit(
Expand Down Expand Up @@ -113,5 +183,68 @@ absl::StatusOr<FusionEmissionResult> CustomFusionEmitter::Emit(
return result;
}

absl::StatusOr<FusionEmissionResult> AddressComputationFusionEmitter::Emit(
IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op,
const HloFusionInstruction& fusion) const {
const BufferAssignment& buffer_assignment =
ir_emitter_context.buffer_assignment();

const HloFusionAdaptor& adaptor = analysis_.fusion();
auto maybe_custom_call_adaptor = HloFindIf(
adaptor.GetRoots(), adaptor,
[](auto node) { return node.opcode() == HloOpcode::kCustomCall; });
TF_RET_CHECK(maybe_custom_call_adaptor != std::nullopt)
<< "AddressComputationFusion requires a CustomCall hero";

const auto& custom_call = *static_cast<const HloCustomCallInstruction*>(
&maybe_custom_call_adaptor->instruction());
if (IsLegacyCublasMatmul(custom_call)) {
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_slice,
GetSliceWithUpdatedOffsetAndSize(
buffer_assignment, adaptor, fusion.operand(0),
*custom_call.operand(0)));

TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_slice,
GetSliceWithUpdatedOffsetAndSize(
buffer_assignment, adaptor, fusion.operand(1),
*custom_call.operand(1)));

BufferAllocation::Slice output;
std::optional<BufferAllocation::Slice> workspace;

// Result of a legacy cuBLAS custom call can be a tuple if we explicitly
// allocate workspace buffer in HLO. If result is an array, it means that
// workspace is not available, and cuBLAS will allocate its own workspace.
if (custom_call.shape().IsArray()) {
TF_ASSIGN_OR_RETURN(output,
GetAllocationSlice(buffer_assignment, &fusion, {}));
} else {
TF_ASSIGN_OR_RETURN(output,
GetAllocationSlice(buffer_assignment, &fusion, {0}));
TF_ASSIGN_OR_RETURN(workspace,
GetAllocationSlice(buffer_assignment, &fusion, {1}));
}

bool deterministic_ops =
ir_emitter_context.debug_options().xla_gpu_deterministic_ops();

TF_ASSIGN_OR_RETURN(
GemmConfig config,
GemmConfig::For(static_cast<const HloInstruction*>(&custom_call)));
auto thunk = std::make_unique<GemmThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(&custom_call),
std::move(config), lhs_slice, rhs_slice, output, workspace,
deterministic_ops);

FusionEmissionResult result;
result.thunks.push_back(std::move(thunk));
return result;
}

return absl::UnimplementedError(
absl::StrCat("No emission for AddressComputationFusion of custom call ",
custom_call.custom_call_target()));
}

} // namespace gpu
} // namespace xla
30 changes: 30 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/custom.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/statusor.h"

Expand All @@ -33,6 +34,35 @@ class CustomFusionEmitter : public FusionInterface {
const HloFusionInstruction& fusion) const final;
};

// Emitter for custom fusions implementing address computation. An address
// computation contains a custom call hero, with at least one of its operands
// comes from a static contiguous slice. E.g. operand `%cast` of `%gemm` coming
// from `%slice`:
// %address_computation {
// %p0 = f32[2, 1024, 1024]
// %p1 = f32[1024, 1024]
// %slice = f32[1, 1024, 1024] slice(%p0)
// %cast = f32[1024, 1024] bitcast(%slice)
// ROOT %gemm = custom_call(%cast, %p1) __cublas$Gemm
// }
//
// The goal is to compute the buffer addresses for such operands (`%cast`) at
// compile-time instead of allocating a new buffer for it at runtime by
// translating the static slice into offset + size of the original buffer passed
// into the custom call `%gemm`.
class AddressComputationFusionEmitter : public FusionInterface {
public:
explicit AddressComputationFusionEmitter(const HloFusionAnalysis& analysis)
: analysis_(analysis) {}

absl::StatusOr<FusionEmissionResult> Emit(
IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op,
const HloFusionInstruction& fusion) const final;

private:
const HloFusionAnalysis& analysis_;
};

} // namespace gpu
} // namespace xla

Expand Down
Loading

0 comments on commit 0c9abcc

Please sign in to comment.