From a62099ff24f5158d874195cb0cefde13db362ea5 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 20 Nov 2023 09:29:18 -0800 Subject: [PATCH] [DTensor] Replaced neg dim normalization with assert in global info [ghstack-poisoned] --- torch/distributed/_tensor/_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/distributed/_tensor/_utils.py b/torch/distributed/_tensor/_utils.py index 10d9b11c51b15..04b714c7789a3 100644 --- a/torch/distributed/_tensor/_utils.py +++ b/torch/distributed/_tensor/_utils.py @@ -145,8 +145,10 @@ def compute_global_tensor_info( if placement.is_shard(): shard_placement = cast(Shard, placement) if shard_placement.dim < 0: - # normalize shard dim to be positive - shard_placement.dim += len(tensor_shape) + raise AssertionError( + "Shard placements should have negative dims normalized in " + f"the user-facing APIs: {shard_placement}" + ) shard_dim = shard_placement.dim assert (