Skip to content

Commit

Permalink
Check inputs have same dtype in addmm_impl_cpu_ even if input has zer…
Browse files Browse the repository at this point in the history
…o numel (#100274)

Fixes #99226

When an inputs has zero numel, addmm_impl_cpu_'s check that the inputs have the same dtype are bypassed. This PR adds a check before  the early return.
Pull Request resolved: #100274
Approved by: https://github.com/ngimel
  • Loading branch information
soulitzer authored and pytorchmergebot committed Apr 29, 2023
1 parent d7fa7fa commit 6a02342
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 0 deletions.
5 changes: 5 additions & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,11 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
static void addmm_impl_cpu_(
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);

TORCH_CHECK(
m1.dtype() == m2.dtype(),
"expected m1 and m2 to have the same dtype, but got: ", m1.dtype(), " != ", m2.dtype()
)
// Array access is faster than .size(n) and .stride(n)
const auto self_sizes = self.sizes();
auto m1_strides = m1.strides();
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
// preflights a check to try to avoid actually needing to call
// expand().
TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(
mat1.dtype() == mat2.dtype(),
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
)

TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU(__func__, args);
Expand Down
7 changes: 7 additions & 0 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5832,6 +5832,13 @@ def t_b(tensor):
self.assertEqual(torch.bmm(Ab_conj, Bb_, out=out_b), torch.bmm(Ab_conj_physical, Bb_, out=out_b))
self.assertEqual(torch.bmm(t_b(Ab_conj), Bb_, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb_, out=out_b))

@onlyNativeDeviceTypes
def test_mm_empty_inputs_mixed_dtype_errors(self, device):
a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
b = torch.randn(10, 20, dtype=torch.float32, device=device)
with self.assertRaisesRegex(RuntimeError, "expected .* and .* to have the same dtype, but got:"):
torch.mm(a, b)

@onlyNativeDeviceTypes
@dtypes(torch.float32, torch.float64)
def test_strided_mm_bmm(self, device, dtype):
Expand Down

0 comments on commit 6a02342

Please sign in to comment.