From 95abddc82ad9e11681f7b64a6433080860df3896 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 15 Nov 2023 22:19:27 -0800 Subject: [PATCH 1/2] [dtensor] group dispatch unwrapping to a method This PR group the dispatch unwrapping logic to a method, so that even custom handlers can reuses many parts of the dispatch logic to do custom things. [ghstack-poisoned] --- torch/distributed/_tensor/dispatch.py | 202 +++++++++++++------------ torch/distributed/_tensor/op_schema.py | 2 +- 2 files changed, 108 insertions(+), 96 deletions(-) diff --git a/torch/distributed/_tensor/dispatch.py b/torch/distributed/_tensor/dispatch.py index f0c0897f12f19..e544995c13e6d 100644 --- a/torch/distributed/_tensor/dispatch.py +++ b/torch/distributed/_tensor/dispatch.py @@ -92,98 +92,14 @@ def dispatch( if op_call in self._custom_op_handlers: return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] - # extract local tensor and sharding infos and run sharding propagation - runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( - op_call, None - ) - - if runtime_schema_info is not None and runtime_schema_info.needs_pytree: - # flatten args/kwargs when necessary - tree_args, args_spec = pytree.tree_flatten(args) - args_list: Sequence[object] = tree_args - else: - args_list, args_spec = args, None - - args_schema: List[object] = [] - kwargs_schema: Dict[str, object] = {} - local_args: List[object] = [] - local_kwargs: Dict[str, object] = {} - mesh: Optional[DeviceMesh] = None - - for arg in args_list: - if isinstance(arg, dtensor.DTensor): - args_schema.append(arg._spec) - local_args.append(arg._local_tensor) - if mesh is not None: - if mesh != arg.device_mesh: - raise NotImplementedError( - f"{op_call}: DTensor does not support cross-mesh operation yet!" - ) - else: - mesh = arg.device_mesh - elif isinstance(arg, torch.Tensor): - if arg.ndim == 0 and mesh is not None: - # scalar tensor can be safely treated as replicated - args_schema.append( - DTensorSpec( - mesh, - (Replicate(),) * mesh.ndim, - tensor_meta=TensorMeta( - shape=arg.shape, stride=arg.stride(), dtype=arg.dtype - ), - ) - ) - local_args.append(arg) - else: - raise RuntimeError( - f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" - " torch.Tensor to DTensor before calling distributed operators!" - ) - else: - args_schema.append(arg) - local_args.append(arg) - - for k, v in kwargs.items(): - if isinstance(v, dtensor.DTensor): - kwargs_schema[k] = v._spec - local_kwargs[k] = v._local_tensor - if mesh is not None: - if mesh != v.device_mesh: - raise NotImplementedError( - f"{op_call}: DTensor does not support cross-mesh operation yet!" - ) - else: - mesh = v.device_mesh - elif isinstance(v, torch.Tensor): - raise RuntimeError( - f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" - " torch.Tensor to DTensor before calling distributed operators!" - ) - else: - kwargs_schema[k] = v - local_kwargs[k] = v - - assert mesh is not None, "found no DeviceMesh from dtensor args!" - op_info = OpInfo( - mesh, - OpSchema( - op_call, - pytree.tree_unflatten(args_schema, args_spec) - if args_spec - else tuple(args_schema), - kwargs_schema, - schema_info=runtime_schema_info, - ), - args_schema, - tuple(local_args), - local_kwargs, - args_spec, - ) + # extract local tensor and sharding infos to a OpInfo + op_info = self.unwrap_to_op_info(op_call, args, kwargs) self.sharding_propagator.propagate(op_info) output_sharding = op_info.output_sharding assert output_sharding is not None, "output sharding should not be None" + mesh = op_info.mesh if mesh.get_coordinate() is None: # For a non-participating device, we do: # 1. if the return type is scalar, set the local result to None. @@ -233,13 +149,14 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: else: if output_sharding.needs_redistribute: # compute locally with redistribute first if needed - assert output_sharding.schema_suggestions is not None - suggested_input_schema = output_sharding.schema_suggestions[0] - self.redistribute_local_args(op_info, suggested_input_schema) + assert output_sharding.schema_suggestion is not None + self.redistribute_local_args(op_info, output_sharding.schema_suggestion) local_tensor_args = ( - pytree.tree_unflatten(cast(List[object], op_info.local_args), args_spec) - if args_spec + pytree.tree_unflatten( + cast(List[object], op_info.local_args), op_info.args_tree_spec + ) + if op_info.args_tree_spec else op_info.local_args ) @@ -255,11 +172,11 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: ) # For DTensor random operator, run it within a distribute region with random._rng_tracker._distribute_region( - cast(DTensorSpec, args_schema[0]) + cast(dtensor.DTensor, args[0])._spec ): - local_results = op_call(*local_tensor_args, **local_kwargs) + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) else: - local_results = op_call(*local_tensor_args, **local_kwargs) + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) # communicate the result to all ranks for some operators that return scalar value if output_sharding.output_spec is None: @@ -330,6 +247,101 @@ def redistribute_local_args( op_info.local_args = tuple(new_local_args) + def unwrap_to_op_info( + self, + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> OpInfo: + # get runtime schema to determine whether to use pytree to flatten inputs + runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( + op_call, None + ) + + if runtime_schema_info is not None and runtime_schema_info.needs_pytree: + # flatten args/kwargs when necessary + tree_args, args_spec = pytree.tree_flatten(args) + args_list: Sequence[object] = tree_args + else: + args_list, args_spec = args, None + + args_schema: List[object] = [] + kwargs_schema: Dict[str, object] = {} + local_args: List[object] = [] + local_kwargs: Dict[str, object] = {} + mesh: Optional[DeviceMesh] = None + + for arg in args_list: + if isinstance(arg, dtensor.DTensor): + args_schema.append(arg._spec) + local_args.append(arg._local_tensor) + if mesh is not None: + if mesh != arg.device_mesh: + raise NotImplementedError( + f"{op_call}: DTensor does not support cross-mesh operation yet!" + ) + else: + mesh = arg.device_mesh + elif isinstance(arg, torch.Tensor): + if arg.ndim == 0 and mesh is not None: + # scalar tensor can be safely treated as replicated + args_schema.append( + DTensorSpec( + mesh, + (Replicate(),) * mesh.ndim, + tensor_meta=TensorMeta( + shape=arg.shape, stride=arg.stride(), dtype=arg.dtype + ), + ) + ) + local_args.append(arg) + else: + raise RuntimeError( + f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" + " torch.Tensor to DTensor before calling distributed operators!" + ) + else: + args_schema.append(arg) + local_args.append(arg) + + for k, v in kwargs.items(): + if isinstance(v, dtensor.DTensor): + kwargs_schema[k] = v._spec + local_kwargs[k] = v._local_tensor + if mesh is not None: + if mesh != v.device_mesh: + raise NotImplementedError( + f"{op_call}: DTensor does not support cross-mesh operation yet!" + ) + else: + mesh = v.device_mesh + elif isinstance(v, torch.Tensor): + raise RuntimeError( + f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" + " torch.Tensor to DTensor before calling distributed operators!" + ) + else: + kwargs_schema[k] = v + local_kwargs[k] = v + + assert mesh is not None, "found no DeviceMesh from dtensor args!" + op_info = OpInfo( + mesh, + OpSchema( + op_call, + pytree.tree_unflatten(args_schema, args_spec) + if args_spec + else tuple(args_schema), + kwargs_schema, + schema_info=runtime_schema_info, + ), + args_schema, + tuple(local_args), + local_kwargs, + args_spec, + ) + return op_info + @staticmethod def wrap(res: object, spec: OutputSpecType) -> object: def to_dt(res, spec): diff --git a/torch/distributed/_tensor/op_schema.py b/torch/distributed/_tensor/op_schema.py index 688f73fd0a401..dd231588ae3ee 100644 --- a/torch/distributed/_tensor/op_schema.py +++ b/torch/distributed/_tensor/op_schema.py @@ -370,7 +370,7 @@ class OutputSharding: """ output_spec: OutputSpecType - schema_suggestions: Optional[List[OpSchema]] = None + schema_suggestion: Optional[OpSchema] = None failed_reason: Optional[str] = None needs_redistribute: bool = False From 1aef3e96dfe78002073ef69c35f0e799d5e3397f Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 15 Nov 2023 23:36:48 -0800 Subject: [PATCH 2/2] Update on "[dtensor] group dispatch unwrapping to a method" This PR group the dispatch unwrapping logic to a method, so that even custom handlers can reuses many parts of the dispatch logic to do custom things. [ghstack-poisoned] --- torch/distributed/_tensor/dispatch.py | 6 ++++-- torch/distributed/_tensor/op_schema.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/distributed/_tensor/dispatch.py b/torch/distributed/_tensor/dispatch.py index e544995c13e6d..587cc0faca5df 100644 --- a/torch/distributed/_tensor/dispatch.py +++ b/torch/distributed/_tensor/dispatch.py @@ -149,8 +149,10 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: else: if output_sharding.needs_redistribute: # compute locally with redistribute first if needed - assert output_sharding.schema_suggestion is not None - self.redistribute_local_args(op_info, output_sharding.schema_suggestion) + assert output_sharding.schema_suggestions is not None + self.redistribute_local_args( + op_info, output_sharding.schema_suggestions[0] + ) local_tensor_args = ( pytree.tree_unflatten( diff --git a/torch/distributed/_tensor/op_schema.py b/torch/distributed/_tensor/op_schema.py index dd231588ae3ee..688f73fd0a401 100644 --- a/torch/distributed/_tensor/op_schema.py +++ b/torch/distributed/_tensor/op_schema.py @@ -370,7 +370,7 @@ class OutputSharding: """ output_spec: OutputSpecType - schema_suggestion: Optional[OpSchema] = None + schema_suggestions: Optional[List[OpSchema]] = None failed_reason: Optional[str] = None needs_redistribute: bool = False