Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switched the test cases to use threads instead of processes #115032

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions test/distributed/_tensor/test_embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
DTensorOpTestBase,
)

if TEST_WITH_DEV_DBG_ASAN:
Expand All @@ -24,7 +23,7 @@
sys.exit(0)


class TestEmbeddingOp(DTensorTestBase):
class TestEmbeddingOp(DTensorOpTestBase):
def _run_embedding_op_test(
self,
shard_dim,
Expand Down Expand Up @@ -112,7 +111,7 @@ def _run_embedding_op_test(
)
self.assertEqual(local_output, sharded_output.full_tensor())

@with_comms

def test_sharded_embedding_colwise(self):
self._run_embedding_op_test(1, [5, 4], 17, 12)
self._run_embedding_op_test(1, [6, 7, 6], 21, 11)
Expand All @@ -122,7 +121,7 @@ def test_sharded_embedding_colwise(self):
self._run_embedding_op_test(1, [34], 15, 14, padding_idx=10)
self._run_embedding_op_test(1, [8, 6, 5, 4], 23, 13, padding_idx=12)

@with_comms

def test_sharded_embedding_colwise_errors(self):
with self.assertRaisesRegex(
NotImplementedError,
Expand All @@ -132,7 +131,7 @@ def test_sharded_embedding_colwise_errors(self):
1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0
)

@with_comms

def test_sharded_embedding_rowwise(self):
with self.assertRaisesRegex(
NotImplementedError,
Expand Down
14 changes: 7 additions & 7 deletions test/distributed/_tensor/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
DTensorOpTestBase,
skip_unless_torch_gpu,
with_comms,
)


class DistMathOpsTest(DTensorTestBase):
class DistMathOpsTest(DTensorOpTestBase):
def linear_op_reductions(self, op_str):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
Expand Down Expand Up @@ -44,18 +44,18 @@ def linear_op_reductions(self, op_str):
dt_full_reduced = op_dt().full_tensor()
self.assertEqual(dt_full_reduced, full_reduced_tensor)

@with_comms

def test_linear_op_reductions(self):
for op_str in ("all", "sum", "prod", "max", "min"):
self.linear_op_reductions(op_str)

@with_comms

@skip_unless_torch_gpu
def test_mean(self):
self.linear_op_reductions("mean")

# TODO: forward test can be removed once test_softmax_with_bwd passes on CPU
@with_comms

def test_softmax_fwd(self):
device_mesh = self.build_device_mesh()

Expand Down Expand Up @@ -88,7 +88,7 @@ def test_softmax_fwd(self):
# TODO: get test_softmax_with_bwd pass on CPU
# DTensor's _softmax_backward_data produces wrong result on CPU on certain dimension.
# fail_on_cpu_list = [(0, -1), (1, -1)]
@with_comms

@skip_unless_torch_gpu
def test_softmax_with_bwd(self):
device_mesh = self.build_device_mesh()
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_softmax_with_bwd(self):
self.assertIsNotNone(dist_x.grad)
self.assertEqual(dist_x.grad.full_tensor(), x.grad)

@with_comms

def test_full_shard_math_ops(self):
mesh_shape = (2, self.world_size // 2)
mesh = DeviceMesh(
Expand Down
77 changes: 38 additions & 39 deletions test/distributed/_tensor/test_matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,42 @@
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
DTensorOpTestBase,
skip_unless_torch_gpu,
with_comms,
)


class DistMatrixOpsTest(DTensorTestBase):
@with_comms
class DistMatrixOpsTest(DTensorOpTestBase):

def test_addmm(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
self.build_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
Copy link
Contributor

@fduwjj fduwjj Dec 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On top of Andrew's question, DeviceMesh return a device_mesh, maybe you want to call it self.device_mesh or self.mesh?

shard_spec = [Shard(0)]
replica_spec = [Replicate()]

tensor_to_shard = torch.randn(12, 8)
mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
mat1 = distribute_tensor(tensor_to_shard, self.build_mesh, shard_spec)
tensor_to_replicate = torch.randn(8, 4)
mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
mat2 = distribute_tensor(tensor_to_replicate, self.build_mesh, replica_spec)
input_tensor = torch.randn(4)
input = distribute_tensor(input_tensor, device_mesh, replica_spec)
input = distribute_tensor(input_tensor, self.build_mesh, replica_spec)

dist_res = torch.addmm(input, mat1, mat2)
local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
self.assertEqual(dist_res.full_tensor(), local_res)

@with_comms

def test_addmm_auto_redistribute(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
self.build_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard0_spec = [Shard(0)]
shard1_spec = [Shard(1)]
replica_spec = [Replicate()]

tensor_to_shard1 = torch.randn(12, 8, requires_grad=True)
mat1 = distribute_tensor(tensor_to_shard1, device_mesh, shard1_spec)
mat1 = distribute_tensor(tensor_to_shard1, self.build_mesh, shard1_spec)
tensor_to_shard0 = torch.randn(8, 4, requires_grad=True)
mat2 = distribute_tensor(tensor_to_shard0, device_mesh, shard0_spec)
mat2 = distribute_tensor(tensor_to_shard0, self.build_mesh, shard0_spec)
input_tensor = torch.randn(4, requires_grad=True)
input = distribute_tensor(input_tensor, device_mesh, replica_spec)
input = distribute_tensor(input_tensor, self.build_mesh, replica_spec)

local_res = torch.addmm(input_tensor, tensor_to_shard1, tensor_to_shard0)
dist_res = torch.addmm(input, mat1, mat2)
Expand All @@ -70,9 +69,9 @@ def test_addmm_auto_redistribute(self):
self.assertIsNotNone(mat2.grad)
self.assertEqual(mat2.grad.full_tensor(), tensor_to_shard0.grad)

@with_comms

def test_mm(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
self.build_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard0_spec = Shard(0)
shard1_spec = Shard(1)
replica_spec = Replicate()
Expand All @@ -84,10 +83,10 @@ def test_mm(self):
def test_placement_comb(
placements1: List[Placement], placements2: List[Placement]
) -> None:
dt1 = distribute_tensor(t1, device_mesh, placements1)
dt2 = distribute_tensor(t2, device_mesh, placements2)
dt1 = distribute_tensor(t1, self.build_mesh, placements1)
dt2 = distribute_tensor(t2, self.build_mesh, placements2)
dist_res: DTensor = cast(DTensor, torch.mm(dt1, dt2)).redistribute(
device_mesh, [replica_spec]
self.build_mesh, [replica_spec]
)
self.assertEqual(dist_res.to_local(), local_res)
# backward
Expand All @@ -100,30 +99,30 @@ def test_placement_comb(
for spec in shard_specs_comb:
test_placement_comb([spec[0]], [spec[1]])

@with_comms

def test_t(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
self.build_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(0)]

tensor_to_transpose = torch.randn(12, 8, requires_grad=True)
mat = distribute_tensor(tensor_to_transpose, device_mesh, shard_spec)
mat = distribute_tensor(tensor_to_transpose, self.build_mesh, shard_spec)
tranposed_mat = mat.t()
self.assertEqual(tranposed_mat.size(), torch.Size([8, 12]))
self.assertEqual(tranposed_mat.placements, [Shard(1)])
tranposed_mat2 = tranposed_mat.t()
self.assertEqual(tranposed_mat2.size(), torch.Size([12, 8]))
self.assertEqual(tranposed_mat2.placements, shard_spec)

@with_comms

def test_t_partial(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
self.build_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

a = torch.randn(12, 8)
b = torch.randn(8, 4)
c = torch.mm(a, b).t()

da = distribute_tensor(a, device_mesh, [Shard(1)])
db = distribute_tensor(b, device_mesh, [Shard(0)])
da = distribute_tensor(a, self.build_mesh, [Shard(1)])
db = distribute_tensor(b, self.build_mesh, [Shard(0)])

# mm(da, db) should return a _Partial tensor.
# transposing it should keep it _Partial
Expand All @@ -134,14 +133,14 @@ def test_t_partial(self):
# check that the local and distributed op results match
self.assertEqual(
c,
dc.redistribute(device_mesh, [Replicate()]).to_local(),
dc.redistribute(self.build_mesh, [Replicate()]).to_local(),
)

# baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588
@with_comms

@skip_unless_torch_gpu
def test_baddbmm(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
self.build_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
tensor = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
batch_1 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True)
Expand All @@ -154,15 +153,15 @@ def test_placement_comb(
alpha: int,
batch_1_grad: Optional[torch.Tensor],
) -> None:
tensor_dt = distribute_tensor(tensor, device_mesh, tensor_placements)
batch_1_dt = distribute_tensor(batch_1, device_mesh, batch_1_placements)
batch_2_dt = distribute_tensor(batch_2, device_mesh, batch_2_placements)
tensor_dt = distribute_tensor(tensor, self.build_mesh, tensor_placements)
batch_1_dt = distribute_tensor(batch_1, self.build_mesh, batch_1_placements)
batch_2_dt = distribute_tensor(batch_2, self.build_mesh, batch_2_placements)
dist_res = cast(
DTensor,
torch.baddbmm(
tensor_dt, batch_1_dt, batch_2_dt, beta=beta, alpha=alpha
),
).redistribute(device_mesh, [Replicate()])
).redistribute(self.build_mesh, [Replicate()])
dist_local_res = dist_res.to_local()
assert not torch.isnan(local_result).any()
assert not torch.isnan(dist_local_res).any()
Expand All @@ -173,7 +172,7 @@ def test_placement_comb(
# dist_res.backward(grad_dist_res)
# self.assertIsNotNone(batch_1_dt.grad)
# batch_1_grad_local = batch_1_dt.grad.redistribute(
# device_mesh, [Replicate()]
# self.build_mesh, [Replicate()]
# ).to_local()
# self.assertEqual(batch_1_grad_local, batch_1_grad)

Expand Down Expand Up @@ -203,9 +202,9 @@ def test_placement_comb(
[spec[0]], [spec[1]], [spec[2]], beta, alpha, batch_1.grad
)

@with_comms

def test_bmm(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
self.build_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True)
mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
local_result = torch.bmm(mat1, mat2)
Expand All @@ -216,10 +215,10 @@ def test_placement_comb(
placements1: List[Placement],
placements2: List[Placement],
) -> None:
mat1_dt = distribute_tensor(mat1, device_mesh, placements1)
mat2_dt = distribute_tensor(mat2, device_mesh, placements2)
mat1_dt = distribute_tensor(mat1, self.build_mesh, placements1)
mat2_dt = distribute_tensor(mat2, self.build_mesh, placements2)
dist_res = cast(DTensor, torch.bmm(mat1_dt, mat2_dt)).redistribute(
device_mesh, [Replicate()]
self.build_mesh, [Replicate()]
)
dist_local_res = dist_res.to_local()
self.assertEqual(dist_local_res, local_result)
Expand All @@ -232,7 +231,7 @@ def test_placement_comb(
self.assertIsNotNone(mat1_dt.grad)
mat1_dt_grad = cast(DTensor, mat1_dt.grad)
mat1_grad_local = mat1_dt_grad.redistribute(
device_mesh, [Replicate()]
self.build_mesh, [Replicate()]
).to_local()
self.assertEqual(mat1_grad_local, mat1.grad)

Expand Down
Loading
Loading