Skip to content

Commit

Permalink
Pass params to dump_nccl_trace_pickle (#128781)
Browse files Browse the repository at this point in the history
Summary
Pass parameters from request to dump_nccl_trace_pickle handler.
The supported parameters + value are all lowercase.
includecollectives={true, false}
includestacktraces={true, false}
onlyactive={true, false}

Example post is:
/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true

Test Plan:
unit tests

Differential Revision: [D58640474](https://our.internmc.facebook.com/intern/diff/D58640474)
Pull Request resolved: #128781
Approved by: https://github.com/d4l3k
  • Loading branch information
c-p-i-o authored and pytorchmergebot committed Jun 18, 2024
1 parent d9eaa22 commit f7eae27
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 10 deletions.
37 changes: 37 additions & 0 deletions test/distributed/elastic/test_control_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,43 @@ def test_dump_nccl_trace_pickle(self) -> None:
resp = pool.request("POST", "/handler/dump_nccl_trace_pickle")
self.assertEqual(resp.status, 200)
out = pickle.loads(resp.data)
self.assertIsInstance(out, dict)
self.assertIn("version", out)

@requires_cuda
def test_dump_nccl_trace_pickle_with_params(self) -> None:
with local_worker_server() as pool:
# bad key - not lower case
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true"
)
self.assertEqual(resp.status, 400)
# unknown key
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?unknownkey=true"
)
self.assertEqual(resp.status, 400)
# bad value - not a bool
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool"
)
self.assertEqual(resp.status, 400)
# bad value - value not lowercase
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?includecollectives=True"
)
self.assertEqual(resp.status, 400)
# good key and value
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?includecollectives=true"
)
self.assertEqual(resp.status, 200)
# multiple good keys and values
resp = pool.request(
"POST",
"/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true",
)
self.assertEqual(resp.status, 200)

def test_tcp(self) -> None:
import requests
Expand Down
51 changes: 51 additions & 0 deletions torch/csrc/distributed/c10d/NCCLUtils.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>

#include <c10/util/CallOnce.h>
#include <c10/util/env.h>
#include <algorithm>

#ifdef USE_C10D_NCCL
#include <vector>
Expand Down Expand Up @@ -238,6 +241,54 @@ std::string getNcclErrorDetailStr(
return interpret + err;
}

control_plane::RegisterHandler dumpHandler{
"dump_nccl_trace_pickle",
[](const control_plane::Request& req, control_plane::Response& res) {
const auto params = req.params();
size_t validParamCount = 0;

// valid params
const std::string includeCollectivesStr = "includecollectives";
const std::string includeStackTracesStr = "includestacktraces";
const std::string onlyActiveStr = "onlyactive";

std::unordered_map<std::string, bool> expectedParams = {
{includeCollectivesStr, true},
{includeStackTracesStr, true},
{onlyActiveStr, false}};

for (const auto& [paramName, paramValue] : params) {
auto it = expectedParams.find(paramName);
if (it != expectedParams.end()) {
validParamCount++;
if (paramValue == "true") {
it->second = true;
} else if (paramValue == "false") {
it->second = false;
} else {
res.setStatus(400);
res.setContent(
"Invalid value for " + paramName +
" valid values are true or false",
"text/plain");
return;
}
}
}
if (validParamCount < params.size()) {
res.setStatus(400);
res.setContent(
"Invalid parameters - unexpected param passed in", "text/plain");
return;
}
res.setContent(
dump_nccl_trace(
expectedParams[includeCollectivesStr],
expectedParams[includeStackTracesStr],
expectedParams[onlyActiveStr]),
"application/octet-stream");
}};

} // namespace c10d

#endif // USE_C10D_NCCL
10 changes: 0 additions & 10 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/TraceUtils.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
#include <torch/csrc/distributed/c10d/logger.hpp>
#include <torch/torch.h>

Expand Down Expand Up @@ -380,15 +379,6 @@ std::string dump_nccl_trace(
}
#endif

// TODO(c-p-i-o): add a JSON endpoint.
control_plane::RegisterHandler dumpHandler{
"dump_nccl_trace_pickle",
[](const control_plane::Request& req, control_plane::Response& res) {
// TODO: c-p-i-o: params from the request need to go to dump_nccl_trace.
res.setContent(
dump_nccl_trace(true, true, false), "application/octet-stream");
}};

std::optional<std::function<void(std::function<void(const std::string&)>)>>&
get_cpp_trace_dumper() {
static std::optional<
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/distributed/c10d/control_plane/Handlers.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <functional>
#include <map>
#include <string>

#include <c10/macros/Export.h>
Expand All @@ -15,6 +16,8 @@ class TORCH_API Request {
virtual ~Request() = default;

virtual const std::string& body() = 0;

virtual const std::multimap<std::string, std::string>& params() const = 0;
};

// Response represents a response to the handler. This conceptually maps to an
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class RequestImpl : public Request {
return req_.body;
}

const std::multimap<std::string, std::string>& params() const override {
return req_.params;
}

private:
const httplib::Request& req_;
};
Expand Down

0 comments on commit f7eae27

Please sign in to comment.