Skip to content

Commit

Permalink
Remove shape_int
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed May 19, 2019
1 parent 3eaf9c4 commit 7522962
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 50 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ randn(ref)
### Shaping
```
shape(a)
shape_int(a)
rank(a)
length(a) (alias: size)
isscalar(a)
Expand Down
49 changes: 49 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from time import time

import autograd.numpy as np

import lab as B

n = 20
m = 1

t = np.float64
eps = B.cast(t, B.epsilon)


def f1(x):
dists2 = (x - B.transpose(x)) ** 2
K = B.exp(-0.5 * dists2)
K = K + B.epsilon * B.eye(t, n)
L = B.cholesky(K)
return B.matmul(L, B.ones(t, n, m))


def f2(x):
dists2 = (x - np.transpose(x)) ** 2
K = np.exp(-0.5 * dists2)
K = K + B.epsilon * np.eye(n, dtype=t)
L = np.linalg.cholesky(K)
return np.matmul(L, np.ones((n, m)))


# Perform computation once.
x = np.linspace(0, 1, n, dtype=t)[:, None]
f1(x)
f2(x)

its = 10000

s = time()
for _ in range(its):
z = f2(x)
us_native = (time() - s) / its * 1e6

s = time()
for _ in range(its):
z = f1(x)
us_lab = (time() - s) / its * 1e6

print('Overhead: {:.1f} us / {:.1f} %'
''.format(us_lab - us_native,
100 * (us_lab / us_native - 1)))
2 changes: 1 addition & 1 deletion lab/autograd/shaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def vec_to_tril(a):
def tril_to_vec(a):
if B.rank(a) != 2:
raise ValueError('Ianput must be rank 2.')
n, m = B.shape_int(a)
n, m = B.shape(a)
if n != m:
raise ValueError('Ianput must be square.')
return a[anp.tril_indices(n)]
Expand Down
16 changes: 1 addition & 15 deletions lab/shaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .util import abstract

__all__ = ['shape',
'shape_int',
'rank',
'length', 'size',
'isscalar',
Expand Down Expand Up @@ -43,19 +42,6 @@ def shape(a): # pragma: no cover
return ()


@dispatch(Numeric)
def shape_int(a): # pragma: no cover
"""Get the shape of a tensor as a tuple of integers.
Args:
a (tensor): Tensor.
Returns:
tuple: Shape of `a` as a tuple of integers.
"""
return shape(a)


@dispatch(Numeric)
def rank(a): # pragma: no cover
"""Get the shape of a tensor.
Expand Down Expand Up @@ -173,7 +159,7 @@ def flatten(a): # pragma: no cover


def _vec_to_tril_shape(a):
n = shape_int(a)[0]
n = int(shape(a)[0]) # Dimensions are not necessarily integers!
return int(((1 + 8 * n) ** .5 - 1) / 2)


Expand Down
4 changes: 2 additions & 2 deletions lab/tensorflow/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def trace(a, axis1=0, axis2=1):

@dispatch(TFNumeric, TFNumeric)
def kron(a, b):
shape_a = B.shape_int(a)
shape_b = B.shape_int(b)
shape_a = B.shape(a)
shape_b = B.shape(b)

# Check that ranks are equal.
if len(shape_a) != len(shape_b):
Expand Down
9 changes: 2 additions & 7 deletions lab/tensorflow/shaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
__all__ = []


@dispatch(TFNumeric)
def shape_int(a):
return tuple(B.shape(a).as_list())


@dispatch(TFNumeric)
def length(a):
return tf.size(a)
Expand Down Expand Up @@ -56,10 +51,10 @@ def vec_to_tril(a):
def tril_to_vec(a):
if B.rank(a) != 2:
raise ValueError('Input must be rank 2.')
n, m = shape_int(a)
n, m = B.shape(a)
if n != m:
raise ValueError('Input must be square.')
return tf.gather_nd(a, list(zip(*np.tril_indices(n))))
return tf.gather_nd(a, list(zip(*np.tril_indices(int(n)))))


@dispatch([TFNumeric])
Expand Down
4 changes: 2 additions & 2 deletions lab/torch/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def trace(a, axis1=0, axis2=1):

@dispatch(TorchNumeric, TorchNumeric)
def kron(a, b):
shape_a = B.shape_int(a)
shape_b = B.shape_int(b)
shape_a = B.shape(a)
shape_b = B.shape(b)

# Check that ranks are equal.
if len(shape_a) != len(shape_b):
Expand Down
7 changes: 1 addition & 6 deletions lab/torch/shaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
__all__ = []


@dispatch(TorchNumeric)
def shape_int(a):
return tuple(B.shape(a))


@dispatch(TorchNumeric)
def length(a):
return a.numel()
Expand Down Expand Up @@ -51,7 +46,7 @@ def vec_to_tril(a):
def tril_to_vec(a):
if B.rank(a) != 2:
raise ValueError('Input must be rank 2.')
n, m = shape_int(a)
n, m = B.shape(a)
if n != m:
raise ValueError('Input must be square.')
return a[np.tril_indices(n)]
Expand Down
10 changes: 5 additions & 5 deletions tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_zeros_ones_eye():
Value(3))

# Check shape of calls.
yield eq, B.shape_int(f(2)), (2, 2) if f is B.eye else (2,)
yield eq, B.shape_int(f(2, 3)), (2, 3)
yield eq, B.shape(f(2)), (2, 2) if f is B.eye else (2,)
yield eq, B.shape(f(2, 3)), (2, 3)

# Check shape type of calls.
yield eq, B.dtype(f(2)), B.default_dtype
Expand All @@ -50,9 +50,9 @@ def test_zeros_ones_eye():
ref = B.randn(t1, 4, 5)

# Check shape of calls.
yield eq, B.shape_int(f(t2, 2)), (2, 2) if f is B.eye else (2,)
yield eq, B.shape_int(f(t2, 2, 3)), (2, 3)
yield eq, B.shape_int(f(ref)), (4, 5)
yield eq, B.shape(f(t2, 2)), (2, 2) if f is B.eye else (2,)
yield eq, B.shape(f(t2, 2, 3)), (2, 3)
yield eq, B.shape(f(ref)), (4, 5)

# Check shape type of calls.
yield eq, B.dtype(f(t2, 2)), t2
Expand Down
2 changes: 1 addition & 1 deletion tests/test_linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_pw_1d():

# Check that we can feed both rank 1 and rank 2 tensors.
for f, g in product(*([[lambda x: x, lambda x: x[:, 0]]] * 2)):

yield allclose, B.pw_dists2(f(a), g(b)), np.abs(a - b.T) ** 2
yield allclose, B.pw_dists2(f(a)), np.abs(a - a.T) ** 2
yield allclose, B.pw_dists(f(a), g(b)), np.abs(a - b.T)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,29 @@ def test_random_generators():
for f in [B.rand, B.randn]:
# Test without specifying data type.
yield deq, B.dtype(f()), B.default_dtype
yield allclose, B.shape_int(f()), (), False
yield eq, B.shape(f()), ()
yield deq, B.dtype(f(2)), B.default_dtype
yield allclose, B.shape_int(f(2)), (2,)
yield allclose, B.shape(f(2)), (2,)
yield deq, B.dtype(f(2, 3)), B.default_dtype
yield allclose, B.shape_int(f(2, 3)), (2, 3)
yield eq, B.shape(f(2, 3)), (2, 3)

# Test with specifying data type.
for t in [np.float32, tf.float32, torch.float32]:
# Test direct specification.
yield deq, B.dtype(f(t)), t
yield allclose, B.shape_int(f(t)), (), False
yield eq, B.shape(f(t)), ()
yield deq, B.dtype(f(t, 2)), t
yield allclose, B.shape_int(f(t, 2)), (2,), False
yield eq, B.shape(f(t, 2)), (2,)
yield deq, B.dtype(f(t, 2, 3)), t
yield allclose, B.shape_int(f(t, 2, 3)), (2, 3), False
yield eq, B.shape(f(t, 2, 3)), (2, 3)

# Test reference specification.
yield deq, B.dtype(f(f(t))), t
yield allclose, B.shape_int(f(f())), (), False
yield eq, B.shape(f(f())), ()
yield deq, B.dtype(f(f(t, 2))), t
yield allclose, B.shape_int(f(f(t, 2))), (2,), False
yield eq, B.shape(f(f(t, 2))), (2,)
yield deq, B.dtype(f(f(t, 2, 3))), t
yield allclose, B.shape_int(f(f(t, 2, 3))), (2, 3), False
yield eq, B.shape(f(f(t, 2, 3))), (2, 3)


def test_conversion_warnings():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def test_sizing():
for f in [B.shape, B.shape_int, B.rank, B.length, B.size]:
for f in [B.shape, B.rank, B.length, B.size]:
yield check_function, f, (Tensor(),), {}, False
yield check_function, f, (Tensor(3, ),), {}, False
yield check_function, f, (Tensor(3, 4),), {}, False
Expand Down

0 comments on commit 7522962

Please sign in to comment.