Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated docs/test for dot and vdot #47242

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
94 changes: 92 additions & 2 deletions test/test_linalg.py
Expand Up @@ -3,13 +3,13 @@
import itertools
import warnings
from math import inf, nan, isnan
from random import randrange
from random import randint, randrange

from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, dtypesIfCUDA,
onlyCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride)
onlyCUDA, onlyOnCPUAndCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride)
from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args
from torch.autograd import gradcheck

Expand Down Expand Up @@ -1037,6 +1037,96 @@ def test_tensorsolve_errors_and_warnings(self, device, dtype):
with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"):
torch.linalg.tensorsolve(a, b, out=out)

def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y):
y_np = y.cpu().numpy()

# `compare_with_numpy` takes care of moving `x` to correct device for calling np_fn.
self.compare_with_numpy(lambda inp: torch_fn(inp, y), lambda inp: np_fn(inp, y_np), x)

# Test out variant
ref = torch_fn(x, y)
out = torch.zeros_like(ref)
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
torch_fn(x, y, out=out)
self.assertEqual(ref, out)

# Use this tensor for out variant tests.
out = torch.randn((), dtype=dtype, device=device)
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved

def compare_out_variant(torch_fn, x, y):
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
torch_fn(v1, v2, out=out)
self.assertEqual(torch_fn(v1, v2), out)

for _ in range(10):
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
numel = randint(10, 1000)
v1 = torch.randn(numel, dtype=dtype, device=device)
v2 = torch.randn(numel, dtype=dtype, device=device)
compare_with_numpy_bin_op(torch_fn, np_fn, v1, v2)
compare_out_variant(torch_fn, v1, v2)

# Test 0-strided
v3 = torch.randn(1, dtype=dtype, device=device).expand(numel)
compare_with_numpy_bin_op(torch_fn, np_fn, v1, v3)
compare_out_variant(torch_fn, v1, v3)

compare_with_numpy_bin_op(torch_fn, np_fn, v3, v1)
compare_out_variant(torch_fn, v3, v1)

# Test stride greater than 1
v4 = torch.randn(numel, numel, dtype=dtype, device=device)[:, numel - 1]
compare_with_numpy_bin_op(torch_fn, np_fn, v1, v4)
compare_out_variant(torch_fn, v1, v4)

compare_with_numpy_bin_op(torch_fn, np_fn, v4, v1)
compare_out_variant(torch_fn, v4, v1)

@precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
def test_dot_vs_numpy(self, device, dtype):
self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot)

@precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_vdot_vs_numpy(self, device, dtype):
self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot)

def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False):
if complex_dtypes:
x = torch.randn(1, dtype=torch.cfloat, device=device)
y = torch.randn(3, dtype=torch.cdouble, device=device)
else:
x = torch.randn(1, dtype=torch.float, device=device)
y = torch.randn(3, dtype=torch.double, device=device)

with self.assertRaisesRegex(RuntimeError,
'dot : expected both vectors to have same dtype'):
torch_fn(x, y)

with self.assertRaisesRegex(RuntimeError,
'1D tensors expected'):
torch_fn(x.reshape(1, 1), y)

with self.assertRaisesRegex(RuntimeError,
'inconsistent tensor size'):
torch_fn(x.expand(9), y.to(x.dtype))

if self.device_type != 'cpu':
x_cpu = x.expand(3).cpu()

with self.assertRaisesRegex(RuntimeError,
'expected all tensors to be on the same device'):
torch_fn(x_cpu, y.to(x.dtype))

@onlyOnCPUAndCUDA
def test_vdot_invalid_args(self, device):
self._test_dot_vdot_invalid_args(device, torch.vdot)
self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True)

@onlyOnCPUAndCUDA
def test_dot_invalid_args(self, device):
self._test_dot_vdot_invalid_args(device, torch.dot)
self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True)

instantiate_device_type_tests(TestLinalg, globals())

if __name__ == '__main__':
Expand Down
84 changes: 0 additions & 84 deletions test/test_torch.py
Expand Up @@ -17036,90 +17036,6 @@ def test_matmul_45724(self, device):
torch.matmul(a, b, out=c)
self.assertEqual(c, cpu_result)

def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y):
y_np = y.cpu().numpy()

# `compare_with_numpy` takes care of moving `x` to correct device for calling np_fn.
self.compare_with_numpy(lambda inp: torch_fn(inp, y), lambda inp: np_fn(inp, y_np), x)

# Use this tensor for out variant tests.
out = torch.randn((), dtype=dtype, device=device)

def compare_out_variant(torch_fn, x, y):
torch_fn(v1, v2, out=out)
self.assertEqual(torch_fn(v1, v2), out)

for _ in range(10):
numel = random.randint(10, 1000)
v1 = torch.randn(numel, dtype=dtype, device=device)
v2 = torch.randn(numel, dtype=dtype, device=device)
compare_with_numpy_bin_op(torch_fn, np_fn, v1, v2)
compare_out_variant(torch_fn, v1, v2)

# Test 0-strided
v3 = torch.randn(1, dtype=dtype, device=device).expand(numel)
compare_with_numpy_bin_op(torch_fn, np_fn, v1, v3)
compare_out_variant(torch_fn, v1, v3)

