Skip to content

Commit

Permalink
Add testing for foreach scalar Tensor overloads in inductor
Browse files Browse the repository at this point in the history
ghstack-source-id: 35e81cd659e1c42d9378309a09c26e95f0278284
Pull Request resolved: #111600
  • Loading branch information
janeyx99 committed Oct 19, 2023
1 parent 985f8ec commit 8ab1659
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions test/inductor/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__)
scalar_bin_ops = parametrize("op", bin_ops_under_test[:4], name_fn=lambda f: f.__name__)
scalar_tensor_bin_ops = parametrize("op", bin_ops_under_test[:2], name_fn=lambda f: f.__name__)
decomp_ops = parametrize("op", compose_ops, name_fn=lambda f: f.__name__)


Expand Down Expand Up @@ -115,6 +116,18 @@ def fn(a0, a1):
torch.rand(20, 20, device="cuda:0"),
),
)

def _test_single_scalar_tensor(self, op):
def fn(a0, a1):
return op([a0, a1], torch.tensor(3.3, device="cuda:0"))

self.check_model_cuda(
fn,
(
torch.rand(10, 10, device="cuda:0"),
torch.rand(20, 20, device="cuda:0"),
),
)

# called in test_cpp_wrapper.py
@requires_cuda()
Expand All @@ -132,6 +145,12 @@ def test_single_list(self, op):
def test_single_scalar(self, op):
self._test_single_scalar(op)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)

@requires_cuda()
@scalar_tensor_bin_ops
def test_single_scalar_tensor(self, op):
self._test_single_scalar_tensor(op)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)

@requires_cuda()
@all_ops
Expand Down

0 comments on commit 8ab1659

Please sign in to comment.