Skip to content

Commit

Permalink
[Kineto][NCCL][5/n] Populate in/out split size info for all_to_all fr…
Browse files Browse the repository at this point in the history
…om CPU to CUDA kernel (pytorch#112308)

Summary:

X-link: pytorch/kineto#822

This diff populates all_to_all input and out split size from CPU op to GPU kernel when valid.

Test Plan:
**Trace example**:
- For non all_to_all collective functions: https://fburl.com/perfdoctor/4nobsu15
https://pxl.cl/3GNVb

- For all_to_all: https://fburl.com/perfdoctor/f418goys

https://pxl.cl/3H2nd

Reviewed By: aaronenyeshi, idning

Differential Revision: D50762093
  • Loading branch information
yoyoyocmu authored and facebook-github-bot committed Oct 30, 2023
1 parent a14f8e0 commit 1ac8eea
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions torch/csrc/profiler/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ static constexpr auto kOutMsgSize = "Out msg size";
static constexpr auto kInSplit = "In split size";
static constexpr auto kOutSplit = "Out split size";
static constexpr auto kGroupSize = "Group size";
static constexpr int32_t kTruncatLength = 30;
#endif // USE_C10D
#endif // USE_DISTRIBUTED

Expand All @@ -365,12 +366,32 @@ std::unordered_map<std::string, std::string> saveNcclMeta(
kDtype, fmt::format("\"{}\"", c10::toString(debugInfo->getDType())));
map.emplace(kInMsgSize, std::to_string(debugInfo->getInMessageSize()));
map.emplace(kOutMsgSize, std::to_string(debugInfo->getOutMessageSize()));
map.emplace(
kInSplit,
fmt::format("[{}]", fmt::join(debugInfo->getInputSplitSizes(), ", ")));
map.emplace(
kOutSplit,
fmt::format("[{}]", fmt::join(debugInfo->getOutputSplitSizes(), ", ")));
auto& inSplitSizes = debugInfo->getInputSplitSizes();
if (!inSplitSizes.empty() && inSplitSizes.size() <= kTruncatLength) {
map.emplace(kInSplit, fmt::format("[{}]", fmt::join(inSplitSizes, ", ")));
} else if (inSplitSizes.size() > kTruncatLength) {
map.emplace(
kInSplit,
fmt::format(
"[{}, ...]",
fmt::join(
inSplitSizes.begin(),
inSplitSizes.begin() + kTruncatLength,
", ")));
}
auto& outSplitSizes = debugInfo->getOutputSplitSizes();
if (!outSplitSizes.empty() && outSplitSizes.size() <= kTruncatLength) {
map.emplace(kOutSplit, fmt::format("[{}]", fmt::join(outSplitSizes, ", ")));
} else if (outSplitSizes.size() > kTruncatLength) {
map.emplace(
kOutSplit,
fmt::format(
"[{}, ...]",
fmt::join(
outSplitSizes.begin(),
outSplitSizes.begin() + kTruncatLength,
", ")));
}
map.emplace(kGroupSize, std::to_string(debugInfo->getWorldSize()));
#endif // USE_C10D
#endif // USE_DISTRIBUTED
Expand Down

0 comments on commit 1ac8eea

Please sign in to comment.