Skip to content

sowing across nnx.while_loop #4799

Answered by cgarciae
jacknewsom asked this question in Q&A
Discussion options

You must be logged in to vote

@jacknewsom to effectively sow inside a while loop you need to create an array with a max size and add the value you want to sow at a specific index that should increase every time you sow. To do this you can use sow's init_fn and reduce_fn:

max_size: int = ...

def init_fn(x):
  return dict(
    i=jnp.array(0),
    x=jnp.zeros((max_size, *x.shape))
  )
  
def reduce_fn(acc, x):
  return dict(
    i=acc['i'] + 1,
    x=acc['x].at[i].set(x),
  )

class MLP(nn.Module):
  def __call__(self, x):
    self.sow("intermediates", "x", x, init_fn=init_fn, reduce_fn=reduce_fn)
    ...

The sow inside scan is a trickier, as you need to initialize the state before calling scan, you can do something lik…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jacknewsom
Comment options

Answer selected by jacknewsom
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants