Skip to content

Commit

Permalink
[XLA:GPU] Compute a graph of TiledHloInstruction for given tile param…
Browse files Browse the repository at this point in the history
…eters.

Use TiledHloInstructions in the Triton emitter.

PiperOrigin-RevId: 623103735
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Apr 9, 2024
1 parent 068d733 commit 5501d29
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 200 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/BUILD
Expand Up @@ -538,6 +538,7 @@ cc_library(
"//xla/service/gpu/model:indexing_map",
"//xla/service/gpu/model:symbolic_tile_analysis",
"//xla/service/gpu/model:symbolic_tiled_hlo_instruction",
"//xla/service/gpu/model:tiled_hlo_instruction",
"//xla/service/llvm_ir:llvm_util",
"//xla/stream_executor:device_description",
"//xla/stream_executor:launch_dim",
Expand Down
80 changes: 36 additions & 44 deletions third_party/xla/xla/service/gpu/ir_emitter_triton.cc
Expand Up @@ -115,7 +115,7 @@ limitations under the License.
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h"
#include "xla/service/gpu/model/tiled_hlo_instruction.h"
#include "xla/service/gpu/target_util.h"
#include "xla/service/gpu/triton_fusion_analysis.h"
#include "xla/service/gpu/triton_tiling_propagation.h"
Expand Down Expand Up @@ -738,11 +738,10 @@ absl::StatusOr<Value> EmitNestedFusion(

// TODO(b/331332678): Add unit tests to target this function specifically.
Value EmitTiledBroadcast(
ImplicitLocOpBuilder& b, const SymbolicTileAnalysis& analysis,
const SymbolicTiledHloInstruction& tiled_broadcast,
absl::flat_hash_map<const SymbolicTiledHloInstruction*, Value>& values) {
auto input_tile_shape = analysis.TileSizes(*tiled_broadcast.operand(0));
auto output_tile_shape = analysis.TileSizes(tiled_broadcast);
ImplicitLocOpBuilder& b, const TiledHloInstruction& tiled_broadcast,
absl::flat_hash_map<const TiledHloInstruction*, Value>& values) {
auto input_tile_shape = tiled_broadcast.operand(0)->tile_sizes();
auto output_tile_shape = tiled_broadcast.tile_sizes();

Value expanded_input = values[tiled_broadcast.operand(0)];

Expand Down Expand Up @@ -799,11 +798,10 @@ Value EmitTiledBroadcast(
absl::StatusOr<Value> EmitTiledHloInstruction(
ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
const se::DeviceDescription& device_info,
const SymbolicTileAnalysis& analysis,
const SymbolicTiledHloInstruction& tiled_hlo,
std::function<absl::StatusOr<Value>(const SymbolicTiledHloInstruction&)>
const TiledHloInstruction& tiled_hlo,
std::function<absl::StatusOr<Value>(const TiledHloInstruction&)>
emit_param_load_fn,
absl::flat_hash_map<const SymbolicTiledHloInstruction*, Value>& values) {
absl::flat_hash_map<const TiledHloInstruction*, Value>& values) {
const HloInstruction* hlo = tiled_hlo.hlo();

if (hlo->opcode() == HloOpcode::kParameter) {
Expand All @@ -817,7 +815,7 @@ absl::StatusOr<Value> EmitTiledHloInstruction(
}

if (hlo->opcode() == HloOpcode::kBroadcast) {
return EmitTiledBroadcast(b, analysis, tiled_hlo, values);
return EmitTiledBroadcast(b, tiled_hlo, values);
}

if (hlo->opcode() == HloOpcode::kReduce) {
Expand All @@ -829,7 +827,7 @@ absl::StatusOr<Value> EmitTiledHloInstruction(
std::vector<Value> operands;
operands.reserve(hlo->operands().size());

for (const SymbolicTiledHloInstruction* operand : tiled_hlo.operands()) {
for (const TiledHloInstruction* operand : tiled_hlo.operands()) {
operands.push_back(values[operand]);
}
return EmitElementwise(b, libdevice_path, device_info, *hlo, operands);
Expand All @@ -852,21 +850,22 @@ absl::StatusOr<Value> EmitTiledHloInstruction(
absl::StatusOr<Value> EmitTiledScope(
ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
const se::DeviceDescription& device_info,
const SymbolicTileAnalysis& analysis,
std::function<absl::StatusOr<Value>(const SymbolicTiledHloInstruction&)>
const std::vector<std::unique_ptr<TiledHloInstruction>>&
tiled_hlo_instructions,
std::function<absl::StatusOr<Value>(const TiledHloInstruction&)>
emit_param_load_fn,
absl::flat_hash_map<const SymbolicTiledHloInstruction*, Value>& values) {
for (const auto& tiled_hlo : analysis.GetTiledHloInstructions()) {
absl::flat_hash_map<const TiledHloInstruction*, Value>& values) {
for (const auto& tiled_hlo : tiled_hlo_instructions) {
TF_ASSIGN_OR_RETURN(
Value result,
EmitTiledHloInstruction(b, libdevice_path, device_info, analysis,
*tiled_hlo, emit_param_load_fn, values));
EmitTiledHloInstruction(b, libdevice_path, device_info, *tiled_hlo,
emit_param_load_fn, values));
TF_RET_CHECK(values.insert({tiled_hlo.get(), result}).second)
<< tiled_hlo->hlo()->ToString();
VLOG(8) << "Emitted "
<< tiled_hlo->hlo()->ToString(HloPrintOptions::ShortParsable());
}
return values[analysis.GetRoot()];
return values[tiled_hlo_instructions.back().get()];
}

// Emit sequence of instructions using compatible tiling ordered producers
Expand Down Expand Up @@ -2219,10 +2218,10 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
// `tile_offset_indexing` is a mapping from
// (program_id) -> [tile_offset0, ..., tile_offsetN]
Value ComputeBasePtrOffset(ImplicitLocOpBuilder b, Value pid,
const Shape& shape,
const IndexingMap& tile_offset_indexing) {
const TiledHloInstruction& tiled_hlo) {
const Shape& shape = tiled_hlo.hlo()->shape();
ArrayRef<mlir::AffineExpr> dimension_exprs =
tile_offset_indexing.GetAffineMap().getResults();
tiled_hlo.block_id_to_tile_offsets_indexing().GetAffineMap().getResults();

mlir::AffineExpr linear_index =
mlir::getAffineConstantExpr(0, b.getContext());
Expand Down Expand Up @@ -2292,7 +2291,9 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,
computation->root_instruction()->shape().rank(), 1);
output_tile_sizes.back() = row_len;

analysis->SetTileSizes(output_tile_sizes);
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<TiledHloInstruction>> tiled_hlo_instructions,
analysis->ComputeTiledHloInstructions(output_tile_sizes));

// block_size must be a power of two.
int result_block_size = llvm::PowerOf2Ceil(row_len);
Expand All @@ -2303,27 +2304,21 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,
}

// Emits load instructions
auto emit_param_load = [&](const SymbolicTiledHloInstruction& tiled_hlo)
-> absl::StatusOr<Value> {
auto emit_param_load =
[&](const TiledHloInstruction& tiled_hlo) -> absl::StatusOr<Value> {
std::vector<Value> tile_sizes, tile_strides, tile_offsets;
for (auto [size, stride, offset] : llvm::zip(
analysis->TileSizes(tiled_hlo), analysis->TileStrides(tiled_hlo),
analysis->TileOffsets(tiled_hlo))) {
for (auto [size, stride] :
llvm::zip(tiled_hlo.tile_sizes(), tiled_hlo.tile_strides())) {
if (size == 1) continue;

tile_sizes.push_back(CreateConst(b, b.getI64Type(), size));
tile_strides.push_back(CreateConst(b, b.getI64Type(), stride));
tile_offsets.push_back(CreateConst(b, b.getI32Type(), offset));
tile_offsets.push_back(CreateConst(b, b.getI32Type(), 0));
}

TF_ASSIGN_OR_RETURN(
IndexingMap program_id_to_input_tile_indexing,
analysis->ComputeBlockIdToTileOffsetIndexing(tiled_hlo));

// Manually compute pointer offset to avoid materialized fully parallel
// dimensions in the tile. Current codegen tried to avoid size-1 dims.
Value ptr_offset = ComputeBasePtrOffset(b, pid, tiled_hlo.hlo()->shape(),
program_id_to_input_tile_indexing);
Value ptr_offset = ComputeBasePtrOffset(b, pid, tiled_hlo);

auto fn_arg = fn.getArgument(tiled_hlo.hlo()->parameter_number());
auto tile_ptr = AddPtr(b, fn_arg, ptr_offset);
Expand All @@ -2343,17 +2338,14 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,
return EmitParameterLoad(b, emitted_tensor, boundary_checks);
};

absl::flat_hash_map<const SymbolicTiledHloInstruction*, Value> values_out;
TF_ASSIGN_OR_RETURN(Value result,
EmitTiledScope(b, libdevice_path, device_info, *analysis,
emit_param_load, values_out));

absl::flat_hash_map<const TiledHloInstruction*, Value> values_out;
TF_ASSIGN_OR_RETURN(
IndexingMap program_id_to_output_tile_indexing,
analysis->ComputeBlockIdToTileOffsetIndexing(*analysis->GetRoot()));
Value result,
EmitTiledScope(b, libdevice_path, device_info, tiled_hlo_instructions,
emit_param_load, values_out));

Value ptr_offset = ComputeBasePtrOffset(b, pid, root_shape,
program_id_to_output_tile_indexing);
Value ptr_offset =
ComputeBasePtrOffset(b, pid, *tiled_hlo_instructions.back());

Value store_tensor = b.create<mt::MakeTensorPtrOp>(
/*base=*/AddPtr(b, fn.getArgument(computation->num_parameters()),
Expand Down
6 changes: 4 additions & 2 deletions third_party/xla/xla/service/gpu/model/BUILD
Expand Up @@ -632,8 +632,8 @@ cc_library(
":indexing_map",
":symbolic_tile",
":symbolic_tiled_hlo_instruction",
":tiled_hlo_instruction",
"//xla:status",
"//xla:status_macros",
"//xla/hlo/ir:hlo",
"//xla/service:instruction_fusion",
"@com_google_absl//absl/algorithm:container",
Expand All @@ -646,6 +646,8 @@ cc_library(
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
],
)

Expand All @@ -655,7 +657,7 @@ xla_cc_test(
deps = [
":indexing_test_utils",
":symbolic_tile_analysis",
":symbolic_tiled_hlo_instruction",
":tiled_hlo_instruction",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
Expand Down

0 comments on commit 5501d29

Please sign in to comment.