compare_with_numpy_bin_op(torch_fn, np_fn, v3, v1)
compare_out_variant(torch_fn, v3, v1)

# Test stride greater than 1
v4 = torch.randn(numel, numel, dtype=dtype, device=device)[:, numel - 1]
compare_with_numpy_bin_op(torch_fn, np_fn, v1, v4)
compare_out_variant(torch_fn, v1, v4)

compare_with_numpy_bin_op(torch_fn, np_fn, v4, v1)
compare_out_variant(torch_fn, v4, v1)

@precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_dot_vs_numpy(self, device, dtype):
self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot)

@precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_vdot_vs_numpy(self, device, dtype):
self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot)

def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False):
if complex_dtypes:
x = torch.randn(1, dtype=torch.cfloat, device=device)
y = torch.randn(3, dtype=torch.cdouble, device=device)
else:
x = torch.randn(1, dtype=torch.float, device=device)
y = torch.randn(3, dtype=torch.double, device=device)

with self.assertRaisesRegex(RuntimeError,
'dot : expected both vectors to have same dtype'):
torch_fn(x, y)

with self.assertRaisesRegex(RuntimeError,
'1D tensors expected'):
torch_fn(x.reshape(1, 1), y)

with self.assertRaisesRegex(RuntimeError,
'inconsistent tensor size'):
torch_fn(x.expand(9), y.to(x.dtype))

if self.device_type != 'cpu':
x_cpu = x.expand(3).cpu()

with self.assertRaisesRegex(RuntimeError,
'expected all tensors to be on the same device'):
torch_fn(x_cpu, y.to(x.dtype))

@onlyOnCPUAndCUDA
def test_vdot_invalid_args(self, device):
self._test_dot_vdot_invalid_args(device, torch.vdot)
self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True)

@onlyOnCPUAndCUDA
def test_dot_invalid_args(self, device):
self._test_dot_vdot_invalid_args(device, torch.dot)
self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True)

@onlyCPU
@slowTest
@dtypes(torch.float)
Expand Down
4 changes: 2 additions & 2 deletions torch/_tensor_docs.py
Expand Up @@ -1210,7 +1210,7 @@ def add_docstr_all(method, docstr):

add_docstr_all('dot',
r"""
dot(tensor2) -> Tensor
dot(other) -> Tensor

See :func:`torch.dot`
""")
Expand Down Expand Up @@ -4004,7 +4004,7 @@ def callable(a, b) -> number

add_docstr_all('vdot',
r"""
dot(other) -> Tensor
vdot(other) -> Tensor

See :func:`torch.vdot`
""")
Expand Down
27 changes: 20 additions & 7 deletions torch/_torch_docs.py
Expand Up @@ -2563,11 +2563,21 @@ def merge_dicts(*dicts):

add_docstr(torch.dot,
r"""
dot(input, tensor) -> Tensor
dot(input, other, *, out=None) -> Tensor

Computes the dot product (inner product) of two tensors.
Computes the dot product (inner product) of two 1D tensors.
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved

.. note:: This function does not :ref:`broadcast <broadcasting-semantics>`.
.. note::

Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product
of two 1D tensors with the same number of elements.

Args:
input (Tensor): first tensor in the dot product, must be 1D.
other (Tensor): second tensor in the dot product, must be 1D.

Keyword args:
{out}

Example::

Expand All @@ -2579,15 +2589,18 @@ def merge_dicts(*dicts):
r"""
vdot(input, other, *, out=None) -> Tensor

Computes the dot product (inner product) of two tensors. The vdot(a, b) function
Computes the dot product (inner product) of two 1D tensors. The vdot(a, b) function
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
handles complex numbers differently than dot(a, b). If the first argument is complex,
the complex conjugate of the first argument is used for the calculation of the dot product.

.. note:: This function does not :ref:`broadcast <broadcasting-semantics>`.
.. note::

Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product
of two 1D tensors with the same number of elements.

Args:
input (Tensor): first tensor in the dot product. Its conjugate is used if it's complex.
other (Tensor): second tensor in the dot product.
input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex.
other (Tensor): second tensor in the dot product, must be 1D.

Keyword args:
{out}
Expand Down
4 changes: 2 additions & 2 deletions torch/overrides.py
Expand Up @@ -320,7 +320,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.dist: lambda input, other, p=2: -1,
torch.div: lambda input, other, out=None: -1,
torch.divide: lambda input, other, out=None: -1,
torch.dot: lambda mat1, mat2: -1,
torch.dot: lambda input, other, out=None: -1,
torch.dropout: lambda input, p, train, inplace=False: -1,
torch.dsmm: lambda input, mat2: -1,
torch.hsmm: lambda mat1, mat2: -1,
Expand Down Expand Up @@ -678,7 +678,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.ravel: lambda input: -1,
torch.real: lambda input, out=None: -1,
torch.vdot: lambda mat1, mat2: -1,
torch.vdot: lambda input, other, out=None: -1,
torch.view_as_real: lambda input: -1,
torch.view_as_complex: lambda input: -1,
torch.reciprocal: lambda input, out=None: -1,
Expand Down