Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU][MLIR-based emitters] Implement scatter with atomics. #63187

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ cc_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TensorDialect",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ limitations under the License.
#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/TypeRange.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
Expand Down Expand Up @@ -80,7 +81,9 @@ namespace {

using llvm::SmallVector;
using llvm::SmallVectorImpl;
using mlir::Block;
using mlir::ImplicitLocOpBuilder;
using mlir::IRMapping;
using mlir::Location;
using mlir::OpBuilder;
using mlir::Value;
Expand Down Expand Up @@ -1134,6 +1137,25 @@ mlir::Value ClampIndex(mlir::Value index, bool is_unsigned, int64_t high,
return index;
}

SmallVector<Value, 2> InlineBlock(OpBuilder& builder, Block& src_block,
ValueRange mapped_args) {
IRMapping mapping;
for (auto [from, to] : llvm::zip(src_block.getArguments(), mapped_args)) {
mapping.map(from, to);
}
for (auto& op : src_block.without_terminator()) {
builder.clone(op, mapping);
}
auto* terminator = src_block.getTerminator();
SmallVector<Value, 2> mapped_results;

mapped_results.reserve(terminator->getResults().size());
for (mlir::Value result : src_block.getTerminator()->getOperands()) {
mapped_results.push_back(mapping.lookup(result));
}
return mapped_results;
}

} // namespace mlir_converter
} // namespace gpu
} // namespace xla
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,17 @@ llvm::SmallVector<mlir::Value> EmitLoopNest(
mlir::ValueRange iter_args, mlir::ValueRange dim_values,
mlir::ValueRange symbol_values)>& create_body);

// Clamps
// Clamps `index` to [0, high] boundaries.
mlir::Value ClampIndex(mlir::Value index, bool is_unsigned, int64_t high,
mlir::ImplicitLocOpBuilder& b);

// Inlines `src_block` using `mapped_args` to initialize IRMapping from the
// block arguments of `src_block` to `mapped_args`. Return remapped values of
// the terminator.
mlir::SmallVector<mlir::Value, 2> InlineBlock(mlir::OpBuilder& builder,
mlir::Block& src_block,
mlir::ValueRange mapped_args);

} // namespace mlir_converter
} // namespace gpu
} // namespace xla
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,5 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
37 changes: 37 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,60 @@ mlir::LogicalResult PureCallOp::verifySymbolUses(
return mlir::success();
}

//===----------------------------------------------------------------------===//
// AllocateSharedOp
//===----------------------------------------------------------------------===//

void AllocateSharedOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
setNameFn(getResult(), "shmem");
}

//===----------------------------------------------------------------------===//
// AtomicRMWOp
//===----------------------------------------------------------------------===//

void AtomicRMWOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
setNameFn(getResult(), "atomic_rmw");
}

using mlir::OpBuilder;
using mlir::OperationState;
using mlir::RankedTensorType;
using mlir::Region;
using mlir::Type;
using mlir::Value;
using mlir::ValueRange;

void AtomicRMWOp::build(OpBuilder &builder, OperationState &result,
Value tensor, ValueRange ivs) {
OpBuilder::InsertionGuard g(builder);
result.addOperands(tensor);
result.addOperands(ivs);
result.addTypes(tensor.getType());

auto tensor_type = llvm::cast<RankedTensorType>(tensor.getType());
Region *body = result.addRegion();
builder.createBlock(body);
body->addArgument(tensor_type.getElementType(), tensor.getLoc());
}

//===----------------------------------------------------------------------===//
// PureCallOp
//===----------------------------------------------------------------------===//

void PureCallOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
for (auto result : getResults()) {
setNameFn(result, "pure_call");
}
}

//===----------------------------------------------------------------------===//
// SyncThreadsOp
//===----------------------------------------------------------------------===//

void SyncThreadsOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
for (auto result : getResults()) {
Expand Down
14 changes: 14 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ def XLAGPU_AtomicRMWOp : XLAGPU_Op<"atomic_rmw",
// the updated value.
let regions = (region SizedRegion<1>:$computation);

let skipDefaultBuilders = 1;
let builders = [OpBuilder<(ins "mlir::Value":$memref, "mlir::ValueRange":$ivs)>];

let extraClassDeclaration = [{
mlir::Block* getBody() { return &getComputation().front(); }
mlir::OpBuilder getBodyBuilder() {
return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
}
// The value stored in tensor[ivs].
mlir::Value getCurrentValue() {
return getRegion().getArgument(0);
}
}];

let assemblyFormat = [{
$input `[` $indices `]` `:` type($input) $computation attr-dict
}];
Expand Down
129 changes: 85 additions & 44 deletions third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
Expand All @@ -49,8 +51,14 @@ namespace gpu {
namespace {

using llvm::SmallVector;
using mlir::Location;
using mlir::OpBuilder;
using mlir::Value;
using mlir::ValueRange;
using mlir::arith::AddIOp;
using mlir::arith::AndIOp;
using mlir::arith::CmpIOp;
using mlir::arith::CmpIPredicate;
using mlir::arith::ConstantIndexOp;
using mlir::func::ReturnOp;
using mlir::tensor::InsertOp;
Expand All @@ -59,22 +67,19 @@ using mlir_converter::CallTargetProvider;
using mlir_converter::PartitionedComputations;
using mlir_converter::ProvideParameter;

namespace scf = ::mlir::scf;

} // namespace

bool MlirScatterFusion::IsSupported(const HloFusionAnalysis& analysis) {
auto* scatter = Cast<HloScatterInstruction>(analysis.fusion_heroes().front());
if (!scatter->unique_indices()) {
LOG(ERROR) << "MlirScatterFusion with atomics is not yet implemented";
return false;
}
if (scatter->scatter_operand_count() != 1) {
LOG(ERROR) << "Variadic scatter is not supported like in the legacy "
"emitter, although it is possible to make it work when the "
"indices are unique.";
return false;
}
// Do not enable it for now.
return false;
return true;
}

std::optional<IndexingMap> MlirScatterFusion::ComputeThreadIdToOutputIndexing(
Expand Down Expand Up @@ -137,6 +142,33 @@ MlirScatterFusion::GetInstructionsWithCustomCodegen(
return {analysis_.fusion_heroes()[0]};
}

mlir::Value EmitScatterComputation(
const HloInstruction* scatter, ValueRange indices, Value update_elem,
Value output_tensor,
const mlir_converter::PartitionedComputation& root_computation,
const mlir_converter::CallTargetProvider& call_targets,
mlir::ImplicitLocOpBuilder& b) {
constexpr int kScatterOperandIndex = 0;
auto reducer =
call_targets(scatter->called_computations()[0]->root_instruction());
if (scatter->unique_indices()) {
auto operand_elem =
ProvideParameter(root_computation, scatter, kScatterOperandIndex,
indices, call_targets, b)[0];
auto reduced_val = mlir_converter::InlineBlock(
b, reducer.getBody().front(), {operand_elem, update_elem})[0];

return b.create<InsertOp>(reduced_val, output_tensor, indices);
}
auto atomic_rmw = b.create<AtomicRMWOp>(output_tensor, indices);
mlir::OpBuilder body_builder = atomic_rmw.getBodyBuilder();
auto reduced_val = mlir_converter::InlineBlock(
body_builder, reducer.getBody().front(),
{atomic_rmw.getCurrentValue(), update_elem})[0];
body_builder.create<xla::gpu::YieldOp>(reducer->getLoc(), reduced_val);
return atomic_rmw->getResult(0);
}

// The scatter has to be canonicalized with `scatter_simplifier` pass.
absl::Status MlirScatterFusion::EmitEntryFunction(
const PartitionedComputations& computations,
Expand All @@ -150,6 +182,7 @@ absl::Status MlirScatterFusion::EmitEntryFunction(
scatter->operand(kScatterOperandIndex);
const HloInstruction* scatter_indices =
scatter->operand(kScatterIndicesIndex);
const HloInstruction* scatter_update = scatter->operand(kScatterUpdateIndex);

mlir::MLIRContext* mlir_context = entry_function.getContext();
auto thread_id_to_update_map =
Expand All @@ -165,13 +198,9 @@ absl::Status MlirScatterFusion::EmitEntryFunction(
mlir::ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function);
b.setInsertionPointToStart(entry_function.addEntryBlock());

int num_inputs = fusion.fused_instructions_computation()->num_parameters();
int num_outputs = entry_function.getArguments().size() - num_inputs;
auto output_tensor_args =
entry_function.getArguments().drop_front(num_inputs);
SmallVector<Value> result_tensors{output_tensor_args.begin(),
output_tensor_args.end()};
SmallVector<Value> result_tensors{entry_function.getArguments().back()};
auto c0 = b.create<ConstantIndexOp>(0);

auto scatter_result = EmitThreadLoopNest(
b, result_tensors, thread_id_to_update_map,
[&](ValueRange output_tensors, ValueRange dim_values,
Expand All @@ -185,40 +214,52 @@ absl::Status MlirScatterFusion::EmitEntryFunction(
update_tensor_indices, call_targets, b)
.front();

// Extract and clamp indices.
SmallVector<Value, 4> clamped_indices(scatter_operand->shape().rank(),
c0);
for (int i = 0; i < scatter_indices->shape().dimensions(1); ++i) {
SmallVector<Value, 4> indices_tensor_indices = {
update_tensor_indices.front(), b.create<ConstantIndexOp>(i)};
auto index =
ProvideParameter(root_computation, scatter, kScatterIndicesIndex,
indices_tensor_indices, call_targets, b)[0];
index = mlir_converter::ClampIndex(
index, /*is_unsigned=*/false,
scatter_operand->shape().dimensions(i), b);
index = b.create<mlir::arith::AddIOp>(index,
update_tensor_indices[i + 1]);
}
// Call scatter's computation.
auto reducer =
call_targets(scatter->called_computations()[0]->root_instruction());
if (scatter->unique_indices()) {
auto operand_elem =
ProvideParameter(root_computation, scatter, kScatterOperandIndex,
clamped_indices, call_targets, b)[0];
auto result_scalars = b.create<PureCallOp>(
reducer, llvm::ArrayRef({operand_elem, update_elem}));
SmallVector<Value> updated_operand;
updated_operand.reserve(num_outputs);
for (auto [tensor, value] :
llvm::zip(output_tensors, result_scalars.getResults())) {
updated_operand.push_back(
b.create<InsertOp>(value, tensor, clamped_indices));
// Extract slice offsets from scatter_indices operand, compute if the
// whole slice of scatter_update operand will fit into the output.
mlir::Value is_in_bounds =
b.create<mlir::arith::ConstantIntOp>(1, b.getI1Type());
SmallVector<Value, 4> indices{
llvm::ArrayRef(update_tensor_indices).drop_front()};
for (int i = 0; i < scatter_operand->shape().rank(); ++i) {
Value extracted_index = c0;
if (i < scatter_indices->shape().dimensions(1)) {
SmallVector<Value, 4> indices_tensor_indices = {
update_tensor_indices.front(), b.create<ConstantIndexOp>(i)};
extracted_index = ProvideParameter(
root_computation, scatter, kScatterIndicesIndex,
indices_tensor_indices, call_targets, b)[0];
if (extracted_index.getType() != b.getIndexType()) {
extracted_index = b.create<mlir::arith::IndexCastOp>(
b.getIndexType(), extracted_index);
}
}
return updated_operand;
is_in_bounds = b.create<AndIOp>(
is_in_bounds,
b.create<CmpIOp>(CmpIPredicate::sge, extracted_index, c0));
Value ub = b.create<ConstantIndexOp>(
scatter_operand->shape().dimensions(i) -
scatter_update->shape().dimensions(i + 1));
is_in_bounds = b.create<AndIOp>(
is_in_bounds,
b.create<CmpIOp>(CmpIPredicate::sle, extracted_index, ub));
indices[i] = b.create<AddIOp>(extracted_index, indices[i]);
}
return output_tensors;
// Call scatter's computation if is_in_bounds.
Value output_tensor = output_tensors.front();
Value predicated_update =
b.create<scf::IfOp>(
is_in_bounds,
[&](OpBuilder& then_builder, Location then_loc) -> void {
Value updated_output = EmitScatterComputation(
scatter, indices, update_elem, output_tensor,
root_computation, call_targets, b);
b.create<scf::YieldOp>(updated_output);
},
[&](OpBuilder& else_b, Location else_loc) {
b.create<scf::YieldOp>(output_tensor);
})
.getResult(0);
return {predicated_update};
});
b.create<ReturnOp>(scatter_result);
return absl::OkStatus();
Expand Down