From 9f7115ece2dcc0ef324e736bd5f15607fb9c15ee Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 2 May 2020 12:02:43 -0700 Subject: [PATCH] reduce use of lax on static data (e.g. shapes) (#2933) * reduce use of lax on static data (e.g. shapes) * use f-string for error message --- jax/numpy/lax_numpy.py | 43 ++++++++++++++++-------------------------- 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 173660ed1aa3..a7e07c647d72 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -1044,7 +1044,7 @@ def gradient_along_axis(a, h, axis): return [] axis = [_canonicalize_axis(i, a.ndim) for i in axis] - if min([s for i, s in enumerate(a.shape) if i in axis]) < 2: + if _min([s for i, s in enumerate(a.shape) if i in axis]) < 2: raise ValueError("Shape of array too small to calculate " "a numerical gradient, " "at least 2 elements are required.") @@ -1852,7 +1852,7 @@ def _pad(array, pad_width, mode, constant_values): array = asarray(array) nd = ndim(array) pad_width = onp.broadcast_to(onp.asarray(pad_width), (nd, 2)) - if any(pad_width < 0): + if onp.any(pad_width < 0): raise ValueError("index can't contain negative values") if mode == "constant": @@ -2313,52 +2313,41 @@ def _repeat_scalar(a, repeats, axis=None): @_wraps(onp.repeat) def repeat(a, repeats, axis=None): - ''' - :param repeats: int or array of ints - ''' # use `_repeat_scalar` when possible if isscalar(repeats): return _repeat_scalar(a, repeats, axis) - repeats_raveled = ravel(array(repeats)) # make sure it's jax's array type + repeats_raveled = onp.ravel(onp.array(repeats)) if size(repeats_raveled) == 1: - return _repeat_scalar(a, list(repeats_raveled)[0], axis) + return _repeat_scalar(a, repeats_raveled.item(), axis) if axis is None or isscalar(a): a = ravel(a) axis = 0 # repeats must match the dimension along the requested axis - a_shape = list(a.shape) - n = a_shape[axis] - if size(repeats_raveled) != n: - raise ValueError("repeats shape {} does not match the dimension on axis {}".format( - repeats_raveled.shape, n - )) + if repeats_raveled.size != a.shape[axis]: + raise ValueError(f"repeats shape {repeats_raveled.shape} does not match " + f"the dimension on axis {a.shape[axis]}") # calculating the new shape - total = sum(repeats_raveled) + total = repeats_raveled.sum() - new_shape = a_shape[:] + new_shape = list(a.shape) new_shape[axis] = total - a_flattened = ravel(a) - ''' - main algorithm: - first break down raveled input array into list of chunks; each chunk is the unit of repeat - then tile the repeats to have same length as the list of chunks - finally repeat each unit x number of times according to the tiled repeat list - ''' - chunks = product(a_shape[:axis+1]).item() + # first break down raveled input array into list of chunks; each chunk is the + # unit of repeat. then tile the repeats to have same length as the list of + # chunks. finally repeat each unit x number of times according to the tiled + # repeat list. + chunks = _prod(a.shape[:axis+1]) a_splitted = split(a_flattened, chunks) - repeats_tiled = tile(repeats_raveled, chunks // len(repeats_raveled)) + repeats_tiled = onp.tile(repeats_raveled, chunks // len(repeats_raveled)) ret = array([], dtype=a.dtype) for i, repeat in enumerate(repeats_tiled): - if not isinstance(repeat, int): - repeat = repeat.item() if repeat != 0: - ret = concatenate((ret, tile(a_splitted[i], repeat))) + ret = concatenate((ret, tile(a_splitted[i], (repeat,)))) return reshape(ret, new_shape)