From c0cfe799e6fcacf7773eb08532de7a7141d3584a Mon Sep 17 00:00:00 2001 From: Liyu Gong Date: Tue, 11 Dec 2018 12:22:41 -0500 Subject: [PATCH 1/3] Fix matmul shape when a.ndim <= 2 --- sparse/coo/common.py | 8 +++++++- sparse/tests/test_coo.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sparse/coo/common.py b/sparse/coo/common.py index 741450b8..3d2ce9d7 100644 --- a/sparse/coo/common.py +++ b/sparse/coo/common.py @@ -185,9 +185,15 @@ def matmul(a, b): (type(a), type(b))) # When one of the input is less than 2-d, it is equivalent to dot - if a.ndim <= 2 or b.ndim <= 2: + if b.ndim <= 2: return dot(a, b) + if a.ndim <= 2: + res = dot(a, b) + axes = list(range(res.ndim)) + axes.insert(-1, axes.pop(0)) + return res.transpose(axes) + # If a can be squeeze to a vector, use dot will be faster if a.ndim <= b.ndim and np.prod(a.shape[:-1]) == 1: res = dot(a.reshape(-1), b) diff --git a/sparse/tests/test_coo.py b/sparse/tests/test_coo.py index 40de0e5d..79c70b6e 100644 --- a/sparse/tests/test_coo.py +++ b/sparse/tests/test_coo.py @@ -283,6 +283,7 @@ def test_tensordot(a_shape, b_shape, axes): ((5,), (5, 6)), ((4, 5), (5,)), ((5,), (5,)), + ((3, 4), (1, 2, 4, 3)), ]) def test_matmul(a_shape, b_shape): sa = sparse.random(a_shape, density=0.5) From d251ea5cdf664eac50c406096f601909476627e0 Mon Sep 17 00:00:00 2001 From: Liyu Gong Date: Tue, 11 Dec 2018 12:25:00 -0500 Subject: [PATCH 2/3] Modify comment --- sparse/coo/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sparse/coo/common.py b/sparse/coo/common.py index 3d2ce9d7..a62ce483 100644 --- a/sparse/coo/common.py +++ b/sparse/coo/common.py @@ -184,10 +184,11 @@ def matmul(a, b): "Cannot perform dot product on types %s, %s" % (type(a), type(b))) - # When one of the input is less than 2-d, it is equivalent to dot + # When b is 2-d, it is equivalent to dot if b.ndim <= 2: return dot(a, b) + # when a is 2-d, we need to transpose result after dot if a.ndim <= 2: res = dot(a, b) axes = list(range(res.ndim)) From 5cc73cc4cfee1556c0a281c109c3dfbb3faf0ca2 Mon Sep 17 00:00:00 2001 From: Liyu Gong Date: Tue, 11 Dec 2018 12:35:28 -0500 Subject: [PATCH 3/3] Fix a flake8 problem --- sparse/coo/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sparse/coo/common.py b/sparse/coo/common.py index a62ce483..f189f581 100644 --- a/sparse/coo/common.py +++ b/sparse/coo/common.py @@ -194,7 +194,7 @@ def matmul(a, b): axes = list(range(res.ndim)) axes.insert(-1, axes.pop(0)) return res.transpose(axes) - + # If a can be squeeze to a vector, use dot will be faster if a.ndim <= b.ndim and np.prod(a.shape[:-1]) == 1: res = dot(a.reshape(-1), b)