Skip to content

Commit

Permalink
Fix transpose and batching bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed May 21, 2019
1 parent 155ee95 commit 453e2db
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 32 deletions.
14 changes: 9 additions & 5 deletions lab/autograd/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
from ..custom import toeplitz_solve, s_toeplitz_solve
from ..linear_algebra import _default_perm
from ..types import NPNumeric
from ..util import batch_computation

__all__ = []
log = logging.getLogger(__name__)


@dispatch(NPNumeric, NPNumeric)
def matmul(a, b, tr_a=False, tr_b=False):
a = a.T if tr_a else a
b = b.T if tr_b else b
a = transpose(a) if tr_a else a
b = transpose(b) if tr_b else b
return anp.matmul(a, b)


Expand Down Expand Up @@ -66,7 +67,7 @@ def kron(a, b):
@dispatch(NPNumeric)
def svd(a, compute_uv=True):
res = anp.linalg.svd(a, full_matrices=False, compute_uv=compute_uv)
return (res[0], res[1], res[2].T.conj()) if compute_uv else res
return (res[0], res[1], transpose(res[2]).conj()) if compute_uv else res


@dispatch(NPNumeric, NPNumeric)
Expand Down Expand Up @@ -96,12 +97,15 @@ def cholesky(a):

@dispatch(NPNumeric, NPNumeric)
def cholesky_solve(a, b):
return triangular_solve(a.T, triangular_solve(a, b), lower_a=False)
return triangular_solve(transpose(a), triangular_solve(a, b), lower_a=False)


@dispatch(NPNumeric, NPNumeric)
def triangular_solve(a, b, lower_a=True):
return asla.solve_triangular(a, b, trans='N', lower=lower_a)
def _triangular_solve(a_, b_):
return asla.solve_triangular(a_, b_, trans='N', lower=lower_a)

return batch_computation(_triangular_solve, a, b)


