Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dtensor] skip pytree when not necessary #110132

Closed
wants to merge 7 commits into from
Closed
2 changes: 1 addition & 1 deletion test/distributed/_tensor/test_device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_device_mesh_hash(self):
mesh_tensor_2d = torch.arange(8).reshape(4, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor_2d)
mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d)
self.assertNotEqual(hash(mesh), hash(mesh2))
self.assertEqual(hash(mesh), hash(mesh2))
mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2)
mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d)
self.assertNotEqual(hash(mesh), hash(mesh3))
Expand Down
13 changes: 10 additions & 3 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def __init__(
else torch.tensor(mesh, dtype=torch.int)
)
self.mesh_dim_names = mesh_dim_names

# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._hash = hash((self._flatten_mesh_list, self.mesh.shape))
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
# process (we need to know if the current global rank is in the mesh or not)
Expand Down Expand Up @@ -278,14 +282,17 @@ def __repr__(self) -> str:
return f"DeviceMesh:({self.mesh.tolist()})"

def __hash__(self):
return hash((self.mesh, id(self)))
return self._hash

def __eq__(self, other: object) -> bool:
if not isinstance(other, DeviceMesh):
return False
if id(self) == id(other):
if id(self.mesh) == id(other.mesh):
return True
return self.mesh.equal(other.mesh)
return (
self.mesh.shape == other.mesh.shape
and self._flatten_mesh_list == other._flatten_mesh_list
)

