Skip to content

Commit

Permalink
develop
Browse files Browse the repository at this point in the history
  • Loading branch information
SouSingh committed Sep 8, 2023
1 parent 09f7cb0 commit c19adf7
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions test/distributed/_tensor/test_embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
DTensorOpTestBase,
with_comms,
)

Expand All @@ -19,7 +19,7 @@
sys.exit(0)


class TestEmbeddingOp(DTensorTestBase):
class TestEmbeddingOp(DTensorOpTestBase):
def _run_embedding_op_test(
self,
shard_dim,
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/_tensor/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
DTensorOpTestBase,
skip_unless_torch_gpu,
with_comms,
)


class DistMathOpsTest(DTensorTestBase):
class DistMathOpsTest(DTensorOpTestBase):
@with_comms
def test_sum(self):
device_mesh = self.build_device_mesh()
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/_tensor/test_matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
DTensorOpTestBase,
skip_unless_torch_gpu,
with_comms,
)


class DistMatrixOpsTest(DTensorTestBase):
class DistMatrixOpsTest(DTensorOpTestBase):
@with_comms
def test_addmm(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/_tensor/test_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorConverter,
DTensorTestBase,
DTensorOpTestBase,
with_comms,
)


class DistTensorOpsTest(DTensorTestBase):
class DistTensorOpsTest(DTensorOpTestBase):
@with_comms
def test_aten_contiguous(self):
# this op not covered by dtensor_ops
Expand Down

0 comments on commit c19adf7

Please sign in to comment.