diff --git a/sparse/coo/common.py b/sparse/coo/common.py index 741450b8..f189f581 100644 --- a/sparse/coo/common.py +++ b/sparse/coo/common.py @@ -184,10 +184,17 @@ def matmul(a, b): "Cannot perform dot product on types %s, %s" % (type(a), type(b))) - # When one of the input is less than 2-d, it is equivalent to dot - if a.ndim <= 2 or b.ndim <= 2: + # When b is 2-d, it is equivalent to dot + if b.ndim <= 2: return dot(a, b) + # when a is 2-d, we need to transpose result after dot + if a.ndim <= 2: + res = dot(a, b) + axes = list(range(res.ndim)) + axes.insert(-1, axes.pop(0)) + return res.transpose(axes) + # If a can be squeeze to a vector, use dot will be faster if a.ndim <= b.ndim and np.prod(a.shape[:-1]) == 1: res = dot(a.reshape(-1), b) diff --git a/sparse/tests/test_coo.py b/sparse/tests/test_coo.py index 40de0e5d..79c70b6e 100644 --- a/sparse/tests/test_coo.py +++ b/sparse/tests/test_coo.py @@ -283,6 +283,7 @@ def test_tensordot(a_shape, b_shape, axes): ((5,), (5, 6)), ((4, 5), (5,)), ((5,), (5,)), + ((3, 4), (1, 2, 4, 3)), ]) def test_matmul(a_shape, b_shape): sa = sparse.random(a_shape, density=0.5)