Skip to content

Commit

Permalink
Create a DDPLoggingData and expose it to python interface
Browse files Browse the repository at this point in the history
Pull Request resolved: #50622

1. Define a DDPLoggingData struct that is the placeholder for all the ddp related logging fields
2. Put the DDPLoggingData struct in the C10 directory so that it can be easily imported by c10 and torch files
3. Expose get_ddp_logging_data() method in python so that users can get the logging data and dump in their applications
4. Unit test tested the logging data can be set and got as expected
5. Follow up will add more logging fields such as perf stats, internal states, env variables and etc
ghstack-source-id: 120307365

Differential Revision: [D25930527](https://our.internmc.facebook.com/intern/diff/D25930527/)
  • Loading branch information
zhaojuanmao committed Jan 26, 2021
1 parent 57fb2c0 commit 8096a50
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 3 deletions.
21 changes: 21 additions & 0 deletions c10/util/Logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,27 @@ BINARY_COMP_HELPER(LessEquals, <=)
C10_API void SetAPIUsageLogger(std::function<void(const std::string&)> logger);
C10_API void LogAPIUsage(const std::string& context);

// PyTorch ddp usage logging capabilities
// DDPLoggingData holds data that can be logged in applications
// for analysis and debugging. Data structure is defined in
// c10 directory so that it can be easily imported by both c10
// and torch files.
// TODO -- right now starting with logging a small set of straightforward
// fields, will add more fields as follow ups such as performance stats,
// internal states and env variables and etc.
struct DDPLoggingData {
// Data that can be got during DistributedDataParallel construction time
int world_size;
int rank;
std::string module_name;
std::vector<int> device_ids;
int output_device;
bool broadcast_buffers;
int bucket_cap_mb;
bool find_unused_parameters;
bool gradient_as_bucket_view;
};

namespace detail {
// Return value is needed to do the static variable initialization trick
C10_API bool LogAPIUsageFakeReturn(const std::string& context);
Expand Down
8 changes: 8 additions & 0 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ class BuiltinCommHookType(Enum):
def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
def _register_builtin_comm_hook(reducer: Reducer, comm_hook_type: BuiltinCommHookType): ...

def _get_ddp_logging_data(reducer: Reducer): ...
def _set_construction_logging_data(
reducer: Reducer,
module_name: str,
device_ids: List[int],
output_device: int,
broadcast_buffers: bool): ...

class _GradBucket:
def __init__(self, tensors: List[Tensor]): ...
def get_tensors(self) -> List[Tensor]: ...
Expand Down
37 changes: 35 additions & 2 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,29 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
"_register_builtin_comm_hook",
&_register_builtin_comm_hook,
py::arg("reducer"),
py::arg("comm_hook_type"));
py::arg("comm_hook_type"))
.def(
"_set_construction_logging_data",
[](
::c10d::Reducer& reducer,
const std::string& module_name,
const std::vector<int>& device_ids,
int output_device,
bool broadcast_buffers) -> void {
reducer.set_construction_logging_data(
module_name, device_ids, output_device, broadcast_buffers);
},
py::arg("reducer"),
py::arg("module_name"),
py::arg("device_ids"),
py::arg("output_device"),
py::arg("broadcast_buffers"))
.def(
"_get_ddp_logging_data",
[](::c10d::Reducer& reducer) -> c10::DDPLoggingData {
return reducer.get_ddp_logging_data();
},
py::arg("reducer"));

shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket")
.def(
Expand Down Expand Up @@ -1159,6 +1181,18 @@ that adds a prefix to each key inserted to the store.
Note that ``fut.done()`` returns only whether the operation has been enqueued on the GPU.
)");

py::class_<c10::DDPLoggingData>(module, "DDPLoggingData")
.def(py::init<>())
.def_readwrite("world_size", &c10::DDPLoggingData::world_size)
.def_readwrite("rank", &c10::DDPLoggingData::rank)
.def_readwrite("module_name", &c10::DDPLoggingData::module_name)
.def_readwrite("device_ids", &c10::DDPLoggingData::device_ids)
.def_readwrite("output_device", &c10::DDPLoggingData::output_device)
.def_readwrite("broadcast_buffers", &c10::DDPLoggingData::broadcast_buffers)
.def_readwrite("bucket_cap_mb", &c10::DDPLoggingData::bucket_cap_mb)
.def_readwrite("find_unused_parameters", &c10::DDPLoggingData::find_unused_parameters)
.def_readwrite("gradient_as_bucket_view", &c10::DDPLoggingData::gradient_as_bucket_view);

module.def(
"_compute_bucket_assignment_by_size",
&::c10d::compute_bucket_assignment_by_size,
Expand Down Expand Up @@ -1668,7 +1702,6 @@ static const auto DistributedC10dFrontendTorchBind =
.def(
"get_name_of_process_group",
&::c10d::DistributedC10d::getNameOfProcessGroup);

} // namespace

