Skip to content

Commit

Permalink
allow vmap in_axes to be a list, fixes google#2367 (google#2395)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 10, 2020
1 parent 9fd69a0 commit ebbcbad
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
8 changes: 8 additions & 0 deletions jax/api.py
Expand Up @@ -657,6 +657,14 @@ def vmap(fun: Callable, in_axes=0, out_axes=0):
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
"but with additional array axes over which {fun} is mapped.")

if isinstance(in_axes, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axes = tuple(in_axes)

_check_callable(fun)
if (not isinstance(in_axes, (list, tuple, type(None), int))
or not isinstance(out_axes, (list, tuple, type(None), int))):
Expand Down
14 changes: 14 additions & 0 deletions tests/api_test.py
Expand Up @@ -1229,6 +1229,20 @@ def vjp(x_tangent):
b = np.dot(a + np.eye(a.shape[0]), real_x)
print(gf(a, b)) # doesn't crash

def test_vmap_in_axes_list(self):
# https://github.com/google/jax/issues/2367
dictionary = {'a': 5., 'b': np.ones(2)}
x = np.zeros(3)
y = np.arange(3.)


def f(dct, x, y):
return dct['a'] + dct['b'] + x + y

out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y)
out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
self.assertAllClose(out1, out2, check_dtypes=True)

def test_vmap_in_axes_tree_prefix_error(self):
# https://github.com/google/jax/issues/795
self.assertRaisesRegex(
Expand Down

0 comments on commit ebbcbad

Please sign in to comment.