Skip to content

Commit

Permalink
DOC: write a new dosctring for jax.numpy.vectorize (jax-ml#2944)
Browse files Browse the repository at this point in the history
* DOC: write a new dosctring for jax.numpy.vectorize

This version is customized entirely for JAX.

* review and typo fixes
  • Loading branch information
shoyer committed May 4, 2020
1 parent 5a0bf46 commit 6aab9e5
Showing 1 changed file with 63 additions and 8 deletions.
71 changes: 63 additions & 8 deletions jax/numpy/vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,71 @@ def new_func(*args):
return new_func, dynamic_args


@_wraps(onp.vectorize, lax_description=textwrap.dedent("""
JAX's implementation of vectorize should be considerably more efficient
than NumPy's, because it uses a batching transformation rather than an
explicit "for" loop.
Note that JAX only supports the optional ``excluded`` (integer only) and
``signature`` arguments, both of which must be specified with keywords.
"""))
def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
"""Define a vectorized function with broadcasting.
``vectorize`` is a convenience wrapper for defining vectorized functions with
broadcasting, in the style of NumPy's `generalized universal functions <https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html>`_.
It allows for defining functions that are automatically repeated across
any leading dimensions, without the implementation of the function needing to
be concerned about how to handle higher dimensional inputs.
``jax.numpy.vectorize`` has the same interface as ``numpy.vectorize``, but it
is syntactic sugar for an auto-batching transformation (``vmap``) rather
than a Python loop. This should be considerably more efficient, but the
implementation must be written in terms of functions that act on JAX arrays.
Args:
pyfunc: vectorized function.
excluded: optional set of integers representing positional arguments for
which the function will not be vectorized. These will be passed directly
to ``pyfunc`` unmodified.
signature: optional generalized universal function signature, e.g.,
``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If
provided, ``pyfunc`` will be called with (and expected to return) arrays
with shapes given by the size of corresponding core dimensions. By
default, pyfunc is assumed to take scalars arrays as input and output.
Returns:
Vectorized version of the given function.
Here a few examples of how one could write vectorized linear algebra routines
using ``vectorize``::
import jax.numpy as jnp
from functools import partial
@partial(jnp.vectorize, signature='(k),(k)->(k)')
def cross_product(a, b):
assert a.shape == b.shape and a.ndim == b.ndim == 1
return jnp.array([a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0]])
@partial(jnp.vectorize, signature='(n,m),(m)->(n)')
def matrix_vector_product(matrix, vector):
assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape
return matrix @ vector
These functions are only written to handle 1D or 2D arrays (the ``assert``
statements will never be violated), but with vectorize they support
arbitrary dimensional inputs with NumPy style broadcasting, e.g.,
>>> cross_product(jnp.ones(3), jnp.ones(3)).shape
(3,)
>>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2, 3)
>>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape
(2, 2, 3)
>>> matrix_vector_product(jnp.ones(3), jnp.ones(3))
ValueError: input with shape (3,) does not have enough dimensions for all
core dimensions ('n', 'k') on vectorized function with excluded=frozenset()
and signature='(n,k),(k)->(k)'
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2,)
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape
(4, 2) # not the same as np.matmul
"""
if any(not isinstance(exclude, int) for exclude in excluded):
raise TypeError("jax.numpy.vectorize can only exclude integer arguments, "
"but excluded={!r}".format(excluded))
Expand Down

0 comments on commit 6aab9e5

Please sign in to comment.