diff --git a/c10/util/Logging.cpp b/c10/util/Logging.cpp index 14b336a794f0..89506329e857 100644 --- a/c10/util/Logging.cpp +++ b/c10/util/Logging.cpp @@ -70,6 +70,7 @@ Error::Error(SourceLocation source_location, std::string msg) (*GetFetchStackTrace())())) {} using APIUsageLoggerType = std::function; +using DDPUsageLoggerType = std::function; namespace { bool IsAPIUsageDebugMode() { @@ -87,6 +88,11 @@ APIUsageLoggerType* GetAPIUsageLogger() { IsAPIUsageDebugMode() ? &APIUsageDebug : [](const string&) {}; return &func; }; + +DDPUsageLoggerType* GetDDPUsageLogger() { + static DDPUsageLoggerType func = [](const c10::DDPLoggingData&) {}; + return &func; +}; } // namespace void SetAPIUsageLogger(std::function logger) { @@ -94,6 +100,11 @@ void SetAPIUsageLogger(std::function logger) { *GetAPIUsageLogger() = logger; } +void SetPyTorchDDPUsageLogger(std::function logger) { + TORCH_CHECK(logger); + *GetDDPUsageLogger() = logger; +} + void LogAPIUsage(const std::string& event) try { if (auto logger = GetAPIUsageLogger()) (*logger)(event); @@ -101,6 +112,13 @@ void LogAPIUsage(const std::string& event) try { // static destructor race } +void LogPyTorchDDPUsage(const c10::DDPLoggingData& ddpData) try { + if (auto logger = GetDDPUsageLogger()) + (*logger)(ddpData); +} catch (std::bad_function_call&) { + // static destructor race +} + namespace detail { bool LogAPIUsageFakeReturn(const std::string& event) try { if (auto logger = GetAPIUsageLogger()) diff --git a/c10/util/Logging.h b/c10/util/Logging.h index 6b5bc5d12cd8..9febbb8f61e9 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -322,6 +322,9 @@ struct DDPLoggingData { std::string backend_name; }; +C10_API void SetPyTorchDDPUsageLogger(std::function logger); +C10_API void LogPyTorchDDPUsage(const c10::DDPLoggingData& ddpData); + namespace detail { // Return value is needed to do the static variable initialization trick C10_API bool LogAPIUsageFakeReturn(const std::string& context); diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index 3d6f949ab3e8..39b2ece357fe 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -1482,6 +1482,8 @@ void Reducer::set_construction_logging_data( ddp_logging_data_->find_unused_parameters = find_unused_parameters_; ddp_logging_data_->gradient_as_bucket_view = gradient_as_bucket_view_; ddp_logging_data_->backend_name = process_group_->getBackendName(); + + LogPyTorchDDPUsage(*ddp_logging_data_); } c10::DDPLoggingData Reducer::get_ddp_logging_data() { diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index d2e5e40bbd9e..507179d36b81 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -3218,7 +3218,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self): @unittest.skipIf( BACKEND == "nccl", "nccl does not support DDP on CPU models" ) - def test_ddp_logging_data(self): + def test_ddp_logging_data_cpu(self): group, group_id, rank = self._init_global_test() model_DDP = copy.deepcopy(DDP_NET) model_DDP = nn.parallel.DistributedDataParallel(model_DDP) @@ -3236,6 +3236,18 @@ def test_ddp_logging_data(self): self.assertEqual(ddp_logging_data.gradient_as_bucket_view, False) self.assertEqual(ddp_logging_data.backend_name, dist.get_backend(group_id)) + @unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo', + "Only Nccl & Gloo backend support DistributedDataParallel") + @skip_if_no_gpu + def test_ddp_logging_data_gpu(self): + group, group_id, rank = self._init_global_test() + model_DDP = copy.deepcopy(DDP_NET) + model_DDP.cuda(rank) + model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=[rank]) + ddp_logging_data = model_DDP.get_ddp_logging_data() + self.assertEqual(ddp_logging_data.device_ids, [rank]) + self.assertEqual(ddp_logging_data.output_device, rank) + @skipIfNoTorchVision def test_SyncBatchNorm_process_group(self): # When adopting `convert_sync_batchnorm` to convert a `nn.modules`,