Skip to content

Commit

Permalink
MLIR emitters: Vectorize column reductions.
Browse files Browse the repository at this point in the history
Special thanks to github user lingzhi98 who experimented with this in
openxla/xla#11018.

I tried to make the logic as similar for vectorized and non-vectorized
reductions as I could. The vectorized logic looks like this:

- produce N reduced elements per thread, store the intermediate results in
  a vector V
- loop over the N elements of V, writing each one to shmem
- loop over N elements, reading them from shmem and writing the result to
  global memory

PiperOrigin-RevId: 636130464
  • Loading branch information
jreiffers authored and tensorflower-gardener committed May 22, 2024
1 parent 98a4c09 commit cd3853d
Show file tree
Hide file tree
Showing 13 changed files with 293 additions and 101 deletions.
3 changes: 2 additions & 1 deletion third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,6 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
Expand All @@ -880,7 +879,9 @@ cc_library(
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:VectorDialect",
],
)

Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:VectorDialect",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
Expand Down Expand Up @@ -214,6 +215,7 @@ cc_library(
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:ToLLVMIRTranslation",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
Expand Down Expand Up @@ -334,6 +336,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorTransforms",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
Expand Down Expand Up @@ -1297,10 +1299,25 @@ SmallVector<Value> EmitLoopNest(
mlir::function_ref<SmallVector<Value>(ValueRange /*iter_args*/,
ValueRange /*dim_values*/,
ValueRange /*symbol_values*/)>
create_body) {
create_body,
bool vectorize) {
SmallVector<Value, 4> lbs, ubs, steps;
GetLoopBoundsFromIndexingMap(b, indexing_map, &lbs, &ubs, &steps);

SmallVector<Value, 4> vector_inits;
if (vectorize) {
CHECK_EQ(indexing_map.GetSymbolBounds().back().lower, 0);
int vector_size = indexing_map.GetSymbolBounds().back().upper + 1;
vector_inits = iter_args_inits;
for (auto& init : vector_inits) {
if (!mlir::isa<mlir::ShapedType>(init.getType())) {
auto vector_ty = mlir::VectorType::get({vector_size}, init.getType());
init = b.create<mlir::vector::SplatOp>(vector_ty, init);
}
}
iter_args_inits = vector_inits;
}

scf::LoopNest loop_nest = scf::buildLoopNest(
b, b.getLoc(), lbs, ubs, steps, iter_args_inits,
[&](OpBuilder& nested_builder, Location loc, ValueRange symbol_values,
Expand All @@ -1313,7 +1330,28 @@ SmallVector<Value> EmitLoopNest(
[&](OpBuilder& then_builder, Location then_loc) -> void {
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(then_builder.getInsertionBlock());
auto results = create_body(iter_args, dim_values, symbol_values);
SmallVector<Value, 4> results;
if (vectorize) {
SmallVector<Value, 4> vector_args;
vector_args = iter_args;
// Extract the vector elements.
for (auto& init : vector_args) {
if (mlir::isa<mlir::VectorType>(init.getType())) {
init = b.create<mlir::vector::ExtractOp>(
init, symbol_values.back());
}
}
results = create_body(vector_args, dim_values, symbol_values);
// Insert the results.
for (auto [index, init] : llvm::enumerate(iter_args)) {
if (mlir::isa<mlir::VectorType>(init.getType())) {
results[index] = b.create<mlir::vector::InsertOp>(
results[index], iter_args[index], symbol_values.back());
}
}
} else {
results = create_body(iter_args, dim_values, symbol_values);
}
b.create<scf::YieldOp>(results);
},
[&](OpBuilder& else_b, Location else_loc) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,26 @@ mlir::Value CheckConstraints(const IndexingMap& map, mlir::ValueRange dims,

// Emits a loop nest over the entire domain of the indexing_map at a point
// `dim_values`.
// If `vectorize` is set, the loop essentially turns into multiple independent
// loops, and the results of all the loops are returned as a vector. The last
// symbol dimension is used as the vectorized dimension.
// If `vectorize` is set:
// - the body will still be called with scalars and should return scalars.
// - the loop for the last symbol in `indexing_map` will be vectorized
// - the symbol range should be [0, 2] or [0, 4] for vectorization to work.
// [0, 1] is supported and will have no effect. The lower bound must be 0.
// - all scalar results of `EmitLoopNest` will become vectors instead. Scalar
// inits will be initialized with a vector splat. Passing a vector init is
// supported.
// - Tensor arguments and results are unaffected.
llvm::SmallVector<mlir::Value> EmitLoopNest(
mlir::ImplicitLocOpBuilder& b, mlir::ValueRange dim_values,
mlir::ValueRange iter_args_inits, const IndexingMap& indexing_map,
mlir::function_ref<llvm::SmallVector<mlir::Value>(
mlir::ValueRange iter_args, mlir::ValueRange dim_values,
mlir::ValueRange symbol_values)>
create_body);
create_body,
bool vectorize = false);

// Same as EmitLoopNest, but the body building function can return an error
// which gets returned from EmitLoopNestWithStatus.
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ struct RewriteNonScalarConstants
mlir::LogicalResult matchAndRewrite(
mlir::arith::ConstantOp op,
mlir::PatternRewriter& rewriter) const override {
if (mlir::isa<mlir::VectorType>(op.getType())) {
return rewriter.notifyMatchFailure(op, "the op is a vector constant");
}
auto shaped_ty = mlir::dyn_cast<mlir::ShapedType>(op.getValue().getType());
// We only need to rewrite non-scalar constants.
if (!shaped_ty || shaped_ty.getNumElements() < 2) {
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" // from @llvm-project
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Arith/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
Expand Down Expand Up @@ -67,6 +68,7 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
patterns);
mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
mlir::populateFuncToLLVMConversionPatterns(type_converter, patterns);
mlir::populateVectorToLLVMConversionPatterns(type_converter, patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(type_converter,
patterns);
mlir::populateComplexToLLVMConversionPatterns(type_converter, patterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ limitations under the License.
#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
Expand Down Expand Up @@ -361,7 +362,8 @@ MlirFusionEmitterBase::CreateMLIRModule(
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
mlir::math::MathDialect, mlir::scf::SCFDialect,
mlir::mhlo::MhloDialect, mlir::gpu::GPUDialect,
mlir::NVVM::NVVMDialect, xla::gpu::XlaGpuDialect>();
mlir::vector::VectorDialect, mlir::NVVM::NVVMDialect,
xla::gpu::XlaGpuDialect>();
mlir::DialectRegistry registry;
mlir::func::registerInlinerExtension(registry);
mlir::registerBuiltinDialectTranslation(registry);
Expand Down Expand Up @@ -456,9 +458,10 @@ SmallVector<Value> MlirFusionEmitterBase::EmitThreadLoopNest(
const IndexingMap& indexing_map,
const std::function<
SmallVector<Value>(ValueRange outputs_tensors, ValueRange dim_values,
ValueRange symbol_values)>& create_body) const {
ValueRange symbol_values)>& create_body,
bool vectorize) const {
return mlir_converter::EmitLoopNest(b, EmitThreadAndBlockIds(b), outputs,
indexing_map, create_body);
indexing_map, create_body, vectorize);
}

absl::Status MlirFusionEmitterBase::EmitMlir(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,15 @@ class MlirFusionEmitterBase : public KernelFusionInterface {
// the symbol 0 as the outermost loop. The indices of the map's dimensions and
// symbols are passed to the lambda separately. The return values of the
// function are the updated outputs.
// For the meaning of `vectorize`, see the documentation of `EmitLoopNest` in
// elemental_hlo_to_mlir.h.
llvm::SmallVector<mlir::Value> EmitThreadLoopNest(
mlir::ImplicitLocOpBuilder& b, mlir::ValueRange outputs,
const IndexingMap& indexing_map,
const std::function<llvm::SmallVector<mlir::Value>(
mlir::ValueRange outputs, mlir::ValueRange dim_values,
mlir::ValueRange symbol_values)>& create_body) const;
mlir::ValueRange symbol_values)>& create_body,
bool vectorize = false) const;

mlir::Value EmitBlockId(mlir::ImplicitLocOpBuilder& builder, int dim) const;
mlir::Value EmitThreadId(mlir::ImplicitLocOpBuilder& builder, int dim) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,19 @@ module {

// -----

module {
func.func @vector_constant() -> vector<2xindex> {
%c1 = arith.constant dense<[1, 2]> : vector<2xindex>
func.return %c1 : vector<2xindex>
}
}

// vector constants should not be rewritten.
// CHECK: @vector_constant
// CHECK-NEXT: arith.constant

// -----

module {
func.func @complex_tensor_insert(
%arg0: tensor<10xcomplex<f32>>) -> tensor<10xcomplex<f32>> {
Expand Down
86 changes: 56 additions & 30 deletions third_party/xla/xla/service/gpu/fusions/reduction_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,37 +70,43 @@ int RowReductionGetRowsPerWarp(int reduced_dimension_size) {

int GetVectorSize(const HloFusionAnalysis& analysis,
const ReductionDimensions& reduction_dimensions,
int num_threads, Vector3 reduction_tiling) {
if (!reduction_dimensions.is_row_reduction) {
int num_threads, Vector3 reduction_tiling, bool for_mlir) {
// If the minor dimension is not divisible by 2, we can't currently vectorize.
int64_t minor_dim = reduction_dimensions.dimensions.back();
if (minor_dim % 2 != 0) {
return 1;
}

constexpr int kRowMinorReduced =
ReductionDimensions::kRowMinorReducedDimension;
if (reduction_dimensions.dimensions[kRowMinorReduced] % 2 != 0 ||
MayPreventVectorization(analysis.fusion())) {
// Only enable vectorization if all threads will still have work.
if (num_threads * 2 > minor_dim) {
return 1;
}

// Enabling vectorization if (number_threads * vector_size) is <=
// minor_reduced_dimension otherwise exist threads not doing any work.
if (num_threads * 2 > reduction_dimensions.dimensions[kRowMinorReduced]) {
if (MayPreventVectorization(analysis.fusion())) {
return 1;
}

const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(
&analysis.device_info().gpu_compute_capability());
if (cuda_cc == nullptr) return 1;
if (cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) return 2;
if (cuda_cc->IsAtLeast(se::CudaComputeCapability::PASCAL_)) {
return analysis.input_output_info().smallest_input_dtype_bits <= 32 &&
reduction_dimensions.dimensions[kRowMinorReduced] %
(reduction_tiling[kRowMinorReduced] * num_threads) ==
0
? 2
: 1;
if (reduction_dimensions.is_row_reduction) {
constexpr int kRowMinorReduced =
ReductionDimensions::kRowMinorReducedDimension;

const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(
&analysis.device_info().gpu_compute_capability());
if (cuda_cc == nullptr) return 1;
if (cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) return 2;
if (cuda_cc->IsAtLeast(se::CudaComputeCapability::PASCAL_)) {
return analysis.input_output_info().smallest_input_dtype_bits <= 32 &&
reduction_dimensions.dimensions[kRowMinorReduced] %
(reduction_tiling[kRowMinorReduced] *
num_threads) ==
0
? 2
: 1;
}
return 1;
}
return 1;

return for_mlir ? 2 : 1;
}

ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis,
Expand Down Expand Up @@ -301,14 +307,14 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis,
}

int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x,
reduction_tiling);
reduction_tiling, for_mlir);

absl::InlinedVector<int64_t, 4> num_threads{1, num_threads_y, num_threads_x};
absl::InlinedVector<int64_t, 4> tiled_shape{shape[0], shape[1],
shape[2] / vector_size};
absl::InlinedVector<int64_t, 4> tile_per_thread{
reduction_tiling[0], reduction_tiling[1],
reduction_tiling[2] / vector_size};
std::max<int64_t>(reduction_tiling[2] / vector_size, 1)};
if (for_mlir) {
// The indexing map simplifier does not currently handle this correctly,
// leading to loop bounds that are too large.
Expand All @@ -332,12 +338,27 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis,
// uses the thread ID as the coordinate.
tile_per_thread[2] = 1;
}
if (vector_size != 1) {
if (vector_size != 1 ||
(for_mlir && !reduction_dimensions.is_row_reduction)) {
num_threads.push_back(1); // The vector dimension is a loop.
tiled_shape.push_back(vector_size);
tile_per_thread.push_back(vector_size);
}

// The MLIR emitter treats the last tiled dimension as the number of parallel
// independent reductions per thread (to use vectorized loads). This is only
// needed for column reductions: row reductions can use vectorized loads for
// the same reduction.
// row reduction: [[a, b], [c, d]] -> [a + b, c + d]
// column reduction: [[a, b], [c, d]] -> [a + c, b + d]
// In both cases [a, b] are loaded together, but only in the column reduction
// they contribute to different result elements.
if (for_mlir && reduction_dimensions.is_row_reduction) {
num_threads.push_back(1);
tiled_shape.push_back(1);
tile_per_thread.push_back(1);
}

Tiling tiling(tiled_shape, tile_per_thread, num_threads,
/*loops_to_unroll=*/{false, false, true, false});
bool reduction_is_race_free = ReductionIsRaceFree(
Expand Down Expand Up @@ -404,13 +425,18 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
physical_shape, ctx));
}

mlir::SmallVector<mlir::AffineExpr> projected_dims{
block_offsets.getResult(kColMajorKept),
block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]};
std::vector<RangeVar> range_vars;
if (thread_ids.size() == 4) {
int vector_size = tiling_.GetThreadTileSize().back();
range_vars.push_back({0, vector_size - 1});
projected_dims.push_back(mlir::getAffineSymbolExpr(0, ctx));
}
IndexingMap projected_index(
mlir::AffineMap::get(
6, 0,
{block_offsets.getResult(kColMajorKept),
block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]},
ctx),
dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{});
mlir::AffineMap::get(6, range_vars.size(), projected_dims, ctx),
dimension_ranges, range_vars, /*rt_vars=*/{});

projected_index.AddConstraint(
mlir::getAffineDimExpr(
Expand Down
11 changes: 6 additions & 5 deletions third_party/xla/xla/service/gpu/fusions/reduction_base_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -556,22 +556,23 @@ TEST_F(ReductionTest, MlirColumnReduction) {
EXPECT_THAT(
fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(),
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
d3 floordiv 48,
(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> (
d3 floordiv 24,
d0 floordiv 32 + s1 * 32,
(d3 mod 48) * 32 + d0 mod 32
((d3 mod 24) * 32 + d0 mod 32) * 2 + s3
)
domain:
d0 in [0, 1023]
d1 in [0, 0]
d2 in [0, 0]
d3 in [0, 9215]
d3 in [0, 4607]
d4 in [0, 0]
d5 in [0, 0]
s0 in [0, 0]
s1 in [0, 1]
s2 in [0, 0]
(d3 mod 48) * 32 + d0 mod 32 in [0, 1535]
s3 in [0, 1]
(d3 mod 24) * 32 + d0 mod 32 in [0, 767]
d0 floordiv 32 + s1 * 32 in [0, 63]
)"));
}
Expand Down

0 comments on commit cd3853d

Please sign in to comment.