def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
"""
Expand Down
103 changes: 55 additions & 48 deletions torch/distributed/_tensor/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,29 @@ def redistribute_local_args(

# TODO: the op schema should probably just remain flattened so that we can avoid this tree flatten
# Need to fix all the ops before doing this.
flatten_args_schema_to_reshard = tree_flatten(suggested_input_schema.args_schema)[0]
if op_info.args_tree_spec is not None:
flatten_args_schema_to_reshard = tuple(
tree_flatten(suggested_input_schema.args_schema)[0]
)
else:
flatten_args_schema_to_reshard = suggested_input_schema.args_schema

new_flat_local_args: List[object] = []
new_local_args: List[object] = []
for i, arg_spec in enumerate(op_info.flat_args_schema):
reshard_arg_spec = flatten_args_schema_to_reshard[i]
if isinstance(arg_spec, DTensorSpec):
local_tensor = cast(torch.Tensor, op_info.flat_local_args[i])
local_tensor = cast(torch.Tensor, op_info.local_args[i])
if arg_spec != reshard_arg_spec:
resharded_local_tensor = redistribute_local_tensor(
local_tensor, arg_spec, reshard_arg_spec
)
new_flat_local_args.append(resharded_local_tensor)
new_local_args.append(resharded_local_tensor)
else:
new_flat_local_args.append(local_tensor)
new_local_args.append(local_tensor)
else:
new_flat_local_args.append(reshard_arg_spec)
new_local_args.append(reshard_arg_spec)

op_info.flat_local_args = new_flat_local_args
wanchaol marked this conversation as resolved.
Show resolved Hide resolved
op_info.local_args = tuple(new_local_args)


def operator_dispatch(
Expand All @@ -118,19 +123,25 @@ def _operator_dispatch(
kwargs: Dict[str, object],
sharding_propagator: ShardingPropagator,
) -> Tuple[object, OpSchema, OutputSharding]:
# unwrap the op info from args/kwargs
flat_args_list, args_spec = tree_flatten(args)
flat_kwargs_list, kwargs_spec = tree_flatten(kwargs)
flat_args_schema: List[object] = []
flat_local_args: List[object] = []
flat_kwargs_schema: List[object] = []
flat_local_kwargs: List[object] = []
runtime_schema_info = sharding_propagator.op_to_schema_info.get(op_call, None)

if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
# flatten args/kwargs when necessary
tree_args, args_spec = tree_flatten(args)
args_list: Sequence[object] = tree_args
else:
args_list, args_spec = args, None

args_schema: List[object] = []
kwargs_schema: Dict[str, object] = {}
local_args: List[object] = []
local_kwargs: Dict[str, object] = {}
mesh: Optional[DeviceMesh] = None

for arg in flat_args_list:
for arg in args_list:
if isinstance(arg, dtensor.DTensor):
flat_args_schema.append(arg._spec)
flat_local_args.append(arg._local_tensor)
args_schema.append(arg._spec)
local_args.append(arg._local_tensor)
if mesh is not None:
if mesh != arg.device_mesh:
raise NotImplementedError(
Expand All @@ -144,44 +155,42 @@ def _operator_dispatch(
" torch.Tensor to DTensor before calling distributed operators!"
)
else:
flat_args_schema.append(arg)
flat_local_args.append(arg)
args_schema.append(arg)
local_args.append(arg)

for kwarg in flat_kwargs_list:
if isinstance(kwarg, dtensor.DTensor):
flat_kwargs_schema.append(kwarg._spec)
flat_local_kwargs.append(kwarg._local_tensor)
for k, v in kwargs.items():
if isinstance(v, dtensor.DTensor):
kwargs_schema[k] = v._spec
local_kwargs[k] = v._local_tensor
if mesh is not None:
if mesh != kwarg.device_mesh:
if mesh != v.device_mesh:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet!"
)
else:
mesh = kwarg.device_mesh
elif isinstance(kwarg, torch.Tensor):
mesh = v.device_mesh
elif isinstance(v, torch.Tensor):
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
else:
flat_kwargs_schema.append(kwarg)
flat_local_kwargs.append(kwarg)
kwargs_schema[k] = v
local_kwargs[k] = v

assert mesh is not None, "found no DeviceMesh from dtensor args!"
op_info = OpInfo(
mesh,
OpSchema(
op_call,
tree_unflatten(flat_args_schema, args_spec),
tree_unflatten(flat_kwargs_schema, kwargs_spec),
schema_info=sharding_propagator.op_to_schema_info.get(op_call, None),
tree_unflatten(args_schema, args_spec) if args_spec else tuple(args_schema),
kwargs_schema,
schema_info=runtime_schema_info,
),
flat_args_schema,
flat_kwargs_schema,
flat_local_args,
flat_local_kwargs,
args_schema,
tuple(local_args),
local_kwargs,
args_spec,
kwargs_spec,
)

sharding_propagator.propagate(op_info)
Expand Down Expand Up @@ -241,16 +250,14 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
suggested_input_schema = output_sharding.schema_suggestions[0]
redistribute_local_args(op_info, suggested_input_schema)

local_tensor_args = tree_unflatten(
op_info.flat_local_args, op_info.args_tree_spec
)
local_tensor_kwargs = tree_unflatten(
op_info.flat_local_kwargs, op_info.kwargs_tree_spec
local_tensor_args = (
tree_unflatten(cast(List[object], op_info.local_args), args_spec)
if args_spec
else op_info.local_args
)

# run local op computation with potentially modified args/kwargs
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
local_tensor_kwargs = cast(Dict[str, object], local_tensor_kwargs)
if _is_random_op(op_call) and is_rng_supported_mesh(mesh):
if not random._rng_tracker:
raise RuntimeError(
Expand All @@ -261,11 +268,11 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
)
# For DTensor random operator, run it within a distribute region
with random._rng_tracker._distribute_region(
cast(DTensorSpec, flat_args_schema[0])
cast(DTensorSpec, args_schema[0])
):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
local_results = op_call(*local_tensor_args, **local_kwargs)
else:
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
local_results = op_call(*local_tensor_args, **local_kwargs)

# communicate the result to all ranks for some operators that return scalar value
if output_sharding.output_spec is None:
Expand All @@ -290,9 +297,9 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
)
out_dts = []
spec_idx = 0
for arg in op_call._schema.arguments:
if arg.is_out:
out_dt = cast(dtensor.DTensor, kwargs[arg.name])
for argument in op_call._schema.arguments:
if argument.is_out:
out_dt = cast(dtensor.DTensor, kwargs[argument.name])
out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
out_dts.append(out_dt)
spec_idx += 1
Expand Down
14 changes: 7 additions & 7 deletions torch/distributed/_tensor/op_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ class RuntimeSchemaInfo:
static_argnum: int = 100
# This static_kwargkey records static kwarg names which would affect sharding prop
static_kwargkey: Optional[List[str]] = None
# TODO: make use of this field
needs_pytree: bool = True
# each op can decide if it wants to use pytree flatten/unflatten during operator
# eager execution, by default we don't need to do flatten/unflatten, only if the
# op indicate it needs to, this is to accelate eager performance.
needs_pytree: bool = False


@dataclass
Expand Down Expand Up @@ -330,11 +332,9 @@ class OpInfo:
mesh: DeviceMesh
schema: OpSchema
flat_args_schema: List[object]
flat_kwargs_schema: List[object]
flat_local_args: List[object]
flat_local_kwargs: List[object]
args_tree_spec: TreeSpec
kwargs_tree_spec: TreeSpec
local_args: Sequence[object]
local_kwargs: Dict[str, object]
args_tree_spec: Optional[TreeSpec] = None

# the output sharding info
output_sharding: Optional[OutputSharding] = None
6 changes: 4 additions & 2 deletions torch/distributed/_tensor/ops/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def prop_index_select(op_schema: OpSchema) -> OutputSharding:
return result


@register_prop_rule(aten.index.Tensor)
@register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True))
def prop_index(op_schema: OpSchema) -> OutputSharding:
"""
Expect replicated on the first input; _mostly_ pointwise on the second input.
Expand Down Expand Up @@ -413,7 +413,9 @@ def place(vp: Placement, ip: Placement) -> Placement:
return result


@register_prop_rule(aten.cat.default, schema_info=RuntimeSchemaInfo(1))
@register_prop_rule(
aten.cat.default, schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
)
def cat_rule(op_schema: OpSchema) -> OutputSharding:
# torch.cat requires all tensors must either have the same shape (except
# in the concatenating dimension) or be "empty". "Empty" here strictly means
Expand Down