New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir-hlo] Added BufferReuse Optimization. #48883
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be split into multiple PRs, one with just the Userange analysis and one with the buffer_reuse optimization?
|
||
def BufferReuse : FunctionPass<"buffer-reuse"> { | ||
let summary = "Reuses already allocated buffers to save allocation " | ||
"operations if all criteria are met."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just say "is provably safe" instead of "if all criteria are met"? IMO that's a bit more exact.
/// every alloc value. | ||
class UserangeAnalysis { | ||
public: | ||
using UseInterval = std::pair<size_t, size_t>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a brief comment on what the size_t
s in UseInterval
mean?
/// Computes the ID for the operation. If the operation contains operands | ||
/// which have read effects, the returning ID will be odd. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to have a richer encoding here instead of implicitly encoding information in the last bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. This part is still work in progress. I will change the PR of the UserangeAnalysis to a draft and work on this issue. Also I tried to split this PR into 2 parts. Part 1 can be found here: #48847.
/// Constructs an Userange builder. | ||
UserangeInfoBuilder(Liveness liveness, ValueSetT values, | ||
OperationListT opList) | ||
: values(values), opList(opList), liveness(liveness) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe std::move
these in?
for (auto iterA = a.begin(), endA = a.end(); | ||
iterA != endA && iterB != endB;) { | ||
// iterA is strictly before iterB => increment iterA. | ||
if (iterA->second < iterB->first) ++iterA; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be better to encapsulate the UserangeInterval
s in a class of its own and providing methods like strictlyBefore
?
// Usually, we would expect the case of iterB beeing strictly before iterA. | ||
// However, due to the initial assumption that all intervals of b are | ||
// included in some interval of a, we do not need to check if iterB is | ||
// striclty before iterA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strictly
for (auto iterA = a.begin(), endA = a.end(); | ||
iterA != endA && iterB != endB;) { | ||
// iterA is strictly before iterB => increment iterA. | ||
if (iterA->second < iterB->first) ++iterA; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if iterA
reaches the end after this increment? Or is that logically impossible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not possible. The prerequisite is that B is a "proper subset" of A. If iterA
is on the last element and IntervalVector b
is not subtracted yet, then iterA
cannot be before iterB
.
// iterB is at the start of iterA, but iterA has some values that go | ||
// beyond those of iterB. We have to set the lower bound of iterA to the | ||
// upper bound of iterB + 1 and increment iterB. | ||
// A(3, 100) - B(3, 5) => A(6,100) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So these intervals are end-inclusive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
} | ||
// iterB is in the middle of iterA. We have to split iterA and increment | ||
// iterB. | ||
// A(2, 10) B(5, 7) => (2, 4), (8, 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're missing a -
I think.
const IntervalVector& b) const { | ||
auto iterB = b.begin(); | ||
auto endB = b.end(); | ||
for (auto iterA = a.begin(), endA = a.end(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can some of the logic in the loop be factored into an operator-=
function on UserangeInterval
?
@dfki-albo Can you please check @sanjoy's comments and keep us posted ? Thanks! |
Operation* defOpA = a.getDefiningOp(); | ||
Operation* defOpB = b.getDefiningOp(); | ||
|
||
// If the alloc method or the number of operands is not the same the types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Cannot be -> might not be.
|
||
/// Checks if the types of the given values are compatible for a | ||
/// replacement. | ||
bool checkTypeCompatibility(Value a, Value b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't call this type. It is based on more than types. Maybe just checkReuseCompatibility
?
if (itemA == itemB || !checkTypeCompatibility(itemA, itemB)) continue; | ||
|
||
// Check if itemA can replace itemB. | ||
if (!userange.rangesInterfere(itemA, itemB)) continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this negated? If the useranges interfere, there can be no reuse, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. I'll add this to the comment and also move the negation into the method, because otherwise the name is misleading.
} | ||
++it; | ||
} | ||
if (it == potReuseVector.end()) potReuseVector.push_back(itemB); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not first search the insertion point it
and then always insert using insert
, even if it
points to the end?
// Create a list of values that can potentially be replaced for each value | ||
// in the useRangeMap. The potentialReuseMap maps each value to the | ||
// respective list. | ||
llvm::MapVector<Value, SmallVector<Value, 4>> potentialReuseMap; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this function is very long. Consider splitting into smaller helpers.
|
||
// Iterate over the potential reuses and check if they can still be | ||
// reused. | ||
for (Value* potReuseValue = potReuses->begin(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why a pointer? Just use Value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use std::remove_if
?
if (replacedSet.contains(*potReuseValue) || | ||
transitiveInterference(*potReuseValue, potReuses, | ||
actualReuseMap) || | ||
!userange.rangesInterfere(item, *potReuseValue)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am always confused why this is negated. If the reanges interfere, then no reuse is possible, no?
for (auto itReuseMap = potentialReuseMap.begin(); | ||
itReuseMap != potentialReuseMap.end();) { | ||
Value item = itReuseMap->first; | ||
SmallVector<Value, 4>* potReuses = &itReuseMap->second; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using a reference?
// The defining block of itemA has to dominate all uses of itemB. | ||
if (!dominatesAllUses(defOpBlock, itemB)) continue; | ||
|
||
// Insert itemB into the right place of the potReuseVector. The order of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be interesting to explore other orders here. This is essentially greedy but one could try prioritize reuses for certain cases. Like if one has a copy, then source and destination should preferably be reused.
@@ -26,6 +27,7 @@ int main(int argc, char **argv) { | |||
mlir::registerAllPasses(); | |||
mlir::mhlo::registerAllMhloPasses(); | |||
mlir::lmhlo::registerAllLmhloPasses(); | |||
mlir::registerAllTransformPasses(); | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these needed here and everywhere? Was that for debugging?
@dfki-albo Can you please resolve conflicts? Thanks! |
@@ -18,6 +18,17 @@ limitations under the License. | |||
|
|||
include "mlir/Pass/PassBase.td" | |||
|
|||
def BufferReuse : FunctionPass<"buffer-reuse"> { | |||
let summary = "Reuses already allocated buffers to save allocation " | |||
"operations if is provably safe."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if it is
@@ -30,6 +30,11 @@ struct TestUserangePass : public TestUserangeBase<TestUserangePass> { | |||
registry.insert<mlir::lmhlo::LmhloDialect>(); | |||
} | |||
|
|||
StringRef getArgument() const final { return "test-print-userange"; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
continue; | ||
|
||
// Get the defining block of itemA. | ||
Block *defOpBlock = itemA.isa<BlockArgument>() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can always use getParentBlock()
here. It returns the block of the defining op, as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, itemA
is an alloc, so it cannot be a block argument, right?
potReuseVector.insert(insertionPoint, itemB); | ||
} | ||
|
||
potentialReuseMap.insert( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need the full type here or would {itemA, potReuseVector}
suffice?
Value potReuseValue, SmallVector<Value, 4> &potReuses, | ||
llvm::MapVector<Value, DenseSet<Value>> &actualReuseMap) { | ||
return actualReuseMap.find(potReuseValue) != actualReuseMap.end() && | ||
llvm::any_of(actualReuseMap[potReuseValue], [&](Value vReuse) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: This does a double lookup. If you already use find, you can use the result of the find to access the found value.
llvm::MapVector<Value, DenseSet<Value>> &actualReuseMap) { | ||
return actualReuseMap.find(potReuseValue) != actualReuseMap.end() && | ||
llvm::any_of(actualReuseMap[potReuseValue], [&](Value vReuse) { | ||
return !std::count(potReuses.begin(), potReuses.end(), vReuse); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Use find here, as that stops after first occurrence.
return false; | ||
|
||
// If all operands are equal the types are compatible. | ||
for (auto const &pair : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not true. Consider memref<?x5xi32>
vs. memref<5x?xi32
. Also, you need to consider the basetype here, too.
As an extension, this could also work on the size of the element type. Not in this PR, please leave a TODO for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was a bit imprecise. Please fix the issue with partially static types and also consider the basetype.
An extension to support reuse for elementtypes of the same size can be a TODO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use std::equal
?
/// A Fixpoint iteration over the potential reuses to compute the actual | ||
/// reuses. | ||
llvm::MapVector<Value, DenseSet<Value>> computeActualReuse( | ||
llvm::MapVector<Value, SmallVector<Value, 4>> &potentialReuseMap) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having read this type so often, maybe give it a name?
llvm::MapVector<Value, DenseSet<Value>> &actualReuseMap) { | ||
for (auto &potReuser : potentialReuseMap) { | ||
Value item = potReuser.first; | ||
SmallVector<Value, 4> potReuses = potReuser.second; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this copies the vector.
return false; | ||
|
||
// If all operands are equal the types are compatible. | ||
for (auto const &pair : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was a bit imprecise. Please fix the issue with partially static types and also consider the basetype.
An extension to support reuse for elementtypes of the same size can be a TODO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the code formatted differently in this PR? Is that an accidental change?
Please fix the comment and use std::equal. I will run presubmits on this change anyway, though.
defOpA->getNumOperands() != defOpB->getNumOperands()) | ||
return false; | ||
|
||
// TODO: Fix for memref<?x5xi32> vs memref<5x?xi32>, also consider the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now fixed.
return false; | ||
|
||
// If all operands are equal the types are compatible. | ||
for (auto const &pair : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use std::equal
?
tensorflow/compiler/mlir/hlo/BUILD
Outdated
deps = [ | ||
":hlo", | ||
":transforms_pass_inc_gen", | ||
"@llvm-project//mlir:MemRefDialect", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are ordered incorrect. See the builder log.
Imported from GitHub PR tensorflow/tensorflow#48883 In this PR, we want to introduce a new optimization on reusing already allocated buffers to save memory consumption. The optimization consists of two steps. First, for each buffer, we find a list of buffers that are potential reuses. A possible reuse has the following properties: - the types are compatible - no interference in the `UserangeAnalysis` (#48847) - the dominance is still given The second step is a fixpoint iteration over the potential reuses. This step is divided into two substeps: - try to assign possible reuses for each buffer - update the potential reuses based on the assignments from step 1. After the distribution of all possible reusable buffers is done, they are actually replaced. ``` For example: Result: func @simpleReuse(%arg0: i1) { func @simpleReuse(%arg0: i1) { %0 = alloc() %0 = alloc() %1 = alloc() cond_br %arg0, ^bb1, ^bb2 cond_br %arg0, ^bb1, ^bb2 ^bb1: ^bb1: use(%0) use(%0) br ^bb3 br ^bb3 ^bb2: ^bb2: use(%1) use(%0) br ^bb3 br ^bb3 ^bb3: ^bb3: return return } } ``` In this simple example `%1` can be replaced with `%0`, because all requirements mentioned above are fulfilled. This PR is a follow up to #48847, in which we introduced the `UserangeAnalysis`. Copybara import of the project: -- ba20a66f43af9a4a6d2639116de768dd050016a8 by Alexander Bosch <Alexander.Bosch@dfki.de>: PR #48847 -- d7b1461fa31c2e847d6eaf087cefd30a8fcdb90b by Alexander Bosch <Alexander.Bosch@dfki.de>: Implementation of a BufferReuse Optimization. -- 5dfaf703438c4260b0edc0334f0b0428b47cca5c by Alexander Bosch <Alexander.Bosch@dfki.de>: Addressed reviewers comments. -- 299f161e3c8630b577fb529b1662f8c9f6395fca by Alexander Bosch <Alexander.Bosch@dfki.de>: Addressed reviewer comments. -- bc43c5670dd1b52441c8d11e7863aad3b8a01c85 by Alexander Bosch <Alexander.Bosch@dfki.de>: Fixed false pass registration. -- d28e1964a13b0cd8f60c4586e0e93c1d670b8f3d by Alexander Bosch <Alexander.Bosch@dfki.de>: Resolved conflicts. -- 51ebd09c55b8da21421646a6c10a23995cbd0d2a by Alexander Bosch <Alexander.Bosch@dfki.de>: Rebased with underlying PR #48847. -- 5ee27a13cda451e0cf520951fc2e08450e4984ec by Alexander Bosch <Alexander.Bosch@dfki.de>: Addressed reviewers comments. -- 456af48684a1bf1d549e244340afc639e804051d by Alexander Bosch <Alexander.Bosch@dfki.de>: Fixed a bug in checkReuseCompatibility. -- cff3fa094a6928d24f36ed9432176287bd396800 by Alexander Bosch <Alexander.Bosch@dfki.de>: Fixed formatting issues and addressed reviewers comments. -- 19928064d51c70aca9f5031cc743ea3e39a20541 by Alexander Bosch <Alexander.Bosch@dfki.de>: Removed anti-pattern. PiperOrigin-RevId: 385111737
In this PR, we want to introduce a new optimization on reusing already allocated buffers to save memory consumption. The optimization consists of two steps. First, for each buffer, we find a list of buffers that are potential reuses. A possible reuse has the following properties:
UserangeAnalysis
([mlir-hlo] Added Userange Analysis for Buffers. #48847)The second step is a fixpoint iteration over the potential reuses. This step is divided into two substeps:
After the distribution of all possible reusable buffers is done, they are actually replaced.
In this simple example
%1
can be replaced with%0
, because all requirements mentioned above are fulfilled.This PR is a follow up to #48847, in which we introduced the
UserangeAnalysis
.