Skip to content

Commit

Permalink
Create PyTorch DDP logging APIs for applications to use
Browse files Browse the repository at this point in the history
Pull Request resolved: #50637

add APIs for logging pytorch ddp logging data in applications.

Differential Revision: [D25933411](https://our.internmc.facebook.com/intern/diff/D25933411/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25933411/)!
ghstack-source-id: 120813409
  • Loading branch information
zhaojuanmao committed Feb 2, 2021
1 parent b1907f5 commit 74b02f8
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 1 deletion.
18 changes: 18 additions & 0 deletions c10/util/Logging.cpp
Expand Up @@ -70,6 +70,7 @@ Error::Error(SourceLocation source_location, std::string msg)
(*GetFetchStackTrace())())) {}

using APIUsageLoggerType = std::function<void(const std::string&)>;
using DDPUsageLoggerType = std::function<void(const c10::DDPLoggingData&)>;

namespace {
bool IsAPIUsageDebugMode() {
Expand All @@ -87,20 +88,37 @@ APIUsageLoggerType* GetAPIUsageLogger() {
IsAPIUsageDebugMode() ? &APIUsageDebug : [](const string&) {};
return &func;
};

DDPUsageLoggerType* GetDDPUsageLogger() {
static DDPUsageLoggerType func = [](const c10::DDPLoggingData&) {};
return &func;
};
} // namespace

void SetAPIUsageLogger(std::function<void(const std::string&)> logger) {
TORCH_CHECK(logger);
*GetAPIUsageLogger() = logger;
}

void SetPyTorchDDPUsageLogger(std::function<void(const c10::DDPLoggingData&)> logger) {
TORCH_CHECK(logger);
*GetDDPUsageLogger() = logger;
}

void LogAPIUsage(const std::string& event) try {
if (auto logger = GetAPIUsageLogger())
(*logger)(event);
} catch (std::bad_function_call&) {
// static destructor race
}

void LogPyTorchDDPUsage(const c10::DDPLoggingData& ddpData) try {
if (auto logger = GetDDPUsageLogger())
(*logger)(ddpData);
} catch (std::bad_function_call&) {
// static destructor race
}

namespace detail {
bool LogAPIUsageFakeReturn(const std::string& event) try {
if (auto logger = GetAPIUsageLogger())
Expand Down
3 changes: 3 additions & 0 deletions c10/util/Logging.h
Expand Up @@ -322,6 +322,9 @@ struct DDPLoggingData {
std::string backend_name;
};

C10_API void SetPyTorchDDPUsageLogger(std::function<void(const c10::DDPLoggingData&)> logger);
C10_API void LogPyTorchDDPUsage(const c10::DDPLoggingData& ddpData);

namespace detail {
// Return value is needed to do the static variable initialization trick
C10_API bool LogAPIUsageFakeReturn(const std::string& context);
Expand Down
2 changes: 2 additions & 0 deletions torch/lib/c10d/reducer.cpp
Expand Up @@ -1482,6 +1482,8 @@ void Reducer::set_construction_logging_data(
ddp_logging_data_->find_unused_parameters = find_unused_parameters_;
ddp_logging_data_->gradient_as_bucket_view = gradient_as_bucket_view_;
ddp_logging_data_->backend_name = process_group_->getBackendName();

LogPyTorchDDPUsage(*ddp_logging_data_);
}

c10::DDPLoggingData Reducer::get_ddp_logging_data() {
Expand Down
14 changes: 13 additions & 1 deletion torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -3218,7 +3218,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self):
@unittest.skipIf(
BACKEND == "nccl", "nccl does not support DDP on CPU models"
)
def test_ddp_logging_data(self):
def test_ddp_logging_data_cpu(self):
group, group_id, rank = self._init_global_test()
model_DDP = copy.deepcopy(DDP_NET)
model_DDP = nn.parallel.DistributedDataParallel(model_DDP)
Expand All @@ -3236,6 +3236,18 @@ def test_ddp_logging_data(self):
self.assertEqual(ddp_logging_data.gradient_as_bucket_view, False)
self.assertEqual(ddp_logging_data.backend_name, dist.get_backend(group_id))

@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
"Only Nccl & Gloo backend support DistributedDataParallel")
@skip_if_no_gpu
def test_ddp_logging_data_gpu(self):
group, group_id, rank = self._init_global_test()
model_DDP = copy.deepcopy(DDP_NET)
model_DDP.cuda(rank)
model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=[rank])
ddp_logging_data = model_DDP.get_ddp_logging_data()
self.assertEqual(ddp_logging_data.device_ids, [rank])
self.assertEqual(ddp_logging_data.output_device, rank)

@skipIfNoTorchVision
def test_SyncBatchNorm_process_group(self):
# When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
Expand Down

0 comments on commit 74b02f8

Please sign in to comment.