// c10d methods on torch._C
Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def is_available():
_broadcast_coalesced,
_compute_bucket_assignment_by_size,
_test_python_store,
_set_construction_logging_data,
_get_ddp_logging_data
)
if sys.platform != 'win32':
from torch._C._distributed_c10d import (
Expand Down
24 changes: 23 additions & 1 deletion torch/lib/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ Reducer::Reducer(
has_rebuilt_bucket_(false),
bucket_bytes_cap_(bucket_bytes_cap),
divFactor_(kUnsetDivFactor),
comm_hook_(nullptr) {
comm_hook_(nullptr),
ddp_logging_data_(std::move(std::make_unique<c10::DDPLoggingData>())) {
C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
TORCH_CHECK(replicas_.size() >= 1, "Expected at least one model replica.");
TORCH_CHECK(replicas_[0].size() >= 1, "Expected at least one parameter.");
Expand Down Expand Up @@ -1465,6 +1466,27 @@ void Reducer::ensure_prior_reduction_finished() {
}
}

void Reducer::set_construction_logging_data(
const std::string& module_name,
const std::vector<int>& device_ids,
int output_device,
bool broadcast_buffers
) {
ddp_logging_data_->module_name = module_name;
ddp_logging_data_->device_ids = device_ids;
ddp_logging_data_->output_device = output_device;
ddp_logging_data_->broadcast_buffers = broadcast_buffers;
ddp_logging_data_->world_size = process_group_->getSize();
ddp_logging_data_->rank = process_group_->getRank();
ddp_logging_data_->bucket_cap_mb = bucket_bytes_cap_ / (1024 * 1024);
ddp_logging_data_->find_unused_parameters = find_unused_parameters_;
ddp_logging_data_->gradient_as_bucket_view = gradient_as_bucket_view_;
}

c10::DDPLoggingData Reducer::get_ddp_logging_data() {
return *ddp_logging_data_;
}

namespace {

// Tensors may be coalesced into buckets. Buckets must contain tensors of
Expand Down
16 changes: 16 additions & 0 deletions torch/lib/c10d/reducer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ class Reducer {
// index has been used.
std::vector<at::Tensor> get_local_used_maps_on_device() const;

// Set logging data that can be got during DistributedDataParallel
// construction time.
void set_construction_logging_data(
const std::string& module_name,
const std::vector<int>& device_ids,
int output_device,
bool broadcast_buffers);

// An Interface for users to get DDPLoggingData and log them
// in the applications.
c10::DDPLoggingData get_ddp_logging_data();

protected:
// Forward declaration.
struct Bucket;
Expand Down Expand Up @@ -358,6 +370,10 @@ class Reducer {
private:
// comm_hook_ is used to access the DDP communication hook if registered.
std::unique_ptr<CommHookInterface> comm_hook_;

// ddp_logging_data_ is used to hold all the ddp related logging
// data fields.
std::unique_ptr<c10::DDPLoggingData> ddp_logging_data_;
};

// This is equivalent to take_tensors but returns indices into the
Expand Down
11 changes: 11 additions & 0 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,14 @@ def produces_sparse_gradient(module):
self.find_unused_parameters,
self.gradient_as_bucket_view)

# Set logging data that can be got during construction time.
dist._set_construction_logging_data(
self.reducer,
self.module.__class__.__name__,
[] if self.device_ids is None else self.device_ids,
-1 if self.output_device is None else self.output_device,
self.broadcast_buffers)

# passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self._module_copies)

Expand Down Expand Up @@ -765,6 +773,9 @@ def train(self, mode=True):
module.train(mode)
return self

def get_ddp_logging_data(self):
return dist._get_ddp_logging_data(self.reducer)

# When running in join mode, schedules an allreduce to match the one in the
# forward pass to determine the no. of currently active processes and whether
# all processes have joined.
Expand Down
19 changes: 19 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3152,6 +3152,25 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self):
global_bs=global_bs,
offset=bs_offset)

@unittest.skipIf(
BACKEND == "nccl", "nccl does not support DDP on CPU models"
)
def test_ddp_logging_data(self):
model_DDP = copy.deepcopy(DDP_NET)
model_DDP = nn.parallel.DistributedDataParallel(model_DDP)
ddp_logging_data = model_DDP.get_ddp_logging_data()
self.assertEqual(ddp_logging_data.world_size, dist.get_world_size())
self.assertEqual(ddp_logging_data.rank, dist.get_rank())
self.assertEqual(ddp_logging_data.module_name, 'Net')
self.assertEqual(ddp_logging_data.device_ids, [])
# output_device is -1 in default if it is not set, e.g.
# output_device of CPU training is -1.
self.assertEqual(ddp_logging_data.output_device, -1)
self.assertEqual(ddp_logging_data.broadcast_buffers, True)
self.assertEqual(ddp_logging_data.bucket_cap_mb, 25)
self.assertEqual(ddp_logging_data.find_unused_parameters, False)
self.assertEqual(ddp_logging_data.gradient_as_bucket_view, False)

@skipIfNoTorchVision
def test_SyncBatchNorm_process_group(self):
# When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
Expand Down

0 comments on commit 8096a50

Please sign in to comment.