In [1]:
import jax
import jax.numpy as jnp

In [2]:
# candidate_ids is an jax 2d array with dimension (1, max_length)
max_length = 64
candidate_ids = jnp.arange(max_length)[None, :]

In [3]:
candidate_ids

Array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
        16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
        32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
        48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]],      dtype=int32)

In [4]:
# padding_pos is an jax 2d array with dimension (1, 3)
padding_pos = jnp.arange(start=1, stop=4)[None, :]
padding_pos

Array([[1, 2, 3]], dtype=int32)

In [5]:
# try slicing without jit
new_candidate_ids = jax.lax.dynamic_slice(candidate_ids, (0, 0), (1, padding_pos[0, 0]))
new_candidate_ids

Array([[0]], dtype=int32)

In [6]:
# jit the function and try again
def slice_ids(input_ids, idx=1):
    return jax.lax.dynamic_slice(input_ids, (0, 0), (1, idx))

jit_slice_ids = jax.jit(slice_ids)

In [7]:
new_candidate_ids = jit_slice_ids(candidate_ids, idx=int(padding_pos[0, 0]))
new_candidate_ids

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (1, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

In [9]:
# need to pass the idx arg as a `static_argnum`
# see: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit
jit_slice_ids = jax.jit(slice_ids, static_argnums=1)

new_candidate_ids = jit_slice_ids(candidate_ids, idx=int(padding_pos[0, 0]))
new_candidate_ids

Array([[0]], dtype=int32)