Skip to content

Commit

Permalink
Improve the debugging message for when foreach mta_called
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
janeyx99 committed Jun 18, 2024
1 parent 44722c6 commit 681ec37
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs):
mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
assert mta_called == (
expect_fastpath and (not zero_size)
), f"{mta_called=}, {expect_fastpath=}, {zero_size=}"
), f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}"
else:
actual = self.func(*inputs, **kwargs)
if self.is_inplace:
Expand Down Expand Up @@ -205,7 +205,6 @@ def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op):
"failing flakily on non sm86 cuda jobs",
)
def test_parity(self, device, dtype, op, noncontiguous, inplace):
torch.manual_seed(2024)
if inplace:
_, _, func, ref = self._get_funcs(op)
else:
Expand Down Expand Up @@ -585,7 +584,6 @@ def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
"failing flakily on non sm86 cuda jobs, ex https://github.com/pytorch/pytorch/issues/125035",
)
def test_binary_op_list_error_cases(self, device, dtype, op):
torch.manual_seed(202406)
foreach_op, foreach_op_, ref, ref_ = (
op.method_variant,
op.inplace_variant,
Expand Down Expand Up @@ -680,7 +678,6 @@ def test_binary_op_list_error_cases(self, device, dtype, op):
"failing flakily on non sm86 cuda jobs, ex https://github.com/pytorch/pytorch/issues/125775",
)
def test_binary_op_list_slow_path(self, device, dtype, op):
torch.manual_seed(20240607)
foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
# 0-strides
tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
Expand Down Expand Up @@ -799,7 +796,6 @@ def test_binary_op_list_slow_path(self, device, dtype, op):
"failing flakily on non sm86 cuda jobs",
)
def test_binary_op_float_inf_nan(self, device, dtype, op):
torch.manual_seed(2024)
inputs = (
[
torch.tensor([float("inf")], device=device, dtype=dtype),
Expand Down Expand Up @@ -869,9 +865,6 @@ def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
"failing flakily on non sm86 cuda jobs",
)
def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
torch.manual_seed(202406)
# `tensors1`: ['cuda', 'cpu']
# `tensors2`: ['cuda', 'cpu']
_cuda_tensors = next(
iter(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True))
).input
Expand Down

0 comments on commit 681ec37

Please sign in to comment.