Skip to content

Commit

Permalink
create a pytorch ddp detailed usage logger for analysis and debug
Browse files Browse the repository at this point in the history
add APIs for logging pytorch ddp logging data in applications.

Differential Revision: [D25933411](https://our.internmc.facebook.com/intern/diff/D25933411/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25933411/)!

ghstack-source-id: 119899629
Pull Request resolved: #50637
  • Loading branch information
zhaojuanmao committed Jan 16, 2021
1 parent a9a4706 commit 1013385
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 6 deletions.
18 changes: 18 additions & 0 deletions c10/util/Logging.cpp
Expand Up @@ -70,6 +70,7 @@ Error::Error(SourceLocation source_location, std::string msg)
(*GetFetchStackTrace())())) {}

using APIUsageLoggerType = std::function<void(const std::string&)>;
using DDPUsageLoggerType = std::function<void(const c10::DDPLoggingData&)>;

namespace {
bool IsAPIUsageDebugMode() {
Expand All @@ -87,20 +88,37 @@ APIUsageLoggerType* GetAPIUsageLogger() {
IsAPIUsageDebugMode() ? &APIUsageDebug : [](const string&) {};
return &func;
};

DDPUsageLoggerType* GetDDPUsageLogger() {
static DDPUsageLoggerType func = [](const c10::DDPLoggingData&) {};
return &func;
};
} // namespace

void SetAPIUsageLogger(std::function<void(const std::string&)> logger) {
TORCH_CHECK(logger);
*GetAPIUsageLogger() = logger;
}

void SetPyTorchDDPUsageLogger(std::function<void(const c10::DDPLoggingData&)> logger) {
TORCH_CHECK(logger);
*GetDDPUsageLogger() = logger;
}

void LogAPIUsage(const std::string& event) try {
if (auto logger = GetAPIUsageLogger())
(*logger)(event);
} catch (std::bad_function_call&) {
// static destructor race
}

void LogPyTorchDDPUsage(const c10::DDPLoggingData& ddpData) try {
if (auto logger = GetDDPUsageLogger())
(*logger)(ddpData);
} catch (std::bad_function_call&) {
// static destructor race
}

namespace detail {
bool LogAPIUsageFakeReturn(const std::string& event) try {
if (auto logger = GetAPIUsageLogger())
Expand Down
5 changes: 4 additions & 1 deletion c10/util/Logging.h
Expand Up @@ -315,14 +315,17 @@ struct DDPLoggingData {
int rank;
std::string backend;
std::string module_name;
std::vector<int> device_ids;
std::string device_ids;
int output_device;
bool broadcast_buffers;
int bucket_cap_mb;
bool find_unused_parameters;
bool gradient_as_bucket_view;
};

C10_API void SetPyTorchDDPUsageLogger(std::function<void(const c10::DDPLoggingData&)> logger);
C10_API void LogPyTorchDDPUsage(const c10::DDPLoggingData& ddpData);

namespace detail {
// Return value is needed to do the static variable initialization trick
C10_API bool LogAPIUsageFakeReturn(const std::string& context);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -187,7 +187,7 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
::c10d::Reducer& reducer,
const std::string& backend,
const std::string& module_name,
const std::vector<int>& device_ids,
const std::string& device_ids,
int output_device,
bool broadcast_buffers) -> void {
reducer.set_construction_logging_data(
Expand Down
4 changes: 3 additions & 1 deletion torch/lib/c10d/reducer.cpp
Expand Up @@ -1469,7 +1469,7 @@ void Reducer::ensure_prior_reduction_finished() {
void Reducer::set_construction_logging_data(
const std::string& backend,
const std::string& module_name,
const std::vector<int>& device_ids,
const std::string& device_ids,
int output_device,
bool broadcast_buffers
) {
Expand All @@ -1483,6 +1483,8 @@ void Reducer::set_construction_logging_data(
ddp_logging_data_->bucket_cap_mb = bucket_bytes_cap_ / (1024 * 1024);
ddp_logging_data_->find_unused_parameters = find_unused_parameters_;
ddp_logging_data_->gradient_as_bucket_view = gradient_as_bucket_view_;

LogPyTorchDDPUsage(*ddp_logging_data_);
}

c10::DDPLoggingData Reducer::get_ddp_logging_data() {
Expand Down
2 changes: 1 addition & 1 deletion torch/lib/c10d/reducer.hpp
Expand Up @@ -110,7 +110,7 @@ class Reducer {
void set_construction_logging_data(
const std::string& backend,
const std::string& module_name,
const std::vector<int>& device_ids,
const std::string& device_ids,
int output_device,
bool broadcast_buffers);

Expand Down
2 changes: 1 addition & 1 deletion torch/nn/parallel/distributed.py
Expand Up @@ -611,7 +611,7 @@ def produces_sparse_gradient(module):
self.reducer,
dist.get_backend(),
self.module.__class__.__name__,
[] if self.device_ids is None else self.device_ids,
"" if self.device_ids is None else ",".join(map(str, self.device_ids)),
-1 if self.output_device is None else self.output_device,
self.broadcast_buffers)

Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -3150,7 +3150,7 @@ def test_ddp_logging_data(self):
self.assertEqual(ddp_logging_data.rank, dist.get_rank())
self.assertEqual(ddp_logging_data.backend, 'gloo')
self.assertEqual(ddp_logging_data.module_name, 'Net')
self.assertEqual(ddp_logging_data.device_ids, [])
self.assertEqual(ddp_logging_data.device_ids, "")
self.assertEqual(ddp_logging_data.output_device, -1)
self.assertEqual(ddp_logging_data.broadcast_buffers, True)
self.assertEqual(ddp_logging_data.bucket_cap_mb, 25)
Expand Down

0 comments on commit 1013385

Please sign in to comment.