Skip to content

Commit

Permalink
Simplify more reshapes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636241926
  • Loading branch information
jreiffers authored and tensorflower-gardener committed May 22, 2024
1 parent 878f357 commit 5942a7a
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized210) {
(d0, d1, d2, d3, d4, d5)[s0, s1] -> (
d0 floordiv 32 + s0 * 4,
d3 floordiv 128,
(d3 mod 128) * 64 + s1 + (d0 mod 32) * 2
(d0 mod 32) * 2 + s1 + (d3 mod 128) * 64
)
domain:
d0 in [0, 127]
Expand Down
135 changes: 91 additions & 44 deletions third_party/xla/xla/service/gpu/model/indexing_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ using mlir::getAffineConstantExpr;
using mlir::getAffineDimExpr;
using mlir::MLIRContext;

AffineExpr GetLhs(AffineExpr e) {
return mlir::cast<AffineBinaryOpExpr>(e).getLHS();
};
AffineExpr GetRhs(AffineExpr e) {
return mlir::cast<AffineBinaryOpExpr>(e).getRHS();
};

class AffineExprSimplifier {
public:
explicit AffineExprSimplifier(RangeEvaluator* range_evaluator)
Expand All @@ -76,6 +83,13 @@ class AffineExprSimplifier {
private:
std::optional<int64_t> GetConstantRhs(mlir::AffineExpr expr,
AffineExprKind kind);
std::pair<mlir::AffineExpr, int64_t> ExtractMultiplier(
mlir::AffineExpr expr) {
if (auto mul = GetConstantRhs(expr, AffineExprKind::Mul)) {
return {GetLhs(expr), *mul};
}
return {expr, 1};
}

// Simplifier for mod.
// - Rewrites (a * 100 + ...) % 100 to (...) % 100
Expand Down Expand Up @@ -144,9 +158,7 @@ AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) {
// = (c % b) * a
if (auto mul = GetConstantRhs(lhs_simplified, AffineExprKind::Mul);
mul && (m % *mul == 0)) {
return (mlir::cast<AffineBinaryOpExpr>(lhs_simplified).getLHS() %
(m / *mul)) *
*mul;
return (GetLhs(lhs_simplified) % (m / *mul)) * *mul;
}

int64_t extracted_constant = 0;
Expand Down Expand Up @@ -220,8 +232,7 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) {
// = (c // a) % b contract mod
if (auto mod = GetConstantRhs(lhs_simplified, AffineExprKind::Mod);
mod && (*mod % d == 0)) {
return mlir::cast<AffineBinaryOpExpr>(lhs_simplified).getLHS().floorDiv(d) %
(*mod / d);
return GetLhs(lhs_simplified).floorDiv(d) % (*mod / d);
}

// If the dividend's range has a single element, return its value.
Expand Down Expand Up @@ -253,8 +264,7 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) {
// one x, but we currently have no reason to do that.
if (*multiplier % d == 0) {
int64_t factor = *multiplier / d;
extracted =
extracted + mlir::cast<AffineBinaryOpExpr>(expr).getLHS() * factor;
extracted = extracted + GetLhs(expr) * factor;
// Remove from dividend.
return true;
}
Expand Down Expand Up @@ -306,10 +316,9 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) {
if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul);
multiplier && (*multiplier > 0) &&
((*multiplier % max_remaining_multiplier_gcd) == 0)) {
auto expr_lhs = mlir::cast<AffineBinaryOpExpr>(expr).getLHS();
partially_extracted =
partially_extracted +
expr_lhs * (*multiplier / max_remaining_multiplier_gcd);
GetLhs(expr) * (*multiplier / max_remaining_multiplier_gcd);
// Remove from dividend.
return true;
}
Expand Down Expand Up @@ -359,9 +368,8 @@ AffineExpr AffineExprSimplifier::RemoveSummands(
void AffineExprSimplifier::VisitSummands(
mlir::AffineExpr expr, const std::function<void(mlir::AffineExpr)>& visit) {
if (expr.getKind() == AffineExprKind::Add) {
auto add = mlir::dyn_cast<AffineBinaryOpExpr>(expr);
VisitSummands(add.getLHS(), visit);
VisitSummands(add.getRHS(), visit);
VisitSummands(GetLhs(expr), visit);
VisitSummands(GetRhs(expr), visit);
} else {
visit(expr);
}
Expand Down Expand Up @@ -389,13 +397,11 @@ int CompareExprs(AffineExpr a, AffineExpr b) {
case AffineExprKind::CeilDiv:
case AffineExprKind::Mul:
case AffineExprKind::Mod: {
auto a_bin = mlir::cast<AffineBinaryOpExpr>(a);
auto b_bin = mlir::cast<AffineBinaryOpExpr>(b);
auto lhs = CompareExprs(a_bin.getLHS(), b_bin.getLHS());
auto lhs = CompareExprs(GetLhs(a), GetLhs(b));
if (lhs != 0) {
return lhs;
}
return CompareExprs(a_bin.getRHS(), b_bin.getRHS());
return CompareExprs(GetRhs(a), GetRhs(b));
}
case AffineExprKind::Constant: {
a_value = mlir::cast<AffineConstantExpr>(a).getValue();
Expand Down Expand Up @@ -433,42 +439,83 @@ AffineExpr CanonicalizeOrder(AffineExpr in) {
AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Mul: {
auto binop = mlir::cast<AffineBinaryOpExpr>(expr);
auto lhs = SimplifyOnce(binop.getLHS());
auto rhs = SimplifyOnce(binop.getRHS());
auto lhs = SimplifyOnce(GetLhs(expr));
auto rhs = SimplifyOnce(GetRhs(expr));
return getAffineBinaryOpExpr(expr.getKind(), lhs, rhs);
}
case AffineExprKind::Add: {
auto binop = mlir::cast<AffineBinaryOpExpr>(expr);
auto lhs = SimplifyOnce(binop.getLHS());
auto rhs = SimplifyOnce(binop.getRHS());

// Rewrite `(x // c) * c + (x % c)` to `x`.
// This should also work with (a+b)+c.
auto rewrite_add = [&](AffineExpr a, AffineExpr b) -> AffineExpr {
if (auto mod = GetConstantRhs(a, AffineExprKind::Mod)) {
if (auto mul = GetConstantRhs(b, AffineExprKind::Mul); mod == mul) {
auto b_lhs = mlir::cast<AffineBinaryOpExpr>(b).getLHS();
if (auto div = GetConstantRhs(b_lhs, AffineExprKind::FloorDiv);
div == mul) {
auto x = mlir::cast<AffineBinaryOpExpr>(b_lhs).getLHS();
if (x == mlir::cast<AffineBinaryOpExpr>(a).getLHS()) {
return x;
}
}
}
// Rewrite `(x % c) * d + (x // c) * (c * d)` to `x * d`. We have to do it
// in this rather convoluted way because the MLIR simplifier sinks
// multiplications into summands.
SmallVector<std::pair<AffineExpr, int64_t /*multiplier*/>> mods;
SmallVector<std::pair<AffineExpr, int64_t /*multiplier*/>> divs;
SmallVector<AffineExpr> others;
bool changed = false;
VisitSummands(expr, [&](AffineExpr expr) {
AffineExpr simplified = SimplifyOnce(expr);
changed |= simplified != expr;
auto [lhs, multiplier] = ExtractMultiplier(simplified);
if (lhs.getKind() == AffineExprKind::Mod) {
mods.push_back({lhs, multiplier});
} else if (lhs.getKind() == AffineExprKind::FloorDiv) {
divs.push_back({lhs, multiplier});
} else {
others.push_back(simplified);
}
return nullptr;
};
});

if (auto rewritten = rewrite_add(lhs, rhs)) {
return rewritten;
// We never see large sums in practice, so there's no point building a
// hash map.
if (mods.size() * divs.size() >= 100) {
std::string s;
llvm::raw_string_ostream ss(s);
ss << expr;
LOG(WARNING) << "Unexpectedly large number of mods and divs in " << s
<< ". Please open an issue on GitHub at "
<< "https://github.com/openxla/xla.";
}
if (auto rewritten = rewrite_add(rhs, lhs)) {
return rewritten;

for (int mod_i = 0; mod_i < mods.size(); ++mod_i) {
auto [mod, mod_mul] = mods[mod_i];
auto mod_c = GetConstantRhs(mod, AffineExprKind::Mod);
if (!mod_c) continue;

for (int div_i = 0; div_i < divs.size(); ++div_i) {
auto [div, div_mul] = divs[div_i];
if (!div) continue; // Already erased.
if (GetLhs(mod) != GetLhs(div)) continue;

auto div_c = GetConstantRhs(div, AffineExprKind::FloorDiv);
if (div_mul % mod_mul) continue;
if (mod_c != div_c || (div_mul / mod_mul) != mod_c) continue;

others.push_back(GetLhs(mod) * mod_mul);
divs[div_i].first = nullptr;
mods[mod_i].first = nullptr;
changed = true;
break;
}
}

return getAffineBinaryOpExpr(expr.getKind(), lhs, rhs);
if (!changed) {
return expr;
}

AffineExpr result = mlir::getAffineConstantExpr(0, expr.getContext());
for (auto expr : others) {
result = result + expr;
}
for (auto [expr, mul] : mods) {
if (expr) {
result = result + (expr * mul);
}
}
for (auto [expr, mul] : divs) {
if (expr) {
result = result + (expr * mul);
}
}
return result;
}
case AffineExprKind::Mod:
return RewriteMod(mlir::cast<AffineBinaryOpExpr>(expr));
Expand Down
14 changes: 14 additions & 0 deletions third_party/xla/xla/service/gpu/model/indexing_map_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,20 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) {
)"));
}

TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) {
auto serialized_map =
"(d0, d1) -> ((d0 mod 8) * 128 + d1 + (d0 floordiv 8) * 1024)";
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap(serialized_map, &mlir_context_), {1024, 128}, {});
indexing_map.Simplify(GetIndexingMapForInstruction);
EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"(
(d0, d1) -> (d0 * 128 + d1)
domain:
d0 in [0, 1023]
d1 in [0, 127]
)"));
}

TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) {
// We have s0 * 128 in the mod, but s0 * 64 in the floordiv *.
auto serialized_map =
Expand Down

0 comments on commit 5942a7a

Please sign in to comment.