Skip to content

Commit

Permalink
Allow for 1D inputs to pw and ew
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed May 19, 2019
1 parent efe3071 commit 3eaf9c4
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 16 deletions.
28 changes: 28 additions & 0 deletions lab/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,13 @@ def pw_dists2(a, b):
matrix: Square of the Euclidean norm of the pairwise differences
between the elements of `a` and `b`.
"""
a = B.uprank(a)
b = B.uprank(b)

# Optimise the one-dimensional case.
if B.shape(a)[1] == 1 and B.shape(b)[1] == 1:
return (a - B.transpose(b)) ** 2

norms_a = B.sum(a ** 2, axis=1)[:, None]
norms_b = B.sum(b ** 2, axis=1)[None, :]
return norms_a + norms_b - 2 * B.matmul(a, b, tr_b=True)
Expand All @@ -347,9 +351,13 @@ def pw_dists(a, b):
matrix: Euclidean norm of the pairwise differences between the
elements of `a` and `b`.
"""
a = B.uprank(a)
b = B.uprank(b)

# Optimise the one-dimensional case.
if B.shape(a)[1] == 1 and B.shape(b)[1] == 1:
return B.abs(a - B.transpose(b))

return B.sqrt(B.maximum(B.pw_dists2(a, b),
B.cast(B.dtype(a), 1e-30)))

Expand All @@ -373,6 +381,8 @@ def ew_dists2(a, b):
matrix: Square of the Euclidean norm of the element-wise differences
between the elements of `a` and `b`.
"""
a = B.uprank(a)
b = B.uprank(b)
return B.sum((a - b) ** 2, axis=1)[:, None]


Expand All @@ -394,9 +404,13 @@ def ew_dists(a, b):
matrix: Euclidean norm of the element-wise differences between the
elements of `a` and `b`.
"""
a = B.uprank(a)
b = B.uprank(b)

# Optimise the one-dimensional case.
if B.shape(a)[1] == 1 and B.shape(b)[1] == 1:
return B.abs(a - b)

return B.sqrt(B.maximum(B.ew_dists2(a, b),
B.cast(B.dtype(a), 1e-30)))

Expand All @@ -420,9 +434,13 @@ def pw_sums2(a, b):
matrix: Square of the Euclidean norm of the pairwise sums
between the elements of `a` and `b`.
"""
a = B.uprank(a)
b = B.uprank(b)

# Optimise the one-dimensional case.
if B.shape(a)[1] == 1 and B.shape(b)[1] == 1:
return (a + B.transpose(b)) ** 2

norms_a = B.sum(a ** 2, axis=1)[:, None]
norms_b = B.sum(b ** 2, axis=1)[None, :]
return norms_a + norms_b + 2 * B.matmul(a, b, tr_b=True)
Expand All @@ -446,9 +464,13 @@ def pw_sums(a, b):
matrix: Euclidean norm of the pairwise sums between the
elements of `a` and `b`.
"""
a = B.uprank(a)
b = B.uprank(b)

# Optimise the one-dimensional case.
if B.shape(a)[1] == 1 and B.shape(b)[1] == 1:
return B.abs(a + B.transpose(b))

return B.sqrt(B.maximum(B.pw_sums2(a, b),
B.cast(B.dtype(a), 1e-30)))

Expand All @@ -472,6 +494,8 @@ def ew_sums2(a, b):
matrix: Square of the Euclidean norm of the element-wise sums
between the elements of `a` and `b`.
"""
a = B.uprank(a)
b = B.uprank(b)
return B.sum((a + b) ** 2, axis=1)[:, None]


Expand All @@ -493,9 +517,13 @@ def ew_sums(a, b):
matrix: Euclidean norm of the element-wise sums between the
elements of `a` and `b`.
"""
a = B.uprank(a)
b = B.uprank(b)

# Optimise the one-dimensional case.
if B.shape(a)[1] == 1 and B.shape(b)[1] == 1:
return B.abs(a + b)

return B.sqrt(B.maximum(B.ew_sums2(a, b),
B.cast(B.dtype(a), 1e-30)))

Expand Down
40 changes: 24 additions & 16 deletions tests/test_linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,18 @@ def approx_allclose(a, b):

def test_pw_1d():
a, b = Matrix(5, 1).np(), Matrix(10, 1).np()
yield allclose, B.pw_dists2(a, b), np.abs(a - b.T) ** 2
yield allclose, B.pw_dists2(a), np.abs(a - a.T) ** 2
yield allclose, B.pw_dists(a, b), np.abs(a - b.T)
yield allclose, B.pw_dists(a), np.abs(a - a.T)
yield allclose, B.pw_sums2(a, b), np.abs(a + b.T) ** 2
yield allclose, B.pw_sums2(a), np.abs(a + a.T) ** 2
yield allclose, B.pw_sums(a, b), np.abs(a + b.T)
yield allclose, B.pw_sums(a), np.abs(a + a.T)

# Check that we can feed both rank 1 and rank 2 tensors.
for f, g in product(*([[lambda x: x, lambda x: x[:, 0]]] * 2)):

yield allclose, B.pw_dists2(f(a), g(b)), np.abs(a - b.T) ** 2
yield allclose, B.pw_dists2(f(a)), np.abs(a - a.T) ** 2
yield allclose, B.pw_dists(f(a), g(b)), np.abs(a - b.T)
yield allclose, B.pw_dists(f(a)), np.abs(a - a.T)
yield allclose, B.pw_sums2(f(a), g(b)), np.abs(a + b.T) ** 2
yield allclose, B.pw_sums2(f(a)), np.abs(a + a.T) ** 2
yield allclose, B.pw_sums(f(a), g(b)), np.abs(a + b.T)
yield allclose, B.pw_sums(f(a)), np.abs(a + a.T)


def test_ew_2d():
Expand All @@ -213,11 +217,15 @@ def test_ew_2d():

def test_ew_1d():
a, b = Matrix(10, 1).np(), Matrix(10, 1).np()
yield allclose, B.ew_dists2(a, b), np.abs(a - b) ** 2
yield allclose, B.ew_dists2(a), np.zeros((10, 1))
yield allclose, B.ew_dists(a, b), np.abs(a - b)
yield allclose, B.ew_dists(a), np.zeros((10, 1))
yield allclose, B.ew_sums2(a, b), np.abs(a + b) ** 2
yield allclose, B.ew_sums2(a), np.abs(a + a) ** 2
yield allclose, B.ew_sums(a, b), np.abs(a + b)
yield allclose, B.ew_sums(a), np.abs(a + a)

# Check that we can feed both rank 1 and rank 2 tensors.
for f, g in product(*([[lambda x: x, lambda x: x[:, 0]]] * 2)):

yield allclose, B.ew_dists2(f(a), g(b)), np.abs(a - b) ** 2
yield allclose, B.ew_dists2(f(a)), np.zeros((10, 1))
yield allclose, B.ew_dists(f(a), g(b)), np.abs(a - b)
yield allclose, B.ew_dists(f(a)), np.zeros((10, 1))
yield allclose, B.ew_sums2(f(a), g(b)), np.abs(a + b) ** 2
yield allclose, B.ew_sums2(f(a)), np.abs(a + a) ** 2
yield allclose, B.ew_sums(f(a), g(b)), np.abs(a + b)
yield allclose, B.ew_sums(f(a)), np.abs(a + a)

0 comments on commit 3eaf9c4

Please sign in to comment.