Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Jul 15, 2023
1 parent 0e3d5b6 commit 19bc63a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
10 changes: 10 additions & 0 deletions test/test_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,16 @@ def test_unbind(self):
expect = torch.unbind(a, 1)
self.assertEqual(actual, expect)

def test_logspace_with_complex_input(self):
actual = refs.logspace(2, 10 + 5j, steps=5)
expect = torch.logspace(2, 10 + 5j, steps=5)
self.assertEqual(actual, expect)

def test_linspace_with_complex_input(self):
actual = refs.linspace(2, 10 + 5j, steps=5)
expect = torch.linspace(2, 10 + 5j, steps=5)
self.assertEqual(actual, expect)


instantiate_device_type_tests(TestRefs, globals())

Expand Down
2 changes: 0 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,6 @@ def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs):
yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device})

yield SampleInput(1, args=(3, 1))
yield SampleInput(1 + 2j, args=(10 + 5j, 5)) # inputs of complex type


def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs):
Expand All @@ -1022,7 +1021,6 @@ def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs):
yield SampleInput(start, args=(end, nstep, base), kwargs={"dtype": dtype, "device": device})

yield SampleInput(1, args=(3, 1, 2.))
yield SampleInput(1 + 2j, args=(10 + 5j, 5)) # inputs of complex type


def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs):
Expand Down

0 comments on commit 19bc63a

Please sign in to comment.