Skip to content

Commit

Permalink
[Kineto][NCCL][5/n] Populate in/out split size info for all_to_all fr…
Browse files Browse the repository at this point in the history
…om CPU to CUDA kernel

Summary: This diff populates all_to_all input and out split size from CPU op to GPU kernel when valid.

Test Plan:
**Trace example**:
- For non all_to_all collective functions: https://fburl.com/perfdoctor/4nobsu15
https://pxl.cl/3GNVb

- For all_to_all: https://fburl.com/perfdoctor/f418goys

https://pxl.cl/3H2nd

Differential Revision: D50762093
  • Loading branch information
yoyoyocmu authored and facebook-github-bot committed Oct 28, 2023
1 parent a1a765c commit 6f6be94
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions torch/csrc/profiler/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ static constexpr auto kOutMsgSize = "Out msg size";
static constexpr auto kInSplit = "In split size";
static constexpr auto kOutSplit = "Out split size";
static constexpr auto kGroupSize = "Group size";
static constexpr int32_t kTruncatLength = 30;
#endif // USE_C10D
#endif // USE_DISTRIBUTED

Expand All @@ -365,12 +366,18 @@ std::unordered_map<std::string, std::string> saveNcclMeta(
kDtype, fmt::format("\"{}\"", c10::toString(debugInfo->getDType())));
map.emplace(kInMsgSize, std::to_string(debugInfo->getInMessageSize()));
map.emplace(kOutMsgSize, std::to_string(debugInfo->getOutMessageSize()));
map.emplace(
auto& inSplitSizes = debugInfo->getInputSplitSizes();
if (!inSplitSizes.empty() && inSplitSizes.size() <= kTruncatLength) {
map.emplace(
kInSplit,
fmt::format("[{}]", fmt::join(debugInfo->getInputSplitSizes(), ", ")));
map.emplace(
fmt::format("[{}]", fmt::join(inSplitSizes, ", ")));
}
auto & outSplitSizes = debugInfo->getOutputSplitSizes();
if (!outSplitSizes.empty() && outSplitSizes.size() <= kTruncatLength) {
map.emplace(
kOutSplit,
fmt::format("[{}]", fmt::join(debugInfo->getOutputSplitSizes(), ", ")));
fmt::format("[{}]", fmt::join(outSplitSizes, ", ")));
}
map.emplace(kGroupSize, std::to_string(debugInfo->getWorldSize()));
#endif // USE_C10D
#endif // USE_DISTRIBUTED
Expand Down

0 comments on commit 6f6be94

Please sign in to comment.