Skip to content

Commit

Permalink
[dtensor][experiment] experimenting with displaying distributed model…
Browse files Browse the repository at this point in the history
… parameters and printing sharding info (#127987)

**Summary**
Example code to display distributed model parameters and verify them against ground truth. Also prints sharding information.

**Test Plan**
torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/display_sharding_example.py

Pull Request resolved: #127987
Approved by: https://github.com/XilunWu
ghstack dependencies: #127358, #127360, #127630
  • Loading branch information
sinhaanshul authored and pytorchmergebot committed Jun 9, 2024
1 parent 2c2cf1d commit f681e36
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 7 deletions.
22 changes: 22 additions & 0 deletions torch/distributed/_tensor/debug/comm_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ModuleParamaterShardingTracker(ModuleTracker):
def __init__(self):
super().__init__()
self.module_parameters_dict = {}
self.sharding_dict = {}

def _fw_pre_hook(self, mod, input):
name = super()._get_mod_name(mod)
Expand All @@ -77,14 +78,26 @@ def _fw_pre_hook(self, mod, input):

self.module_parameters_dict[name][param_name] = param.data

if isinstance(param.data, DTensor):
key_name = name + "." + param_name
self.sharding_dict[key_name] = param.data.placements

def __enter__(self):
self.module_parameters_dict.clear()
self.sharding_dict.clear()
self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
self._fw_post_handle = register_module_forward_hook(super()._fw_post_hook)

def __exit__(self, *args):
super().__exit__(*args)

def print_paramater_info(self):
print(self.module_parameters_dict)

def print_sharding_info(self):
for key, value in self.sharding_dict.items():
print(key + ": " + str(value))


class CommDebugMode(TorchDispatchMode):
"""
Expand Down Expand Up @@ -130,6 +143,9 @@ def get_comm_counts(self) -> Dict[Any, int]:
def get_parameter_info(self) -> Dict[str, Dict[str, Any]]:
return self.advanced_module_tracker.module_parameters_dict

def get_sharding_info(self) -> Dict[str, Dict[str, Any]]:
return self.advanced_module_tracker.sharding_dict

def __enter__(self):
self.comm_counts.clear()
super().__enter__()
Expand All @@ -140,6 +156,12 @@ def __exit__(self, *args):
self.advanced_module_tracker.__exit__()
super().__exit__(*args)

def print_paramater_info(self):
self.advanced_module_tracker.print_paramater_info()

def print_sharding_info(self):
self.advanced_module_tracker.print_sharding_info()

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
# When running this mode with DTensor, ordinarily all modes will
# run **before** subclasses get a chance to run.
Expand Down
92 changes: 85 additions & 7 deletions torch/distributed/_tensor/examples/display_sharding_example.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,50 @@
import os
from typing import Any, Dict

import torch

from torch.distributed._tensor import DeviceMesh, Shard
from torch.distributed._tensor.debug import CommDebugMode

from torch.distributed._tensor.debug.comm_mode import ModuleParamaterShardingTracker

from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)

from torch.testing._internal.distributed._tensor.common_dtensor import (
MLPModule,
MLPStacked,
NUM_DEVICES,
)


def get_device_type():
return (
"cuda"
if torch.cuda.is_available() and torch.cuda.device_count() >= 4
else "cpu"
)


c10d_functional = torch.ops.c10d_functional

aten = torch.ops.aten
supported_ops = [aten.view.default, aten._to_copy.default]


class DisplayShardingExample:
"""
Checks if the set of keys in ground truth dictionary and the set
produced in advanced_module_tracker are in the same order
"""

def __init__(self, world_size, rank):
self.world_size = world_size
self.rank = rank
self.device_type = get_device_type()

def same_set_of_keys(self, dict1, dict2):
dict1_keys = []
dict2_keys = []
Expand Down Expand Up @@ -54,9 +81,7 @@ def ground_truth(self, model):
return module_parameters_dict

def test_display_parameters_MLP(self):
"""
Example of using obtaining all module's FQN and parameters for a given model
"""
"""Example of obtaining all module's FQN and parameters for a given model"""

inp_size = [8, 10]

Expand All @@ -67,7 +92,6 @@ def test_display_parameters_MLP(self):

LR = 0.25

optim = torch.optim.SGD(model.parameters(), lr=LR)
comm_mode = CommDebugMode()
module_tracker = ModuleParamaterShardingTracker()

Expand All @@ -91,7 +115,61 @@ def test_display_parameters_MLP(self):
)
)

def test_display_parameters_MLP_distributed(
self, is_seq_parallel=False, recompute_activation=False
):
"Example of obtaining all module's FQN and parameters for a given distributed model and printing the sharding info"
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
inp_size = [8, 10]
rng_seed = self.rank if is_seq_parallel else 0
torch.manual_seed(rng_seed)
inp = torch.rand(*inp_size, device=self.device_type)
model = MLPModule(self.device_type)

LR = 0.25

parallelize_plan = {
"net1": ColwiseParallel(input_layouts=Shard(0))
if is_seq_parallel
else ColwiseParallel(),
"net2": RowwiseParallel(output_layouts=Shard(0))
if is_seq_parallel
else RowwiseParallel(),
}

model = parallelize_module(model, device_mesh, parallelize_plan)

comm_mode = CommDebugMode()

with comm_mode:
output_tp = model(inp)
output_tp.sum().backward()

print(
self.same_set_of_keys(
self.ground_truth(model), comm_mode.get_parameter_info()
)
)

comm_mode.print_sharding_info()


def run_example(world_size, rank):
# set manual seed
torch.manual_seed(0)

# run the example
instantiated_test = DisplayShardingExample(world_size, rank)
instantiated_test.test_display_parameters_MLP_distributed()


if __name__ == "__main__":
instantiated_test = DisplayShardingExample()
instantiated_test.test_display_parameters_MLP()
# this script is launched via torchrun which automatically manages ProcessGroup
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size == 4 # our example uses 4 worker ranks

run_example(world_size, rank)

0 comments on commit f681e36

Please sign in to comment.