From 1013385cfd814cbe319cc79097b3fc3d49f27711 Mon Sep 17 00:00:00 2001 From: yanlizhao Date: Fri, 15 Jan 2021 16:28:22 -0800 Subject: [PATCH] create a pytorch ddp detailed usage logger for analysis and debug add APIs for logging pytorch ddp logging data in applications. Differential Revision: [D25933411](https://our.internmc.facebook.com/intern/diff/D25933411/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25933411/)! ghstack-source-id: 119899629 Pull Request resolved: https://github.com/pytorch/pytorch/pull/50637 --- c10/util/Logging.cpp | 18 ++++++++++++++++++ c10/util/Logging.h | 5 ++++- torch/csrc/distributed/c10d/init.cpp | 2 +- torch/lib/c10d/reducer.cpp | 4 +++- torch/lib/c10d/reducer.hpp | 2 +- torch/nn/parallel/distributed.py | 2 +- .../_internal/distributed/distributed_test.py | 2 +- 7 files changed, 29 insertions(+), 6 deletions(-) diff --git a/c10/util/Logging.cpp b/c10/util/Logging.cpp index 14b336a794f0..89506329e857 100644 --- a/c10/util/Logging.cpp +++ b/c10/util/Logging.cpp @@ -70,6 +70,7 @@ Error::Error(SourceLocation source_location, std::string msg) (*GetFetchStackTrace())())) {} using APIUsageLoggerType = std::function; +using DDPUsageLoggerType = std::function; namespace { bool IsAPIUsageDebugMode() { @@ -87,6 +88,11 @@ APIUsageLoggerType* GetAPIUsageLogger() { IsAPIUsageDebugMode() ? &APIUsageDebug : [](const string&) {}; return &func; }; + +DDPUsageLoggerType* GetDDPUsageLogger() { + static DDPUsageLoggerType func = [](const c10::DDPLoggingData&) {}; + return &func; +}; } // namespace void SetAPIUsageLogger(std::function logger) { @@ -94,6 +100,11 @@ void SetAPIUsageLogger(std::function logger) { *GetAPIUsageLogger() = logger; } +void SetPyTorchDDPUsageLogger(std::function logger) { + TORCH_CHECK(logger); + *GetDDPUsageLogger() = logger; +} + void LogAPIUsage(const std::string& event) try { if (auto logger = GetAPIUsageLogger()) (*logger)(event); @@ -101,6 +112,13 @@ void LogAPIUsage(const std::string& event) try { // static destructor race } +void LogPyTorchDDPUsage(const c10::DDPLoggingData& ddpData) try { + if (auto logger = GetDDPUsageLogger()) + (*logger)(ddpData); +} catch (std::bad_function_call&) { + // static destructor race +} + namespace detail { bool LogAPIUsageFakeReturn(const std::string& event) try { if (auto logger = GetAPIUsageLogger()) diff --git a/c10/util/Logging.h b/c10/util/Logging.h index 9e57f5b6ff54..308bb059de87 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -315,7 +315,7 @@ struct DDPLoggingData { int rank; std::string backend; std::string module_name; - std::vector device_ids; + std::string device_ids; int output_device; bool broadcast_buffers; int bucket_cap_mb; @@ -323,6 +323,9 @@ struct DDPLoggingData { bool gradient_as_bucket_view; }; +C10_API void SetPyTorchDDPUsageLogger(std::function logger); +C10_API void LogPyTorchDDPUsage(const c10::DDPLoggingData& ddpData); + namespace detail { // Return value is needed to do the static variable initialization trick C10_API bool LogAPIUsageFakeReturn(const std::string& context); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 24acfdfc82bb..971dc786f7c1 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -187,7 +187,7 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { ::c10d::Reducer& reducer, const std::string& backend, const std::string& module_name, - const std::vector& device_ids, + const std::string& device_ids, int output_device, bool broadcast_buffers) -> void { reducer.set_construction_logging_data( diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index abc9ddba2448..8b7d9d4a964c 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -1469,7 +1469,7 @@ void Reducer::ensure_prior_reduction_finished() { void Reducer::set_construction_logging_data( const std::string& backend, const std::string& module_name, - const std::vector& device_ids, + const std::string& device_ids, int output_device, bool broadcast_buffers ) { @@ -1483,6 +1483,8 @@ void Reducer::set_construction_logging_data( ddp_logging_data_->bucket_cap_mb = bucket_bytes_cap_ / (1024 * 1024); ddp_logging_data_->find_unused_parameters = find_unused_parameters_; ddp_logging_data_->gradient_as_bucket_view = gradient_as_bucket_view_; + + LogPyTorchDDPUsage(*ddp_logging_data_); } c10::DDPLoggingData Reducer::get_ddp_logging_data() { diff --git a/torch/lib/c10d/reducer.hpp b/torch/lib/c10d/reducer.hpp index 9be551fb7f8d..0e5c8be165b2 100644 --- a/torch/lib/c10d/reducer.hpp +++ b/torch/lib/c10d/reducer.hpp @@ -110,7 +110,7 @@ class Reducer { void set_construction_logging_data( const std::string& backend, const std::string& module_name, - const std::vector& device_ids, + const std::string& device_ids, int output_device, bool broadcast_buffers); diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index db9f0594c1b2..586088c0c81a 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -611,7 +611,7 @@ def produces_sparse_gradient(module): self.reducer, dist.get_backend(), self.module.__class__.__name__, - [] if self.device_ids is None else self.device_ids, + "" if self.device_ids is None else ",".join(map(str, self.device_ids)), -1 if self.output_device is None else self.output_device, self.broadcast_buffers) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 31d2061a042d..11f742ec3652 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -3150,7 +3150,7 @@ def test_ddp_logging_data(self): self.assertEqual(ddp_logging_data.rank, dist.get_rank()) self.assertEqual(ddp_logging_data.backend, 'gloo') self.assertEqual(ddp_logging_data.module_name, 'Net') - self.assertEqual(ddp_logging_data.device_ids, []) + self.assertEqual(ddp_logging_data.device_ids, "") self.assertEqual(ddp_logging_data.output_device, -1) self.assertEqual(ddp_logging_data.broadcast_buffers, True) self.assertEqual(ddp_logging_data.bucket_cap_mb, 25)