f = autograd_register(toeplitz_solve, s_toeplitz_solve)
Expand Down
3 changes: 3 additions & 0 deletions lab/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def kron(a, b): # pragma: no cover
def svd(a, compute_uv=True): # pragma: no cover
"""Compute the singular value decomposition.
Note:
PyTorch does not allow batch computation.
Args:
a (tensor): Matrix to decompose.
compute_uv (bool, optional): Also compute `U` and `V`. Defaults to
Expand Down
9 changes: 5 additions & 4 deletions lab/torch/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from ..custom import toeplitz_solve, s_toeplitz_solve
from ..linear_algebra import _default_perm
from ..types import TorchNumeric
from ..util import batch_computation

__all__ = []


@dispatch(TorchNumeric, TorchNumeric)
def matmul(a, b, tr_a=False, tr_b=False):
a = a.t() if tr_a else a
b = b.t() if tr_b else b
a = transpose(a) if tr_a else a
b = transpose(b) if tr_b else b
return torch.matmul(a, b)


Expand Down Expand Up @@ -71,12 +72,12 @@ def inv(a):

@dispatch(TorchNumeric)
def det(a):
return torch.det(a)
return batch_computation(torch.det, a)


@dispatch(TorchNumeric)
def logdet(a):
return torch.logdet(a)
return batch_computation(torch.logdet, a)


@dispatch(TorchNumeric)
Expand Down
34 changes: 32 additions & 2 deletions lab/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,43 @@

from __future__ import absolute_import, division, print_function

from functools import wraps
from functools import wraps, reduce
from operator import mul

import plum

from . import B

__all__ = ['abstract']
__all__ = ['batch_computation', 'abstract']


def batch_computation(f, *xs):
"""Apply a function over all the batches of the arguments, where the
arguments are assumed to be matrices or batches of matrices.
Args:
*xs (tensor): Matrices or batches of matrices.
Returns:
tensor: Result in batched form.
"""
# Reshape arguments for batched computation.
batch_shapes = [B.shape(x)[:-2] for x in xs]
xs = [B.reshape(x, -1, *B.shape(x)[-2:]) for x in xs]

# Check that all batch shapes are the same.
if not all(s == batch_shapes[0] for s in batch_shapes[1:]):
raise ValueError('Inconsistent batch shapes.')
batch_shape = batch_shapes[0]

# Loop over batches.
batches = []
for i in range(reduce(mul, batch_shape, 1)):
batches.append(f(*[x[i, :, :] for x in xs]))

# Construct result, reshape, and return.
res = B.stack(*batches, axis=0)
return B.reshape(res, *(batch_shape + B.shape(res)[1:]))


def abstract(promote=-1):
Expand Down
40 changes: 29 additions & 11 deletions tests/test_linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

import lab as B
from . import check_function, Matrix, Bool, Value, PSD, Tensor
from . import check_function, Matrix, Bool, Value, PSD, Tensor, PSDTriangular
# noinspection PyUnresolvedReferences
from . import eq, neq, lt, le, ge, gt, raises, call, ok, allclose, approx, is_

Expand Down Expand Up @@ -48,7 +48,10 @@ def test_transpose():
def test_matmul():
for f in [B.matmul, B.mm, B.dot]:
yield check_function, f, \
(Matrix(), Matrix()), {'tr_a': Bool(), 'tr_b': Bool()}
(Tensor(3, 3), Tensor(3, 3)), {'tr_a': Bool(), 'tr_b': Bool()}
yield check_function, f, \
(Tensor(4, 3, 3), Tensor(4, 3, 3)), \
{'tr_a': Bool(), 'tr_b': Bool()}


def test_trace():
Expand All @@ -70,6 +73,8 @@ def test_trace():

def test_kron():
yield check_function, B.kron, (Tensor(2, 3), Tensor(4, 5))
# Cannot test tensors of higher rank, because TensorFlows broadcasting
# behaviour does not allow that.
yield raises, ValueError, \
lambda: B.kron(Tensor(2).tf(), Tensor(4, 5).tf())
yield raises, ValueError, \
Expand All @@ -85,49 +90,62 @@ def svd(a, compute_uv=True):
else:
return B.svd(a, compute_uv=False)

yield check_function, svd, (Matrix(3, 2),), {'compute_uv': Bool()}
yield check_function, svd, (Tensor(3, 2),), {'compute_uv': Bool()}
# Torch does not allow batch computation.


def test_solve():
yield check_function, B.solve, (Matrix(), Matrix())
yield check_function, B.solve, (Matrix(3, 3), Matrix(3, 4))
yield check_function, B.solve, (Matrix(5, 3, 3), Matrix(5, 3, 4))


def test_inv():
yield check_function, B.inv, (Matrix(),)
yield check_function, B.inv, (Matrix(4, 3, 3),)


def test_det():
yield check_function, B.det, (Matrix(),)
yield check_function, B.det, (Matrix(4, 3, 3),)


def test_logdet():
yield check_function, B.logdet, (PSD(),)
yield check_function, B.logdet, (PSD(4, 3, 3),)


def test_cholesky():
for f in [B.cholesky, B.chol]:
yield check_function, f, (PSD(),)
yield check_function, f, (PSD(4, 3, 3),)


def test_cholesky_solve():
chol = B.cholesky(PSD().np())
for f in [B.cholesky_solve, B.cholsolve]:
yield check_function, f, (Matrix(mat=chol), Matrix())
yield check_function, f, (PSDTriangular(3, 3), Matrix(3, 4))
yield check_function, f, (PSDTriangular(5, 3, 3), Matrix(5, 3, 4))


def test_triangular_solve():
chol = B.cholesky(PSD().np())
for f in [B.triangular_solve, B.trisolve]:
yield check_function, f, \
(Matrix(mat=chol), Matrix()), {'lower_a': Value(True)}
(PSDTriangular(3, 3), Matrix(3, 4)), \
{'lower_a': Value(True)}
yield check_function, f, \
(PSDTriangular(5, 3, 3), Matrix(5, 3, 4)), \
{'lower_a': Value(True)}
yield check_function, f, \
(PSDTriangular(3, 3, upper=True), Matrix(3, 4)), \
{'lower_a': Value(False)}
yield check_function, f, \
(Matrix(mat=chol.T), Matrix()), {'lower_a': Value(False)}
(PSDTriangular(5, 3, 3, upper=True), Matrix(5, 3, 4)), \
{'lower_a': Value(False)}


def test_toeplitz_solve():
for f in [B.toeplitz_solve, B.toepsolve]:
yield check_function, f, (Tensor(3), Matrix(3))
yield check_function, f, (Tensor(3), Matrix(3))
yield check_function, f, (Tensor(3), Tensor(2), Matrix(3, 4))
yield check_function, f, (Tensor(3), Matrix(3, 4))


def test_outer():
Expand Down
11 changes: 10 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,20 @@
from plum import NotFoundLookupError

import lab as B
from lab.util import abstract
from lab.util import abstract, batch_computation
# noinspection PyUnresolvedReferences
from . import eq, neq, lt, le, ge, gt, raises, call, ok, allclose, approx


def test_batch_computation():
# Correctness is already checked by usage in linear algebra functions. Here
# we test the check of batch shapes.
yield raises, ValueError, \
lambda: batch_computation(None, B.randn(3, 4, 4), B.randn(2, 4, 4))
yield raises, ValueError, \
lambda: batch_computation(None, B.randn(2, 2, 4, 4), B.randn(2, 4, 4))


def test_metadata():
# Test that the name and docstrings for functions are available.
yield eq, B.transpose.__name__, 'transpose'
Expand Down
50 changes: 41 additions & 9 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,50 @@ def __init__(self, *dims, **kw_args):
class Matrix(Tensor):
"""Matrix placeholder."""

def __init__(self, rows=3, cols=None, mat=None):
# Default the number of columns to the number of rows.
cols = rows if cols is None else cols
Tensor.__init__(self, rows, cols, mat=mat)
def __init__(self, *shape, **kw_args):
# Handle shorthands.
if shape == ():
shape = (3, 3)
elif len(shape) == 1:
shape = shape * 2

Tensor.__init__(self, *shape, **kw_args)

class PSD(Matrix):
"""Positive-definite matrix placeholder."""

def __init__(self, rows=3):
a = np.random.randn(rows, rows)
Matrix.__init__(self, mat=np.matmul(a, np.transpose(a)))
class PSD(Matrix):
"""Positive-definite tensor placeholder."""

def __init__(self, *shape):
# Handle shorthands.
if shape == ():
shape = (3, 3)
elif len(shape) == 1:
shape = shape * 2

if not shape[-2] == shape[-1]:
raise ValueError('PSD matrix must be square.')

a = np.random.randn(*shape)
perm = list(range(len(a.shape)))
perm[-2], perm[-1] = perm[-1], perm[-2]
a_t = np.transpose(a, perm)
Matrix.__init__(self, mat=np.matmul(a, a_t))


class PSDTriangular(PSD):
def __init__(self, *shape, **kw_args):
PSD.__init__(self, *shape)

# Zero upper triangular part.
for i in range(self.mat.shape[0]):
for j in range(i + 1, self.mat.shape[1]):
self.mat[..., i, j] = 0

# Create upper-triangular matrices, if asked for.
if kw_args.get('upper', False):
perm = list(range(len(self.mat.shape)))
perm[-2], perm[-1] = perm[-1], perm[-2]
self.mat = np.transpose(self.mat, perm)


class Tuple(object):
Expand Down

0 comments on commit 453e2db

Please sign in to comment.