-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] OptimizeLDSUsage pass #3730
base: main
Are you sure you want to change the base?
Conversation
int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt); | ||
int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt); | ||
if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) { | ||
int LDSUsage = std::max(tmpCvtLDS, newCvtLDS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oplavsic
I've changed this part of the algorithm: https://github.com/openai/triton/pull/3730/files#diff-0d63e5cd9cf58151489fd9a5206b43a0902939004e58f3a7ec5258fa7d473267L227
Was it crucial?
To clarify, what this PR is doing: At the moment we have an optimization in Current approach can not optimize convert_layout in hopper flash attention test, so LDS overflows.
First item is needed, because old set of intermediate layouts was not able to optimize conversions found int hopper FA. Second item is needed to generalize optimization. For example, take a look at this example:
%1 consumes 16 KB of LDS, %2 requires ~64KB of lds for a scratch buffer. P.s. I had some concerns that new optimization can affect existing benchmarks. I had an offline conversation with author of original optimization (@oplavsic) and we decided that best to leave old optimization functionally same, but move some functions in common place and make them parameterizable. |
* -> | ||
* %1 = cvtOp %0 (srcLayout -> dstLayout) | ||
* %2 = cvtOp %0 (srcLayout -> tmpLayout) | ||
* %3 = cvtOp %1 (tmpLayout -> dstLayout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be %3 = cvtOp %2
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function creates two cvtOps based on a given cvtOps. Could you be more specific about which cvtOp is the new one and which is the old one in the comment?
// LDS reduction is possible by changing the shape of WarpsPerCta attribute in | ||
// mfma layout. The implicit LDS usage of cvt(mfma->blocked) op depends on the | ||
// number of warps per CTA that mfma layout uses along x dimension and block | ||
// layout uses across y dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little confusing whether x refers to the row or column. We can use dim 0 and dim 1 instead.
// LDS usage of this op is roughly calculated as: | ||
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type) | ||
// LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C, | ||
// where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is 32 hardcoded? Is it assuming mfma32 is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, I did not look deep into this comment, just copied it from original algorithm.
It was implemented a log ago, we probably had only mfma32 at the time.
I'll take a closer look and adjust.
for (int i = 0; i < tmpLayouts.size(); i++) { | ||
auto tmpLayout = tmpLayouts[i]; | ||
std::tie(tmpCvt, newEpilogueCvt) = | ||
createNewConvertOps(builder, cvtOp, tmpLayout); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this loop, we only want to know the index of the tmpLayout that gives us the min LDS usage. Do we really need to create the cvtOps and erase them at the end of each iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This creation/deletion is needed because algorithm use getScratchConfigForCvtLayout(ConvertLayoutOp unsigned&, unsigned&)
function from Allocation.cpp
to estimate LDS usage.
I can introduce new interface, so we can avoid these redundant stuff.
* @return mapping from operation to list of live LDS buffers | ||
*/ | ||
std::map<mlir::Operation *, SmallVector<Allocation::BufferId>> | ||
analyzeBufferLiveness(FunctionOpInterface func, const Allocation *allocations) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not AMD specific. Maybe we should put it in Analysis/Allocation.cpp?
} | ||
|
||
SmallVector<triton::gpu::ConvertLayoutOp> | ||
findLDSBottleneck(ModuleAllocation &allocAnalysis, FunctionOpInterface func) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also put this to the common part since it can benefit NV path. But after realizing NV GPUs have pretty large shared memory ....
@binarman I have a question regarding
In this example, |
|
||
namespace { | ||
|
||
constexpr int LDSSize = 65536; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we not hardcode it but pass it from the front end?
@zhanglx13 about Condition is filters out cases which will definitely overflow LDS and there are no early exit. |
yes, at least the early return condition needs to be removed |
Now I see, I've missed this early return, thank you! |
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { | ||
tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} { | ||
%1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> | ||
%2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I forgot to mention that I think this cvtOp is decomposed just because it uses more than 64 KB of LDS since padding is used. Therefore, this test does not test the functionality that a cvtOp could still be decomposed even it uses less than 64 KB LDS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added new test: it uses fp16 instead of fp32, so cvt scratch buffer is x2 smaller
74d3bad
to
ada48d1
Compare
third_party/amd/backend/compiler.py
Outdated
@@ -147,6 +147,8 @@ def make_llir(src, metadata, options): | |||
pm = ir.pass_manager(mod.context) | |||
pm.enable_debug() | |||
amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm) | |||
lds_size = 65536 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, where to place code choosing LDS size, so it is plain constant at this point.
Let's introduce some interface in later PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be convenient to rebase onto Lei's PR #3808
(coverting to draft as we chatted--need to first get all issues addressed from AMD side before making it as open) |
@antiagainst @zhanglx13 |
namespace triton { | ||
namespace AMD { | ||
|
||
constexpr int kPtrBitWidth = 64; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need to hardcode the pointer bitwidth? Can we just use inline constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is copied from Allocation.cpp
(it is not part of public interface).
Maybe I can actually take this part in some public interface, for example in Analysis/Utility
module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what I was talking about: binarman#6
res.LDS = std::numeric_limits<typeof(res.LDS)>::max(); | ||
|
||
triton::gpu::ConvertLayoutOp tmpCvt; | ||
triton::gpu::ConvertLayoutOp newEpilogueCvt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above three lines are not used.
threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; | ||
auto order = triton::gpu::getOrder(srcEnc); | ||
auto layoutCTA = triton::gpu::getCTALayout(srcEnc); | ||
auto fallbackLayout = triton::gpu::BlockedEncodingAttr::get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- For this fallbackLayout, all the components, except warpsPerCTA, are loop invariants. Maybe we can create a base BlockLayout out of the loop and use
createTmpLayout(blockEnc, warpsPerCTA)
inside the loop to update the warpsPerCTA only? - Why is 8 chosen in
warpSize / 8
? - In general, why we need this fallbackLayout? Is it covered by either srcEnc or dstEnc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Why is 8 chosen in warpSize / 8
For wave64 it will be [8, 8], for wave32 it will be [4, 8]. This is done to make layout tile "square", so no dim size of minimal tile is dominating.
- In general, why we need this fallbackLayout? Is it covered by either srcEnc or dstEnc?
In some cases different warpsPerCTA of src or dst layout is not enough to reduce LDS usage, but some other layouts can be appropriate. These fallback layouts are designed to have as compact tile as possible, i.e. elementsPerThread = [1, ... 1]
, and threadsPerWarp
are as "square" as possible.
I believe, that in most cases fallback layout will be chosen as a temporary layout. This could be non optimal in terms of performance, but it is fine, because without this transformation kernel will not compile at all.
return; | ||
} | ||
|
||
triton::gpu::ConvertLayoutOp tmpCvt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we using this tmpCvt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, will rewrite this part as done in DecomposeUnsupportedConversions
pass.
if (offset + size > LDSSize) { | ||
auto maxScratchBufferSize = computeMaxScratchBufferSize( | ||
cvtOp, funcAnalysis, liveBuffers[cvtOp]); | ||
candidates.push_back({cvtOp, maxScratchBufferSize}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is very confusing to me.
- Why do we need
opBuffer
? Just to check it's valid? - Does
liveBuffers[cvtOp]
includeopBuffer
? To put it another way, does one of thebufId
's for the scratch buffer allocated for this cvtOp? - It seems to me that this function assumes that there is at most one extra buffer that can overlap with the buffer for this cvtOp? If there are more live buffers that overlap with this cvtOp, we should still only push
cvtOp
intocandidates
once, but computemaxScratchBufferSize
based on all overlapped live buffers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need opBuffer? Just to check it's valid?
Sorry, this is reminder after refactoring, I used to pass it to computeMaxScratchBufferSize
, but then start compute it inside function.
Does liveBuffers[cvtOp] include opBuffer? To put it another way, does one of the bufId's for the scratch buffer allocated for this cvtOp?
Yes, scratch buffer is the same as "long-living" buffers, the only difference, that it's live time is limited to one operation.
It seems to me that this function assumes that there is at most one extra buffer that can overlap with the buffer for this cvtOp? If there are more live buffers that overlap with this cvtOp, we should still only push cvtOp into candidates once, but compute maxScratchBufferSize based on all overlapped live buffers.
No, there could be any number of buffers with live-time overlapping with scratch buffer.
let me remove loop from this function, it should make algorithm clearer.
int64_t scratchBufferSize = allocation->getAllocatedSize(scratchBufferId); | ||
size_t totalLDSConsumption = 0; | ||
for (auto buf : liveBuffers) | ||
totalLDSConsumption = std::max( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If all liveBuffers are live at this cvtOp, should we use sum instead of max here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Max is more conservative metric in this sense. Let's consider that we have "holes" in memory:
let's consider that green buffer is scratch buffer that we want to optimize, viollet and blue are long-living buffers in shared layout.
Hole is created, because pink tensor is allocated on tick 1 and reallocated on tick 2, but previously allocated violet tensor continue live.
Summarizing buffer sizes will tell that we have 20 KB(3 * 8 KB) for scratch buffer, but in reality we probably wan to make it smaller.
* space available for scratch buffer. | ||
*/ | ||
int64_t | ||
computeMaxScratchBufferSize(triton::gpu::ConvertLayoutOp op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe computeTargetBufferSize
? I feel like "target" or "desired" is more accurate about what we want to do here.
@zhanglx13 @antiagainst PTAL |
1eae89e
to
f5a73b7
Compare
@binarman @zhanglx13 what's the status on this pull request? Do we still need it? |
I don't think we should focus on this at the moment, because it is not blocking anything and no test/kernel requires this change. I have used this change few times during debug: adding device prints increases LDS consumption and normally working test can overflow LDS. |
7f6f784
to
c9cdc96
Compare
auto srcType = cvtOp.getSrc().getType(); | ||
auto bytes = | ||
isa<triton::PointerType>(srcType.getElementType()) | ||
? elems * kPtrBitWidth / 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where kPtrBitWidth is defined ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is defined here: https://github.com/triton-lang/triton/pull/3730/files#diff-69efd7149b566a254eabbb7b7808df841b5fb3e78f82d074bc26aa9369d4e4bfR19
I agree that it is not the cleanest solution, feel free to propose other place.
be5d0d8
to
27026ff
Compare
I've moved refactoring of DecomposeUnsupportedConversions.cpp to separate PR #4262 so now here we have only changes related to new pass. Hope this will make review slightly easier. |
This PR inroduces OptimizeLDSUsage pass which generalizes LDS optimization, which was part of DecomposeUnsupportedLayouts pass.
- use arch name to infer lds size - remove unused code, simplify code, rename entities, etc.
67eaac0
to
28620cd
Compare
This PR inroduces OptimizeLDSUsage pass which generalizes LDS optimization, which was part of DecomposeUnsupportedLayouts pass.