-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
🚀 Feature
We are interested in understanding and improving performance of models with SyncBN, so we would like to detect if models have SyncBN and log that to DDP logging.
This task is to add a field "has_sync_bn" to DDP logging data and ensure it is correctly populated for models with sync batch norm. Model is generally instantiated such as -
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[gpu_index],
output_device=gpu_index,
gradient_as_bucket_view=True,
)
so we can log this in the constructor where set_construction_data_and_log
is called - https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L672
To detect whether a model actually has sync BN or not, we can probably iterate through modules recursively and run isinstance
checks for SyncBN.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang