From ebbcbad547e56357e600cd8c19232d4b91cf4f00 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 10 Mar 2020 08:29:46 -0700 Subject: [PATCH] allow vmap in_axes to be a list, fixes #2367 (#2395) --- jax/api.py | 8 ++++++++ tests/api_test.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/jax/api.py b/jax/api.py index e8492db51d70..2b9b7859e4d4 100644 --- a/jax/api.py +++ b/jax/api.py @@ -657,6 +657,14 @@ def vmap(fun: Callable, in_axes=0, out_axes=0): docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} " "but with additional array axes over which {fun} is mapped.") + if isinstance(in_axes, list): + # To be a tree prefix of the positional args tuple, in_axes can never be a + # list: if in_axes is not a leaf, it must be a tuple of trees. However, + # in cases like these users expect tuples and lists to be treated + # essentially interchangeably, so we canonicalize lists to tuples here + # rather than raising an error. https://github.com/google/jax/issues/2367 + in_axes = tuple(in_axes) + _check_callable(fun) if (not isinstance(in_axes, (list, tuple, type(None), int)) or not isinstance(out_axes, (list, tuple, type(None), int))): diff --git a/tests/api_test.py b/tests/api_test.py index 98c505f1560d..23d7d5b49f17 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1229,6 +1229,20 @@ def vjp(x_tangent): b = np.dot(a + np.eye(a.shape[0]), real_x) print(gf(a, b)) # doesn't crash + def test_vmap_in_axes_list(self): + # https://github.com/google/jax/issues/2367 + dictionary = {'a': 5., 'b': np.ones(2)} + x = np.zeros(3) + y = np.arange(3.) + + + def f(dct, x, y): + return dct['a'] + dct['b'] + x + y + + out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y) + out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y) + self.assertAllClose(out1, out2, check_dtypes=True) + def test_vmap_in_axes_tree_prefix_error(self): # https://github.com/google/jax/issues/795 self.assertRaisesRegex(