Skip to content

Track models with SyncBN in DDP #66215

@rohan-varma

Description

@rohan-varma

🚀 Feature

We are interested in understanding and improving performance of models with SyncBN, so we would like to detect if models have SyncBN and log that to DDP logging.

This task is to add a field "has_sync_bn" to DDP logging data and ensure it is correctly populated for models with sync batch norm. Model is generally instantiated such as -

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[gpu_index],
            output_device=gpu_index,
            gradient_as_bucket_view=True,
        )

so we can log this in the constructor where set_construction_data_and_log is called - https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L672

To detect whether a model actually has sync BN or not, we can probably iterate through modules recursively and run isinstance checks for SyncBN.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang

Metadata

Metadata

Assignees

No one assigned

    Labels

    better-engineeringRelatively self-contained tasks for better engineering contributorsmodule: ddpIssues/PRs related distributed data parallel trainingoncall: distributedAdd this issue/PR to distributed oncall triage queuept_distributed_rampupRamp up tasks for new developers on PT distributed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions