Skip to content

Commit

Permalink
[XLA:GPU][IndexAnalysis] Add DimVar, RangeVar and RTVar to IndexingMap.
Browse files Browse the repository at this point in the history
RTVar is a new type of symbol/variable, that is associated with a runtime value of an HLO instruction, e.g. it can be used to model DUS, gather, etc.

PiperOrigin-RevId: 616791671
  • Loading branch information
pifon2a authored and tensorflower-gardener committed Mar 19, 2024
1 parent d022a42 commit f05f34c
Show file tree
Hide file tree
Showing 15 changed files with 385 additions and 163 deletions.
30 changes: 15 additions & 15 deletions third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc
Expand Up @@ -172,27 +172,27 @@ IndexingMap KernelFusionInterface::GetDefaultThreadIdToOutputIndexingMap(
divisor *= output_shape.dimensions(dimension);
}

std::vector<Interval> dimension_ranges = {
{0, static_cast<int64_t>(launch_dims.thread_counts_per_block().x) - 1},
{0, static_cast<int64_t>(launch_dims.thread_counts_per_block().y) - 1},
{0, static_cast<int64_t>(launch_dims.thread_counts_per_block().z) - 1},
{0, static_cast<int64_t>(launch_dims.block_counts().x) - 1},
{0, static_cast<int64_t>(launch_dims.block_counts().y) - 1},
{0, static_cast<int64_t>(launch_dims.block_counts().z) - 1},
std::vector<DimVar> dim_vars = {
{{0, static_cast<int64_t>(launch_dims.thread_counts_per_block().x) - 1}},
{{0, static_cast<int64_t>(launch_dims.thread_counts_per_block().y) - 1}},
{{0, static_cast<int64_t>(launch_dims.thread_counts_per_block().z) - 1}},
{{0, static_cast<int64_t>(launch_dims.block_counts().x) - 1}},
{{0, static_cast<int64_t>(launch_dims.block_counts().y) - 1}},
{{0, static_cast<int64_t>(launch_dims.block_counts().z) - 1}},
};
std::vector<Interval> symbol_ranges;
std::vector<RangeVar> range_vars;
int64_t num_elements = ShapeUtil::ElementsIn(output_shape);
symbol_ranges.push_back(
{0, CeilOfRatio(num_elements,
static_cast<int64_t>(launch_dims.launch_bound()) *
unroll_factor) -
1});
symbol_ranges.push_back({0, unroll_factor - 1});
range_vars.push_back(
{{0, CeilOfRatio(num_elements,
static_cast<int64_t>(launch_dims.launch_bound()) *
unroll_factor) -
1}});
range_vars.push_back({0, unroll_factor - 1});
IndexingMap indexing_map(
indexing_context,
mlir::AffineMap::get(/*dimCount=*/6,
/*symbolCount=*/2, output_dims, mlir_context),
dimension_ranges, symbol_ranges);
dim_vars, range_vars, /*rt_vars=*/{});
// Remove the unroll_elem_id symbol if unrolling divides num_elements.
if (num_elements % unroll_factor == 0) {
indexing_map.AddConstraint(linear_index.replace({{unroll_elem_id, c0}}),
Expand Down
Expand Up @@ -576,8 +576,8 @@ Value CheckConstraints(const IndexingMap& map, ValueRange dims,
ret, CheckConstraint(ApplyAffineExpr(expression, dims, symbols, b),
range, b));
}
for (auto&& [index, range] : llvm::enumerate(map.GetDimensionRanges())) {
ret = b.create<AndIOp>(ret, CheckConstraint(dims[index], range, b));
for (auto&& [index, bound] : llvm::enumerate(map.GetDimensionBounds())) {
ret = b.create<AndIOp>(ret, CheckConstraint(dims[index], bound, b));
}
return ret;
}
Expand Down Expand Up @@ -1061,9 +1061,9 @@ void GetLoopBoundsFromIndexingMap(ImplicitLocOpBuilder& b,
SmallVectorImpl<Value>* steps) {
Value c1 = b.create<ConstantIndexOp>(1);

for (const Interval& range : indexing_map.GetSymbolRanges()) {
lbs->push_back(b.create<ConstantIndexOp>(range.lower));
ubs->push_back(b.create<ConstantIndexOp>(range.upper + 1));
for (const Interval& bound : indexing_map.GetSymbolBounds()) {
lbs->push_back(b.create<ConstantIndexOp>(bound.lower));
ubs->push_back(b.create<ConstantIndexOp>(bound.upper + 1));
// Note that this is not optimal, when there are mod constraints on symbols,
// e.g. for reduce-window. In that case we have to extract loop steps from
// the mod constraints.
Expand Down
13 changes: 7 additions & 6 deletions third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc
Expand Up @@ -100,15 +100,15 @@ struct RewriteAffineApply
mlir::affine::AffineApplyOp op,
mlir::PatternRewriter& rewriter) const override {
auto affine_map = op.getAffineMap();
std::vector<Interval> dim_ranges(affine_map.getNumDims());
std::vector<Interval> symbol_ranges(affine_map.getNumSymbols());
std::vector<DimVar> dim_ranges(affine_map.getNumDims());
std::vector<RangeVar> symbol_ranges(affine_map.getNumSymbols());

for (int i = 0; i < affine_map.getNumInputs(); ++i) {
if (auto range = GetRange(op->getOperand(i))) {
if (i >= dim_ranges.size()) {
symbol_ranges[i - dim_ranges.size()] = *range;
symbol_ranges[i - dim_ranges.size()] = RangeVar{*range};
} else {
dim_ranges[i] = *range;
dim_ranges[i] = DimVar{*range};
}
} else {
return rewriter.notifyMatchFailure(op, "failed to deduce range");
Expand All @@ -117,11 +117,12 @@ struct RewriteAffineApply

IndexingContext indexing_context(op->getContext());
IndexingMap map(&indexing_context, op.getAffineMap(), dim_ranges,
symbol_ranges);
symbol_ranges, /*rt_vars=*/{});
map.Simplify();
auto expr = map.GetAffineMap().getResult(0);

RangeEvaluator range_evaluator(dim_ranges, symbol_ranges, op->getContext());
RangeEvaluator range_evaluator(map.GetDimensionBounds(),
map.GetSymbolBounds(), op->getContext());
std::function<bool(mlir::AffineExpr)> can_be_lowered;
bool fits_32_bits = true;
can_be_lowered = [&](mlir::AffineExpr expr) {
Expand Down
12 changes: 6 additions & 6 deletions third_party/xla/xla/service/gpu/fusions/reduction_base.cc
Expand Up @@ -333,12 +333,12 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(

auto physical_shape = ShapeUtil::DeleteDimensions(hero->dimensions(),
hero->operand(0)->shape());
std::vector<Interval> dimension_ranges{
{0, tiling_.GetNumThreadsPerBlock() - 1},
std::vector<DimVar> dimension_ranges{
{{0, tiling_.GetNumThreadsPerBlock() - 1}},
{},
{},
{0, tiling_.GetNumBlocks() - 1},
{0, static_cast<int64_t>(groups_.grouped_roots.size() - 1)},
{{0, tiling_.GetNumBlocks() - 1}},
{{0, static_cast<int64_t>(groups_.grouped_roots.size() - 1)}},
{},
};

Expand All @@ -357,7 +357,7 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
mlir::AffineMap::get(
6, 0, block_offsets.getResult(kRowKept) + thread_ids[kRowKept],
mlir_context),
dimension_ranges, {});
dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{});
int rows_per_warp = GetRowsPerWarp();
if (rows_per_warp > 1) {
linear_index.AddConstraint(
Expand All @@ -379,7 +379,7 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
{block_offsets.getResult(kColMajorKept),
block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]},
mlir_context),
dimension_ranges, {});
dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{});

projected_index.AddConstraint(
mlir::getAffineDimExpr(
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
Expand Up @@ -123,9 +123,9 @@ std::optional<IndexingMap> MlirScatterFusion::ComputeThreadIdToInputIndexing(
{mlir::getAffineDimExpr(0, mlir_context),
mlir::getAffineSymbolExpr(0, mlir_context)},
mlir_context),
/*dim_ranges=*/RangesFromTensorSizes(scatter_update_shape.dimensions()),
/*symbol_ranges=*/
RangesFromTensorSizes({scatter_indices_shape.dimensions(1)})};
DimVarsFromTensorSizes(scatter_update_shape.dimensions()),
RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}),
/*rt_vars=*/{}};
auto scatter_indices_map = scatter_update_map * updates_to_indices_map;
scatter_indices_map.Simplify();
return scatter_indices_map;
Expand Down
10 changes: 6 additions & 4 deletions third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc
Expand Up @@ -207,8 +207,9 @@ IndexingMap GetSharedMemoryWriteIndexingMap(
thread_id_indexing.GetSymbolCount(),
{c0, th_x.floorDiv(32) + 4 * tile_sizes[loop_dim], th_x % 32},
mlir_context),
thread_id_indexing.GetDimensionRanges(),
thread_id_indexing.GetSymbolRanges(),
thread_id_indexing.GetDimVars(),
thread_id_indexing.GetRangeVars(),
thread_id_indexing.GetRTVars(),
thread_id_indexing.GetConstraints()};
shmem_write_indexing.Simplify();
return shmem_write_indexing;
Expand All @@ -222,8 +223,9 @@ IndexingMap GetSharedMemoryReadIndexingMap(
GetSharedMemoryWriteIndexingMap(thread_id_indexing, loop_dim);
return IndexingMap{thread_id_indexing.GetIndexingContext(),
write_indexing.GetAffineMap().getSubMap({0, 2, 1}),
write_indexing.GetDimensionRanges(),
write_indexing.GetSymbolRanges(),
write_indexing.GetDimVars(),
write_indexing.GetRangeVars(),
write_indexing.GetRTVars(),
write_indexing.GetConstraints()};
}

Expand Down
9 changes: 5 additions & 4 deletions third_party/xla/xla/service/gpu/model/coalescing_analysis.cc
Expand Up @@ -115,7 +115,7 @@ void FindAllIndices(const IndexingMap& thread_id_to_physical_index,
std::vector<AffineExpr>* symbols,
std::vector<int64_t>* indices) {
if (dim_id < thread_id_to_physical_index.GetDimensionCount()) {
Interval dim_range = thread_id_to_physical_index.GetDimensionRange(dim_id);
Interval dim_range = thread_id_to_physical_index.GetDimensionBound(dim_id);
for (int64_t dim_value = dim_range.lower; dim_value <= dim_range.upper;
++dim_value) {
dimensions->push_back(getAffineConstantExpr(dim_value, mlir_context));
Expand All @@ -127,7 +127,7 @@ void FindAllIndices(const IndexingMap& thread_id_to_physical_index,
}
if (symbol_id < thread_id_to_physical_index.GetSymbolCount()) {
Interval symbol_range =
thread_id_to_physical_index.GetSymbolRange(symbol_id);
thread_id_to_physical_index.GetSymbolBound(symbol_id);
for (int64_t symbol_value = symbol_range.lower;
symbol_value <= symbol_range.upper; ++symbol_value) {
symbols->push_back(getAffineConstantExpr(symbol_value, mlir_context));
Expand Down Expand Up @@ -232,8 +232,9 @@ bool IsCoalesced(const IndexingMap& thread_id_to_input_indexing_map,
IndexingMap thread_x_first_32_elements{
indexing_context,
AffineMap::get(1, 0, {thread_x_dim, c0, c0, c0, c0, c0}, mlir_context),
{Interval{0, 31}},
{}};
{DimVar{{0, 31}}},
/*range_vars=*/{},
/*rt_vars=*/{}};
IndexingMap thread_x_to_linearized_input =
thread_x_first_32_elements * thread_id_to_input_indexing_map;
thread_x_to_linearized_input.Simplify();
Expand Down
Expand Up @@ -97,8 +97,8 @@ int64_t GetIterationSpaceSize(const IndexingMap& indexing_map,
return num_iters;
};

return get_ranges_iteration_space_size(indexing_map.GetSymbolRanges()) *
get_ranges_iteration_space_size(indexing_map.GetDimensionRanges());
return get_ranges_iteration_space_size(indexing_map.GetSymbolBounds()) *
get_ranges_iteration_space_size(indexing_map.GetDimensionBounds());
}

EstimateRunTimeData
Expand Down
55 changes: 26 additions & 29 deletions third_party/xla/xla/service/gpu/model/indexing_analysis.cc
Expand Up @@ -151,15 +151,6 @@ HloInstructionIndexing ComputeInputToOutputBroadcastOpIndexing(
return HloInstructionIndexing::FromIndexingMaps({indexing_map});
}

std::vector<Interval> RangesFromUpperBounds(absl::Span<const int64_t> bounds) {
std::vector<Interval> dim_ranges;
dim_ranges.reserve(bounds.size());
for (int64_t dim : bounds) {
dim_ranges.push_back(Interval{0, dim - 1});
}
return dim_ranges;
}

HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing(
const HloConcatenateInstruction* concat,
IndexingContext* indexing_context) {
Expand All @@ -171,7 +162,7 @@ HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing(
// be adjusted for a particular operand_id.
mlir::MutableAffineMap affine_map =
AffineMap::getMultiDimIdentityMap(operand_0_dims.size(), mlir_context);
std::vector<Interval> dim_ranges = RangesFromUpperBounds(operand_0_dims);
std::vector<DimVar> dim_vars = DimVarsFromTensorSizes(operand_0_dims);

HloInstructionIndexing concat_indexing;
concat_indexing.indexing_maps.resize(concat->operand_count());
Expand All @@ -181,10 +172,10 @@ HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing(
for (const auto [operand_id, operand] : llvm::enumerate(concat->operands())) {
affine_map.setResult(concat_dim, concat_dim_expr - offset);
int64_t operand_concat_dim = operand->shape().dimensions()[concat_dim];
dim_ranges[concat_dim] = Interval{offset, offset + operand_concat_dim - 1};
dim_vars[concat_dim] = DimVar{{offset, offset + operand_concat_dim - 1}};
concat_indexing.indexing_maps[operand_id].insert(
IndexingMap(indexing_context, affine_map.getAffineMap(), dim_ranges,
/*symbol_ranges=*/{}));
IndexingMap(indexing_context, affine_map.getAffineMap(), dim_vars,
/*range_vars=*/{}, /*rt_vars=*/{}));
offset += operand_concat_dim;
}
return concat_indexing;
Expand Down Expand Up @@ -325,16 +316,16 @@ IndexingMap ComputeOutputToInputPadOpIndexingImpl(

std::vector<AffineExpr> exprs;
std::vector<std::pair<AffineExpr, Interval>> constraints;
std::vector<Interval> dimension_ranges;
std::vector<DimVar> dim_vars;
exprs.reserve(output_rank);
constraints.reserve(output_rank);
int64_t output_dim_id = 0;
for (const auto [output_dim, pad_low, pad_high, pad_interior] :
llvm::zip(output_dims, padding_low, padding_high, padding_interior)) {
AffineExpr dim_expr = getAffineDimExpr(output_dim_id, mlir_context);
dimension_ranges.push_back(
Interval{std::max(int64_t{0}, pad_low),
std::min(output_dim - 1, output_dim - 1 - pad_high)});
dim_vars.push_back(
{Interval{std::max(int64_t{0}, pad_low),
std::min(output_dim - 1, output_dim - 1 - pad_high)}});
if (pad_interior == 0) {
exprs.push_back(dim_expr - pad_low);
} else {
Expand All @@ -347,7 +338,10 @@ IndexingMap ComputeOutputToInputPadOpIndexingImpl(
return IndexingMap{
indexing_context,
AffineMap::get(output_rank, /*symbolCount=*/0, exprs, mlir_context),
dimension_ranges, /*symbol_ranges = */ {}, absl::MakeSpan(constraints)};
std::move(dim_vars),
/*range_vars = */ {},
/*rt_vars = */ {},
absl::MakeSpan(constraints)};
}

HloInstructionIndexing ComputeOutputToInputPadOpIndexing(
Expand Down Expand Up @@ -487,10 +481,11 @@ HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing(
padding_interior.reserve(rank);
padded_input_dimensions.reserve(rank);
SmallVector<AffineExpr, 4> exprs;
std::vector<Interval> dim_ranges, symbol_ranges;
std::vector<DimVar> dim_vars;
std::vector<RangeVar> range_vars;
exprs.reserve(rank);
dim_ranges.reserve(rank);
symbol_ranges.reserve(rank);
dim_vars.reserve(rank);
range_vars.reserve(rank);
for (const auto& [dim_id, window_config] :
llvm::enumerate(reduce_window->window().dimensions())) {
padding_low.push_back(window_config.padding_low());
Expand All @@ -507,8 +502,8 @@ HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing(
AffineExpr symbol_expr = getAffineSymbolExpr(dim_id, mlir_context);

exprs.push_back(symbol_expr + window_config.stride() * dim_expr);
dim_ranges.push_back(Interval{0, output_shape.dimensions(dim_id) - 1});
symbol_ranges.push_back(Interval{0, window_config.size() - 1});
dim_vars.push_back({Interval{0, output_shape.dimensions(dim_id) - 1}});
range_vars.push_back({Interval{0, window_config.size() - 1}});
}
// Indexing map for pad op that pads the input.
IndexingMap padded_input_indexing = ComputeOutputToInputPadOpIndexingImpl(
Expand All @@ -517,7 +512,7 @@ HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing(
// Indexing map for reduce-window, that does not do any padding.
IndexingMap reduce_window_indexing_no_padding(
indexing_context, AffineMap::get(rank, rank, exprs, mlir_context),
dim_ranges, symbol_ranges);
dim_vars, range_vars, /*rt_vars=*/{});

// Composed indexing.
IndexingMap inputs_indexing = ComposeIndexingMaps(
Expand Down Expand Up @@ -926,7 +921,8 @@ IndexingMap GetIndexingMapFromPhysicalLayoutToLogical(
const Shape& shape, IndexingContext* indexing_context) {
MLIRContext* mlir_context = indexing_context->GetMLIRContext();
if (shape.rank() == 0) {
return IndexingMap(indexing_context, AffineMap::get(mlir_context), {}, {});
return IndexingMap(indexing_context, AffineMap::get(mlir_context),
/*dim_vars=*/{}, /*range vars=*/{}, /*rt_vars=*/{});
}
return IndexingMap::FromTensorSizes(
indexing_context,
Expand All @@ -942,7 +938,8 @@ IndexingMap GetIndexingMapFromLogicalToPhysicalLayout(
const Shape& shape, IndexingContext* indexing_context) {
MLIRContext* mlir_context = indexing_context->GetMLIRContext();
if (shape.rank() == 0) {
return IndexingMap(indexing_context, AffineMap::get(mlir_context), {}, {});
return IndexingMap(indexing_context, AffineMap::get(mlir_context),
/*dim_vars=*/{}, /*range vars=*/{}, /*rt_vars=*/{});
}
return IndexingMap::FromTensorSizes(
indexing_context,
Expand Down Expand Up @@ -1000,14 +997,14 @@ IndexingMap GetIndexingMapForTiling(AffineMap block_offsets,
llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) {
offsets.push_back(block + thread);
}
std::vector<Interval> dimension_ranges{
{0, threads_per_block - 1}, {}, {}, {0, num_blocks - 1}, {}, {},
std::vector<DimVar> dimension_ranges{
{{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {},
};
auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(),
block_offsets.getNumSymbols(), offsets,
indexing_context->GetMLIRContext());
IndexingMap map{indexing_context, affine_map, dimension_ranges,
RangesFromUpperBounds(thread_tile_sizes)};
RangeVarsFromTensorSizes(thread_tile_sizes), /*rt_vars=*/{}};
for (int i = 0; i < tiled_shape.size(); ++i) {
map.AddConstraint(affine_map.getResult(i), {0, tiled_shape[i] - 1});
}
Expand Down
16 changes: 13 additions & 3 deletions third_party/xla/xla/service/gpu/model/indexing_context.cc
Expand Up @@ -15,12 +15,22 @@ limitations under the License.

#include "xla/service/gpu/model/indexing_context.h"

#include <utility>

#include "xla/service/gpu/model/indexing_map.h"

namespace xla {
namespace gpu {

IndexingContext::RTValsID IndexingContext::RegisterRTSymbol(
const HloInstruction* instr, IndexingMap indexing_map) {
return 0;
static RTVarID rt_var_count = 0;

RTVar IndexingContext::RegisterRTVar(RTVarData rt_var_data) {
rt_vars_registry_.insert(std::make_pair(rt_var_count, rt_var_data));
return RTVar{rt_var_count++};
}

RTVarData& IndexingContext::GetRTVarData(RTVarID id) {
return rt_vars_registry_.at(id);
}

} // namespace gpu
Expand Down

0 comments on commit f05f34c

Please sign in to comment.