Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dtensor] group dispatch unwrapping to a method #113846

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 108 additions & 94 deletions torch/distributed/_tensor/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,98 +92,14 @@ def dispatch(
if op_call in self._custom_op_handlers:
return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]

# extract local tensor and sharding infos and run sharding propagation
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
op_call, None
)

if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
# flatten args/kwargs when necessary
tree_args, args_spec = pytree.tree_flatten(args)
args_list: Sequence[object] = tree_args
else:
args_list, args_spec = args, None

args_schema: List[object] = []
kwargs_schema: Dict[str, object] = {}
local_args: List[object] = []
local_kwargs: Dict[str, object] = {}
mesh: Optional[DeviceMesh] = None

for arg in args_list:
if isinstance(arg, dtensor.DTensor):
args_schema.append(arg._spec)
local_args.append(arg._local_tensor)
if mesh is not None:
if mesh != arg.device_mesh:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet!"
)
else:
mesh = arg.device_mesh
elif isinstance(arg, torch.Tensor):
if arg.ndim == 0 and mesh is not None:
# scalar tensor can be safely treated as replicated
args_schema.append(
DTensorSpec(
mesh,
(Replicate(),) * mesh.ndim,
tensor_meta=TensorMeta(
shape=arg.shape, stride=arg.stride(), dtype=arg.dtype
),
)
)
local_args.append(arg)
else:
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
else:
args_schema.append(arg)
local_args.append(arg)

for k, v in kwargs.items():
if isinstance(v, dtensor.DTensor):
kwargs_schema[k] = v._spec
local_kwargs[k] = v._local_tensor
if mesh is not None:
if mesh != v.device_mesh:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet!"
)
else:
mesh = v.device_mesh
elif isinstance(v, torch.Tensor):
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
else:
kwargs_schema[k] = v
local_kwargs[k] = v

assert mesh is not None, "found no DeviceMesh from dtensor args!"
op_info = OpInfo(
mesh,
OpSchema(
op_call,
pytree.tree_unflatten(args_schema, args_spec)
if args_spec
else tuple(args_schema),
kwargs_schema,
schema_info=runtime_schema_info,
),
args_schema,
tuple(local_args),
local_kwargs,
args_spec,
)
# extract local tensor and sharding infos to a OpInfo
op_info = self.unwrap_to_op_info(op_call, args, kwargs)

self.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"

mesh = op_info.mesh
if mesh.get_coordinate() is None:
# For a non-participating device, we do:
# 1. if the return type is scalar, set the local result to None.
Expand Down Expand Up @@ -234,12 +150,15 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
if output_sharding.needs_redistribute:
# compute locally with redistribute first if needed
assert output_sharding.schema_suggestions is not None
suggested_input_schema = output_sharding.schema_suggestions[0]
self.redistribute_local_args(op_info, suggested_input_schema)
self.redistribute_local_args(
op_info, output_sharding.schema_suggestions[0]
)

local_tensor_args = (
pytree.tree_unflatten(cast(List[object], op_info.local_args), args_spec)
if args_spec
pytree.tree_unflatten(
cast(List[object], op_info.local_args), op_info.args_tree_spec
)
if op_info.args_tree_spec
else op_info.local_args
)

Expand All @@ -255,11 +174,11 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
)
# For DTensor random operator, run it within a distribute region
with random._rng_tracker._distribute_region(
cast(DTensorSpec, args_schema[0])
cast(dtensor.DTensor, args[0])._spec
):
local_results = op_call(*local_tensor_args, **local_kwargs)
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
else:
local_results = op_call(*local_tensor_args, **local_kwargs)
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)

# communicate the result to all ranks for some operators that return scalar value
if output_sharding.output_spec is None:
Expand Down Expand Up @@ -330,6 +249,101 @@ def redistribute_local_args(

op_info.local_args = tuple(new_local_args)

def unwrap_to_op_info(
self,
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> OpInfo:
# get runtime schema to determine whether to use pytree to flatten inputs
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
op_call, None
)

if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
# flatten args/kwargs when necessary
tree_args, args_spec = pytree.tree_flatten(args)
args_list: Sequence[object] = tree_args
else:
args_list, args_spec = args, None

args_schema: List[object] = []
kwargs_schema: Dict[str, object] = {}
local_args: List[object] = []
local_kwargs: Dict[str, object] = {}
mesh: Optional[DeviceMesh] = None

for arg in args_list:
if isinstance(arg, dtensor.DTensor):
args_schema.append(arg._spec)
local_args.append(arg._local_tensor)
if mesh is not None:
if mesh != arg.device_mesh:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet!"
)
else:
mesh = arg.device_mesh
elif isinstance(arg, torch.Tensor):
if arg.ndim == 0 and mesh is not None:
# scalar tensor can be safely treated as replicated
args_schema.append(
DTensorSpec(
mesh,
(Replicate(),) * mesh.ndim,
tensor_meta=TensorMeta(
shape=arg.shape, stride=arg.stride(), dtype=arg.dtype
),
)
)
local_args.append(arg)
else:
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
else:
args_schema.append(arg)
local_args.append(arg)

for k, v in kwargs.items():
if isinstance(v, dtensor.DTensor):
kwargs_schema[k] = v._spec
local_kwargs[k] = v._local_tensor
if mesh is not None:
if mesh != v.device_mesh:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet!"
)
else:
mesh = v.device_mesh
elif isinstance(v, torch.Tensor):
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
else:
kwargs_schema[k] = v
local_kwargs[k] = v

assert mesh is not None, "found no DeviceMesh from dtensor args!"
op_info = OpInfo(
mesh,
OpSchema(
op_call,
pytree.tree_unflatten(args_schema, args_spec)
if args_spec
else tuple(args_schema),
kwargs_schema,
schema_info=runtime_schema_info,
),
args_schema,
tuple(local_args),
local_kwargs,
args_spec,
)
return op_info

@staticmethod
def wrap(res: object, spec: OutputSpecType) -> object:
def to_dt(res, spec):
Expand Down
Loading