Skip to content

Commit

Permalink
reduce use of lax on static data (e.g. shapes) (google#2933)
Browse files Browse the repository at this point in the history
* reduce use of lax on static data (e.g. shapes)

* use f-string for error message
  • Loading branch information
mattjj committed May 2, 2020
1 parent 64f12a4 commit 9f7115e
Showing 1 changed file with 16 additions and 27 deletions.
43 changes: 16 additions & 27 deletions jax/numpy/lax_numpy.py
Expand Up @@ -1044,7 +1044,7 @@ def gradient_along_axis(a, h, axis):
return []
axis = [_canonicalize_axis(i, a.ndim) for i in axis]

if min([s for i, s in enumerate(a.shape) if i in axis]) < 2:
if _min([s for i, s in enumerate(a.shape) if i in axis]) < 2:
raise ValueError("Shape of array too small to calculate "
"a numerical gradient, "
"at least 2 elements are required.")
Expand Down Expand Up @@ -1852,7 +1852,7 @@ def _pad(array, pad_width, mode, constant_values):
array = asarray(array)
nd = ndim(array)
pad_width = onp.broadcast_to(onp.asarray(pad_width), (nd, 2))
if any(pad_width < 0):
if onp.any(pad_width < 0):
raise ValueError("index can't contain negative values")

if mode == "constant":
Expand Down Expand Up @@ -2313,52 +2313,41 @@ def _repeat_scalar(a, repeats, axis=None):

@_wraps(onp.repeat)
def repeat(a, repeats, axis=None):
'''
:param repeats: int or array of ints
'''
# use `_repeat_scalar` when possible
if isscalar(repeats):
return _repeat_scalar(a, repeats, axis)
repeats_raveled = ravel(array(repeats)) # make sure it's jax's array type
repeats_raveled = onp.ravel(onp.array(repeats))
if size(repeats_raveled) == 1:
return _repeat_scalar(a, list(repeats_raveled)[0], axis)
return _repeat_scalar(a, repeats_raveled.item(), axis)

if axis is None or isscalar(a):
a = ravel(a)
axis = 0

# repeats must match the dimension along the requested axis
a_shape = list(a.shape)
n = a_shape[axis]
if size(repeats_raveled) != n:
raise ValueError("repeats shape {} does not match the dimension on axis {}".format(
repeats_raveled.shape, n
))
if repeats_raveled.size != a.shape[axis]:
raise ValueError(f"repeats shape {repeats_raveled.shape} does not match "
f"the dimension on axis {a.shape[axis]}")

# calculating the new shape
total = sum(repeats_raveled)
total = repeats_raveled.sum()

new_shape = a_shape[:]
new_shape = list(a.shape)
new_shape[axis] = total

a_flattened = ravel(a)

'''
main algorithm:
first break down raveled input array into list of chunks; each chunk is the unit of repeat
then tile the repeats to have same length as the list of chunks
finally repeat each unit x number of times according to the tiled repeat list
'''
chunks = product(a_shape[:axis+1]).item()
# first break down raveled input array into list of chunks; each chunk is the
# unit of repeat. then tile the repeats to have same length as the list of
# chunks. finally repeat each unit x number of times according to the tiled
# repeat list.
chunks = _prod(a.shape[:axis+1])
a_splitted = split(a_flattened, chunks)
repeats_tiled = tile(repeats_raveled, chunks // len(repeats_raveled))
repeats_tiled = onp.tile(repeats_raveled, chunks // len(repeats_raveled))

ret = array([], dtype=a.dtype)
for i, repeat in enumerate(repeats_tiled):
if not isinstance(repeat, int):
repeat = repeat.item()
if repeat != 0:
ret = concatenate((ret, tile(a_splitted[i], repeat)))
ret = concatenate((ret, tile(a_splitted[i], (repeat,))))

return reshape(ret, new_shape)

Expand Down

0 comments on commit 9f7115e

Please sign in to comment.