Skip to content

Commit

Permalink
Allow subshape selection
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed May 14, 2021
1 parent adf9cd7 commit cff8852
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -387,7 +387,7 @@ choice(a)

### Shaping
```
shape(a)
shape(a, *dims)
rank(a)
length(a) (alias: size)
isscalar(a)
Expand Down
11 changes: 11 additions & 0 deletions lab/shaping.py
Expand Up @@ -63,6 +63,7 @@ def shape(a: Numeric): # pragma: no cover
Args:
a (tensor): Tensor.
*dims (int, optional): Dimensions to get.
Returns:
object: Shape of `a`.
Expand All @@ -82,6 +83,16 @@ def shape(a: Union[list, tuple]):
return np.array(a).shape


@dispatch
def shape(a, *dims: Int):
a_shape = B.shape(a)
subshape = tuple(a_shape[i] for i in dims)
if LazyShapes.enabled:
return Shape(*subshape)
else:
return subshape


@dispatch
def rank(a: Union[Numeric, list, tuple]): # pragma: no cover
"""Get the shape of a tensor.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_shaping.py
Expand Up @@ -38,6 +38,7 @@ def test_sizing(f, check_lazy_shapes):
@pytest.mark.parametrize(
"x,shape",
[
(1, ()),
([], (0,)),
([5], (1,)),
([[5], [6]], (2, 1)),
Expand All @@ -50,6 +51,12 @@ def test_shape(x, shape, check_lazy_shapes):
assert B.shape(x) == shape


def test_subshape(check_lazy_shapes):
assert B.shape(B.zeros(2), 0) == (2,)
assert B.shape(B.zeros(2, 3, 4), 1) == (3,)
assert B.shape(B.zeros(2, 3, 4), 0, 2) == (2, 4)


def test_lazy_shape():
a = B.randn(2, 2)

Expand Down

0 comments on commit cff8852

Please sign in to comment.