From 7ccc9023a367e22519844c14b849e12bca85c621 Mon Sep 17 00:00:00 2001 From: Faycel Kouteib Date: Fri, 26 Apr 2024 19:21:46 -0700 Subject: [PATCH] Fix linting errors in my changes. --- test/distributed/_tensor/experimental/test_local_map.py | 4 +--- test/distributed/_tensor/experimental/test_tp_transform.py | 4 +--- test/distributed/_tensor/test_api.py | 4 +--- test/distributed/_tensor/test_common_rules.py | 4 +--- test/distributed/_tensor/test_convolution_ops.py | 4 +--- test/distributed/_tensor/test_dtensor.py | 4 +--- test/distributed/_tensor/test_embedding_ops.py | 4 +--- test/distributed/_tensor/test_experimental_ops.py | 4 +--- test/distributed/_tensor/test_init.py | 4 +--- test/distributed/_tensor/test_redistribute.py | 4 +--- test/distributed/_tensor/test_utils.py | 5 +---- test/distributed/_tensor/test_view_ops.py | 4 +--- 12 files changed, 12 insertions(+), 37 deletions(-) diff --git a/test/distributed/_tensor/experimental/test_local_map.py b/test/distributed/_tensor/experimental/test_local_map.py index 32a0b8312467..5181db4ced32 100644 --- a/test/distributed/_tensor/experimental/test_local_map.py +++ b/test/distributed/_tensor/experimental/test_local_map.py @@ -12,9 +12,7 @@ from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.experimental import local_map from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase def equal_forward(device_mesh, X, Y): diff --git a/test/distributed/_tensor/experimental/test_tp_transform.py b/test/distributed/_tensor/experimental/test_tp_transform.py index 3d4f8483bc7c..d60e25031417 100644 --- a/test/distributed/_tensor/experimental/test_tp_transform.py +++ b/test/distributed/_tensor/experimental/test_tp_transform.py @@ -12,9 +12,7 @@ RowwiseParallel, ) from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase class MLPListModule(torch.nn.Module): diff --git a/test/distributed/_tensor/test_api.py b/test/distributed/_tensor/test_api.py index e30cb27c3aa8..f52033301915 100644 --- a/test/distributed/_tensor/test_api.py +++ b/test/distributed/_tensor/test_api.py @@ -12,9 +12,7 @@ Shard, ) from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase class MyModel(nn.Module): diff --git a/test/distributed/_tensor/test_common_rules.py b/test/distributed/_tensor/test_common_rules.py index 714b689b5154..b087d0b63d6e 100644 --- a/test/distributed/_tensor/test_common_rules.py +++ b/test/distributed/_tensor/test_common_rules.py @@ -8,9 +8,7 @@ from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase aten = torch.ops.aten diff --git a/test/distributed/_tensor/test_convolution_ops.py b/test/distributed/_tensor/test_convolution_ops.py index 941170cfd4bc..6781353902d1 100644 --- a/test/distributed/_tensor/test_convolution_ops.py +++ b/test/distributed/_tensor/test_convolution_ops.py @@ -13,9 +13,7 @@ Shard, ) from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase ITER_TIME = 10 LR = 0.001 diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index f03794d49c7e..6203616b9254 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -21,9 +21,7 @@ ) from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase c10d_functional = torch.ops.c10d_functional diff --git a/test/distributed/_tensor/test_embedding_ops.py b/test/distributed/_tensor/test_embedding_ops.py index 47dd39e8ba3c..d85b82d9afab 100644 --- a/test/distributed/_tensor/test_embedding_ops.py +++ b/test/distributed/_tensor/test_embedding_ops.py @@ -12,9 +12,7 @@ ) from torch.distributed._tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase if TEST_WITH_DEV_DBG_ASAN: print( diff --git a/test/distributed/_tensor/test_experimental_ops.py b/test/distributed/_tensor/test_experimental_ops.py index 7066b8daae65..cc8a585b2ffd 100644 --- a/test/distributed/_tensor/test_experimental_ops.py +++ b/test/distributed/_tensor/test_experimental_ops.py @@ -8,9 +8,7 @@ from torch.distributed._tensor import distribute_tensor, Replicate from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase ITER_TIME = 10 diff --git a/test/distributed/_tensor/test_init.py b/test/distributed/_tensor/test_init.py index 391e05d188c1..dfdd4f0b7f34 100644 --- a/test/distributed/_tensor/test_init.py +++ b/test/distributed/_tensor/test_init.py @@ -4,9 +4,7 @@ import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase class DTensorInitOpsTest(DTensorOpTestBase): diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py index bd96a839db94..ba4d7174a81a 100644 --- a/test/distributed/_tensor/test_redistribute.py +++ b/test/distributed/_tensor/test_redistribute.py @@ -10,9 +10,7 @@ from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase funcol = torch.ops.c10d_functional diff --git a/test/distributed/_tensor/test_utils.py b/test/distributed/_tensor/test_utils.py index 6a1a3308121c..a0789b864d6c 100644 --- a/test/distributed/_tensor/test_utils.py +++ b/test/distributed/_tensor/test_utils.py @@ -14,9 +14,7 @@ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase c10d_functional = torch.ops.c10d_functional @@ -55,7 +53,6 @@ def test_compute_local_shape_2d_uneven(self): else: self.assertEqual(local_size3[1], 3) - def test_compute_local_shape_and_global_offset_1D(self): one_d_placements = [[Shard(0)], [Replicate()]] diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py index 55d290327f54..c35059b8fd21 100644 --- a/test/distributed/_tensor/test_view_ops.py +++ b/test/distributed/_tensor/test_view_ops.py @@ -21,9 +21,7 @@ ) from torch.distributed._tensor.placement_types import Placement from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorOpTestBase, -) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase from torch.utils import _pytree as pytree