Skip to content

Commit

Permalink
[XLA:GPU][MLIR-based emitters] Implement scatter with atomics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613567580
  • Loading branch information
pifon2a authored and tensorflower-gardener committed Mar 7, 2024
1 parent 554d0f1 commit 0fac627
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 60 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ cc_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TensorDialect",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ limitations under the License.
#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/TypeRange.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
Expand Down Expand Up @@ -80,7 +81,9 @@ namespace {

using llvm::SmallVector;
using llvm::SmallVectorImpl;
using mlir::Block;
using mlir::ImplicitLocOpBuilder;
using mlir::IRMapping;
using mlir::Location;
using mlir::OpBuilder;
using mlir::Value;
Expand Down Expand Up @@ -1134,6 +1137,25 @@ mlir::Value ClampIndex(mlir::Value index, bool is_unsigned, int64_t high,
return index;
}

SmallVector<Value, 2> InlineBlock(OpBuilder& builder, Block& src_block,
ValueRange mapped_args) {
IRMapping mapping;
for (auto [from, to] : llvm::zip(src_block.getArguments(), mapped_args)) {
mapping.map(from, to);
}
for (auto& op : src_block.without_terminator()) {
builder.clone(op, mapping);
}
auto* terminator = src_block.getTerminator();
SmallVector<Value, 2> mapped_results;

mapped_results.reserve(terminator->getResults().size());
for (mlir::Value result : src_block.getTerminator()->getOperands()) {
mapped_results.push_back(mapping.lookup(result));
}
return mapped_results;
}

} // namespace mlir_converter
} // namespace gpu
} // namespace xla
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,17 @@ llvm::SmallVector<mlir::Value> EmitLoopNest(
mlir::ValueRange iter_args, mlir::ValueRange dim_values,
mlir::ValueRange symbol_values)>& create_body);

// Clamps
// Clamps `index` to [0, high] boundaries.
mlir::Value ClampIndex(mlir::Value index, bool is_unsigned, int64_t high,
mlir::ImplicitLocOpBuilder& b);

// Inlines `src_block` using `mapped_args` to initialize IRMapping from the
// block arguments of `src_block` to `mapped_args`. Return remapped values of
// the terminator.
mlir::SmallVector<mlir::Value, 2> InlineBlock(mlir::OpBuilder& builder,
mlir::Block& src_block,
mlir::ValueRange mapped_args);

} // namespace mlir_converter
} // namespace gpu
} // namespace xla
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,5 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
37 changes: 37 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,60 @@ mlir::LogicalResult PureCallOp::verifySymbolUses(
return mlir::success();
}

//===----------------------------------------------------------------------===//
// AllocateSharedOp
//===----------------------------------------------------------------------===//

void AllocateSharedOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
setNameFn(getResult(), "shmem");
}

//===----------------------------------------------------------------------===//
// AtomicRMWOp
//===----------------------------------------------------------------------===//

void AtomicRMWOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
setNameFn(getResult(), "atomic_rmw");
}

using mlir::OpBuilder;
using mlir::OperationState;
using mlir::RankedTensorType;
using mlir::Region;
using mlir::Type;
using mlir::Value;
using mlir::ValueRange;

void AtomicRMWOp::build(OpBuilder &builder, OperationState &result,
Value tensor, ValueRange ivs) {
OpBuilder::InsertionGuard g(builder);
result.addOperands(tensor);
result.addOperands(ivs);
result.addTypes(tensor.getType());

auto tensor_type = llvm::cast<RankedTensorType>(tensor.getType());
Region *body = result.addRegion();
builder.createBlock(body);
body->addArgument(tensor_type.getElementType(), tensor.getLoc());
}

//===----------------------------------------------------------------------===//
// PureCallOp
//===----------------------------------------------------------------------===//

void PureCallOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
for (auto result : getResults()) {
setNameFn(result, "pure_call");
}
}

//===----------------------------------------------------------------------===//
// SyncThreadsOp
//===----------------------------------------------------------------------===//

void SyncThreadsOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
for (auto result : getResults()) {
Expand Down
14 changes: 14 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ def XLAGPU_AtomicRMWOp : XLAGPU_Op<"atomic_rmw",
// the updated value.
let regions = (region SizedRegion<1>:$computation);

let skipDefaultBuilders = 1;
let builders = [OpBuilder<(ins "mlir::Value":$memref, "mlir::ValueRange":$ivs)>];

let extraClassDeclaration = [{
mlir::Block* getBody() { return &getComputation().front(); }
mlir::OpBuilder getBodyBuilder() {
return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
}
// The value stored in tensor[ivs].
mlir::Value getCurrentValue() {
return getRegion().getArgument(0);
}
}];

let assemblyFormat = [{
$input `[` $indices `]` `:` type($input) $computation attr-dict
}];
Expand Down
129 changes: 85 additions & 44 deletions third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
Expand All @@ -49,8 +51,14 @@ namespace gpu {
namespace {

using llvm::SmallVector;
using mlir::Location;
using mlir::OpBuilder;
using mlir::Value;
using mlir::ValueRange;
using mlir::arith::AddIOp;
using mlir::arith::AndIOp;
using mlir::arith::CmpIOp;
using mlir::arith::CmpIPredicate;
using mlir::arith::ConstantIndexOp;
using mlir::func::ReturnOp;
using mlir::tensor::InsertOp;
Expand All @@ -59,22 +67,19 @@ using mlir_converter::CallTargetProvider;
using mlir_converter::PartitionedComputations;
using mlir_converter::ProvideParameter;

namespace scf = ::mlir::scf;

} // namespace

bool MlirScatterFusion::IsSupported(const HloFusionAnalysis& analysis) {
auto* scatter = Cast<HloScatterInstruction>(analysis.fusion_heroes().front());
if (!scatter->unique_indices()) {
LOG(ERROR) << "MlirScatterFusion with atomics is not yet implemented";
return false;
}
if (scatter->scatter_operand_count() != 1) {
LOG(ERROR) << "Variadic scatter is not supported like in the legacy "
"emitter, although it is possible to make it work when the "
"indices are unique.";
return false;
}
// Do not enable it for now.
return false;
return true;
}

std::optional<IndexingMap> MlirScatterFusion::ComputeThreadIdToOutputIndexing(
Expand Down Expand Up @@ -137,6 +142,33 @@ MlirScatterFusion::GetInstructionsWithCustomCodegen(
return {analysis_.fusion_heroes()[0]};
}

mlir::Value EmitScatterComputation(
const HloInstruction* scatter, ValueRange indices, Value update_elem,
Value output_tensor,
const mlir_converter::PartitionedComputation& root_computation,
const mlir_converter::CallTargetProvider& call_targets,
mlir::ImplicitLocOpBuilder& b) {
constexpr int kScatterOperandIndex = 0;
auto reducer =
call_targets(scatter->called_computations()[0]->root_instruction());
if (scatter->unique_indices()) {
auto operand_elem =
ProvideParameter(root_computation, scatter, kScatterOperandIndex,
indices, call_targets, b)[0];
auto reduced_val = mlir_converter::InlineBlock(
b, reducer.getBody().front(), {operand_elem, update_elem})[0];

return b.create<InsertOp>(reduced_val, output_tensor, indices);
}
auto atomic_rmw = b.create<AtomicRMWOp>(output_tensor, indices);
mlir::OpBuilder body_builder = atomic_rmw.getBodyBuilder();
auto reduced_val = mlir_converter::InlineBlock(
body_builder, reducer.getBody().front(),
{atomic_rmw.getCurrentValue(), update_elem})[0];
body_builder.create<xla::gpu::YieldOp>(reducer->getLoc(), reduced_val);
return atomic_rmw->getResult(0);
}

// The scatter has to be canonicalized with `scatter_simplifier` pass.
absl::Status MlirScatterFusion::EmitEntryFunction(
const PartitionedComputations& computations,
Expand All @@ -150,6 +182,7 @@ absl::Status MlirScatterFusion::EmitEntryFunction(
scatter->operand(kScatterOperandIndex);
const HloInstruction* scatter_indices =
scatter->operand(kScatterIndicesIndex);
const HloInstruction* scatter_update = scatter->operand(kScatterUpdateIndex);

mlir::MLIRContext* mlir_context = entry_function.getContext();
auto thread_id_to_update_map =
Expand All @@ -165,13 +198,9 @@ absl::Status MlirScatterFusion::EmitEntryFunction(
mlir::ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function);
b.setInsertionPointToStart(entry_function.addEntryBlock());

int num_inputs = fusion.fused_instructions_computation()->num_parameters();
int num_outputs = entry_function.getArguments().size() - num_inputs;
auto output_tensor_args =
entry_function.getArguments().drop_front(num_inputs);
SmallVector<Value> result_tensors{output_tensor_args.begin(),
output_tensor_args.end()};
SmallVector<Value> result_tensors{entry_function.getArguments().back()};
auto c0 = b.create<ConstantIndexOp>(0);

auto scatter_result = EmitThreadLoopNest(
b, result_tensors, thread_id_to_update_map,
[&](ValueRange output_tensors, ValueRange dim_values,
Expand All @@ -185,40 +214,52 @@ absl::Status MlirScatterFusion::EmitEntryFunction(
update_tensor_indices, call_targets, b)
.front();

// Extract and clamp indices.
SmallVector<Value, 4> clamped_indices(scatter_operand->shape().rank(),
c0);
for (int i = 0; i < scatter_indices->shape().dimensions(1); ++i) {
SmallVector<Value, 4> indices_tensor_indices = {
update_tensor_indices.front(), b.create<ConstantIndexOp>(i)};
auto index =
ProvideParameter(root_computation, scatter, kScatterIndicesIndex,
indices_tensor_indices, call_targets, b)[0];
index = mlir_converter::ClampIndex(
index, /*is_unsigned=*/false,
scatter_operand->shape().dimensions(i), b);
index = b.create<mlir::arith::AddIOp>(index,
update_tensor_indices[i + 1]);
}
// Call scatter's computation.
auto reducer =
call_targets(scatter->called_computations()[0]->root_instruction());
if (scatter->unique_indices()) {
auto operand_elem =
ProvideParameter(root_computation, scatter, kScatterOperandIndex,
clamped_indices, call_targets, b)[0];
auto result_scalars = b.create<PureCallOp>(
reducer, llvm::ArrayRef({operand_elem, update_elem}));
SmallVector<Value> updated_operand;
updated_operand.reserve(num_outputs);
for (auto [tensor, value] :
llvm::zip(output_tensors, result_scalars.getResults())) {
updated_operand.push_back(
b.create<InsertOp>(value, tensor, clamped_indices));
// Extract slice offsets from scatter_indices operand, compute if the
// whole slice of scatter_update operand will fit into the output.
mlir::Value is_in_bounds =
b.create<mlir::arith::ConstantIntOp>(1, b.getI1Type());
SmallVector<Value, 4> indices{
llvm::ArrayRef(update_tensor_indices).drop_front()};
for (int i = 0; i < scatter_operand->shape().rank(); ++i) {
Value extracted_index = c0;
if (i < scatter_indices->shape().dimensions(1)) {
SmallVector<Value, 4> indices_tensor_indices = {
update_tensor_indices.front(), b.create<ConstantIndexOp>(i)};
extracted_index = ProvideParameter(
root_computation, scatter, kScatterIndicesIndex,
indices_tensor_indices, call_targets, b)[0];
if (extracted_index.getType() != b.getIndexType()) {
extracted_index = b.create<mlir::arith::IndexCastOp>(
b.getIndexType(), extracted_index);
}
}
return updated_operand;
is_in_bounds = b.create<AndIOp>(
is_in_bounds,
b.create<CmpIOp>(CmpIPredicate::sge, extracted_index, c0));
Value ub = b.create<ConstantIndexOp>(
scatter_operand->shape().dimensions(i) -
scatter_update->shape().dimensions(i + 1));
is_in_bounds = b.create<AndIOp>(
is_in_bounds,
b.create<CmpIOp>(CmpIPredicate::sle, extracted_index, ub));
indices[i] = b.create<AddIOp>(extracted_index, indices[i]);
}
return output_tensors;
// Call scatter's computation if is_in_bounds.
Value output_tensor = output_tensors.front();
Value predicated_update =
b.create<scf::IfOp>(
is_in_bounds,
[&](OpBuilder& then_builder, Location then_loc) -> void {
Value updated_output = EmitScatterComputation(
scatter, indices, update_elem, output_tensor,
root_computation, call_targets, b);
b.create<scf::YieldOp>(updated_output);
},
[&](OpBuilder& else_b, Location else_loc) {
b.create<scf::YieldOp>(output_tensor);
})
.getResult(0);
return {predicated_update};
});
b.create<ReturnOp>(scatter_result);
return absl::OkStatus();
Expand Down

0 comments on commit 0fac627

Please sign in to comment.