-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Flax GeneralizedModule should be able to pass rngs dict to module.init during summary #154
Comments
Hey @sooheon! This is a good point. On one hand we can improve our https://github.com/poets-ai/elegy/blob/master/elegy/generalized_module/linen_module.py#L24 As you point out, only rng values are given for On the other hand, import dataget
import elegy
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
X_train, y_train, X_test, y_test = dataget.image.mnist(global_cache=True).get()
print("X_train:", X_train.shape, X_train.dtype)
print("y_train:", y_train.shape, y_train.dtype)
print("X_test:", X_test.shape, X_test.dtype)
print("y_test:", y_test.shape, y_test.dtype)
class MLP(nn.Module):
@nn.compact
#@elegy.flax_summarize # use decorators to report module summaries
def __call__(self, x):
x = nn.Dense(300)(x) # core modules dont report summaries :(
x = nn.relu(x)
x = nn.Dropout(0.1)(x)
x = nn.Dense(10)(x)
return x
class FlaxLinearClassifier(elegy.Model):
def pred_step(
self, x, states: elegy.States, initializing: bool, rng: elegy.RNGSeq
) -> elegy.PredStep:
x = jnp.reshape(x, (x.shape[0], -1)) / 255
if initializing:
variables = self.module.init(
{"params": rng.next(), "dropout": rng.next()}, x
)
params = variables["params"]
else:
params = states.net_params
logits = self.module.apply({"params": params}, x, rngs={"dropout": rng.next()})
return elegy.PredStep.simple(logits, states.update(rng=rng, net_params=params))
def test_step(self, x, y_true, states, mode, initializing):
# call_pred_step is the recommended way of invoking pred_step
logits, states, _, _, _ = self.call_pred_step(x, mode, states, initializing)
labels = jax.nn.one_hot(y_true, 10)
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
logs = dict(accuracy=accuracy, loss=loss)
return loss, logs, states
model = FlaxLinearClassifier(module=MLP(), optimizer=optax.adamw(1e-3))
model.summary(X_test[:64]) I found a bug so using |
I'll be adding guides for the low-level API so it becomes a bit more clear what the different methods you can override do ( |
Yeah an example using all of the canonical user-facing API would definitely go a long way. |
@sooheon is there a way to extract all possible names that might need |
There's no static way to know ahead of time afaict. Submodules can call |
Currently, the low-level API works for toy linen modules, but it does not allow for passing in multiple RNG keys, which Flax modules require for e.g. dropout.
Not sure about API design, but the hardcoded init happens in
LinenModule.init
. Somehow it should be possible to pass it a set of keywords to which to associaterng.next()
values.Minimal repro:
AssertionError: Need PRNG for "dropout"
The text was updated successfully, but these errors were encountered: