sow
ing across nnx.while_loop
#4799
-
Hello, Is it possible to
For example, the snippet below fails:
with this error:
Is there any way around this that lets me use the JIT and sow in a loop? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
@jacknewsom to effectively sow inside a while loop you need to create an array with a max size and add the value you want to sow at a specific index that should increase every time you sow. To do this you can use max_size: int = ...
def init_fn(x):
return dict(
i=jnp.array(0),
x=jnp.zeros((max_size, *x.shape))
)
def reduce_fn(acc, x):
return dict(
i=acc['i'] + 1,
x=acc['x].at[i].set(x),
)
class MLP(nn.Module):
def __call__(self, x):
self.sow("intermediates", "x", x, init_fn=init_fn, reduce_fn=reduce_fn)
... The class MLP(nn.Module):
def __call__(self, x):
self.sow("intermediates", "x", x, init_fn=init_fn, reduce_fn=reduce_fn)
block = Block()
if self.is_initializing():
carry = ... # TODO
block.sow("intermediates", "carry", carry, init_fn=init_fn, reduce_fn=reduce_fn)
carry, _ = nn.scan(
Block.__call__,
variable_axes={"params": 0, "intermediates": 0},
split_rngs={"params": True},
length=self.depth,
)(block, x, None)
return carry |
Beta Was this translation helpful? Give feedback.
@jacknewsom to effectively sow inside a while loop you need to create an array with a max size and add the value you want to sow at a specific index that should increase every time you sow. To do this you can use
sow
'sinit_fn
andreduce_fn
:The
sow
insidescan
is a trickier, as you need to initialize the state before callingscan
, you can do something lik…