diff --git a/torch/distributed/_tensor/_utils.py b/torch/distributed/_tensor/_utils.py index eaabdf02cc8e6..592cb6bd6e25e 100644 --- a/torch/distributed/_tensor/_utils.py +++ b/torch/distributed/_tensor/_utils.py @@ -43,42 +43,6 @@ def compute_local_shape( return tuple(local_shape) -# TODO: audit existing code base to see if we can safely remove this API. -def compute_local_offset( - global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] -) -> Tuple[int, ...]: - """ - Compute the offsets of a local shard of the given DTensor on its current - global rank. This is mostly used by distributed checkpointing to know the - exact offsets of the local shard. - """ - my_coordinate = mesh.get_coordinate() - - if my_coordinate is None: - # if rank not in the mesh, return empty offset - return () - else: - local_offsets = [0] * len(global_shape) - local_shape = list(global_shape) - - for idx, placement in enumerate(placements): - mesh_dim_size = mesh.size(idx) - if isinstance(placement, Shard): - shard_dim = placement.dim - assert shard_dim < len( - local_shape - ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" - shard_size, shard_offset = placement._local_shard_size_on_dim( - local_shape[shard_dim], - mesh_dim_size, - my_coordinate[idx], - return_offset=True, - ) - local_shape[shard_dim] = shard_size - local_offsets[shard_dim] = shard_offset - return tuple(local_offsets) - - def compute_local_shape_and_global_offset( global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: