Skip to content

Commit

Permalink
Fix linting errors in my changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
fkouteib committed Apr 27, 2024
1 parent 0362967 commit 7ccc902
Show file tree
Hide file tree
Showing 12 changed files with 12 additions and 37 deletions.
4 changes: 1 addition & 3 deletions test/distributed/_tensor/experimental/test_local_map.py
Expand Up @@ -12,9 +12,7 @@
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.experimental import local_map
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase


def equal_forward(device_mesh, X, Y):
Expand Down
4 changes: 1 addition & 3 deletions test/distributed/_tensor/experimental/test_tp_transform.py
Expand Up @@ -12,9 +12,7 @@
RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase


class MLPListModule(torch.nn.Module):
Expand Down
4 changes: 1 addition & 3 deletions test/distributed/_tensor/test_api.py
Expand Up @@ -12,9 +12,7 @@
Shard,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase


class MyModel(nn.Module):
Expand Down
4 changes: 1 addition & 3 deletions test/distributed/_tensor/test_common_rules.py
Expand Up @@ -8,9 +8,7 @@
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase

aten = torch.ops.aten

Expand Down
4 changes: 1 addition & 3 deletions test/distributed/_tensor/test_convolution_ops.py
Expand Up @@ -13,9 +13,7 @@
Shard,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase

ITER_TIME = 10
LR = 0.001
Expand Down
4 changes: 1 addition & 3 deletions test/distributed/_tensor/test_dtensor.py
Expand Up @@ -21,9 +21,7 @@
)

from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase


c10d_functional = torch.ops.c10d_functional
Expand Down
4 changes: 1 addition & 3 deletions test/distributed/_tensor/test_embedding_ops.py
Expand Up @@ -12,9 +12,7 @@
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase

if TEST_WITH_DEV_DBG_ASAN:
print(
Expand Down
4 changes: 1 addition & 3 deletions test/distributed/_tensor/test_experimental_ops.py
Expand Up @@ -8,9 +8,7 @@

from torch.distributed._tensor import distribute_tensor, Replicate
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase


ITER_TIME = 10
Expand Down
4 changes: 1 addition & 3 deletions test/distributed/_tensor/test_init.py
Expand Up @@ -4,9 +4,7 @@
import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase


class DTensorInitOpsTest(DTensorOpTestBase):
Expand Down
4 changes: 1 addition & 3 deletions test/distributed/_tensor/test_redistribute.py
Expand Up @@ -10,9 +10,7 @@

from torch.testing._internal.common_utils import run_tests

from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase

funcol = torch.ops.c10d_functional

Expand Down
5 changes: 1 addition & 4 deletions test/distributed/_tensor/test_utils.py
Expand Up @@ -14,9 +14,7 @@
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase

c10d_functional = torch.ops.c10d_functional

Expand Down Expand Up @@ -55,7 +53,6 @@ def test_compute_local_shape_2d_uneven(self):
else:
self.assertEqual(local_size3[1], 3)


def test_compute_local_shape_and_global_offset_1D(self):
one_d_placements = [[Shard(0)], [Replicate()]]

Expand Down
4 changes: 1 addition & 3 deletions test/distributed/_tensor/test_view_ops.py
Expand Up @@ -21,9 +21,7 @@
)
from torch.distributed._tensor.placement_types import Placement
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
from torch.utils import _pytree as pytree


Expand Down

0 comments on commit 7ccc902

Please sign in to comment.