Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Flax GeneralizedModule should be able to pass rngs dict to module.init during summary #154

Closed
sooheon opened this issue Feb 3, 2021 · 5 comments · Fixed by #185
Labels
enhancement New feature or request

Comments

@sooheon
Copy link
Contributor

sooheon commented Feb 3, 2021

Currently, the low-level API works for toy linen modules, but it does not allow for passing in multiple RNG keys, which Flax modules require for e.g. dropout.

Not sure about API design, but the hardcoded init happens in LinenModule.init. Somehow it should be possible to pass it a set of keywords to which to associate rng.next() values.

Minimal repro:

import dataget
import elegy
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax

X_train, y_train, X_test, y_test = dataget.image.mnist(global_cache=True).get()

print("X_train:", X_train.shape, X_train.dtype)
print("y_train:", y_train.shape, y_train.dtype)
print("X_test:", X_test.shape, X_test.dtype)
print("y_test:", y_test.shape, y_test.dtype)


# %%
class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(300)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.1)(x)
        x = nn.Dense(10)(x)
        return x


class FlaxLinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool, rng: elegy.RNGSeq
    ):
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            variables = self.module.init(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
            params = variables["params"]
        else:
            params = states.net_params

        logits = self.module.apply({"params": params}, x, rngs={"dropout": rng.next()})
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

        logs = dict(accuracy=accuracy, loss=loss)
        return loss, logs, states.update(rng=rng, net_params=params)


model = FlaxLinearClassifier(module=MLP(), optimizer=optax.adamw(1e-3))

model.summary(X_test[:64])

AssertionError: Need PRNG for "dropout"

@sooheon sooheon added the enhancement New feature or request label Feb 3, 2021
@cgarciae
Copy link
Collaborator

cgarciae commented Feb 3, 2021

Hey @sooheon! This is a good point.

On one hand we can improve our LinenModule implementation, currently the calls to linen.Module.init and linen.Module.apply are implemented like this:

https://github.com/poets-ai/elegy/blob/master/elegy/generalized_module/linen_module.py#L24
https://github.com/poets-ai/elegy/blob/master/elegy/generalized_module/linen_module.py#L61-L67

As you point out, only rng values are given for params, the problem is that we a priori don't know what names the user might use. I think we can just add the most common names even if the users will not need them but I don't know if this solves the problem in general.

On the other hand, summary calls pred_step (not test_step) so you can refactor you code like this to fix the issue:

import dataget
import elegy
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax

X_train, y_train, X_test, y_test = dataget.image.mnist(global_cache=True).get()

print("X_train:", X_train.shape, X_train.dtype)
print("y_train:", y_train.shape, y_train.dtype)
print("X_test:", X_test.shape, X_test.dtype)
print("y_test:", y_test.shape, y_test.dtype)


class MLP(nn.Module):
    @nn.compact
    #@elegy.flax_summarize # use decorators to report module summaries
    def __call__(self, x):
        x = nn.Dense(300)(x) # core modules dont report summaries :(
        x = nn.relu(x)
        x = nn.Dropout(0.1)(x)
        x = nn.Dense(10)(x)
        return x


class FlaxLinearClassifier(elegy.Model):
    def pred_step(
        self, x, states: elegy.States, initializing: bool, rng: elegy.RNGSeq
    ) -> elegy.PredStep:
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            variables = self.module.init(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
            params = variables["params"]
        else:
            params = states.net_params

        logits = self.module.apply({"params": params}, x, rngs={"dropout": rng.next()})

        return elegy.PredStep.simple(logits, states.update(rng=rng, net_params=params))

    def test_step(self, x, y_true, states, mode, initializing):
        # call_pred_step is the recommended way of invoking pred_step
        logits, states, _, _, _ = self.call_pred_step(x, mode, states, initializing)

        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

        logs = dict(accuracy=accuracy, loss=loss)
        return loss, logs, states


model = FlaxLinearClassifier(module=MLP(), optimizer=optax.adamw(1e-3))

model.summary(X_test[:64])

I found a bug so using @elegy.flax_summarize which creates the summaries for the output of the Module can be uncommented after #155 is merged.

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 3, 2021

I'll be adding guides for the low-level API so it becomes a bit more clear what the different methods you can override do (pred_step, test_step, grad_step, train_step) and how you can compose them.

@sooheon
Copy link
Contributor Author

sooheon commented Feb 4, 2021

Yeah an example using all of the canonical user-facing API would definitely go a long way.

@cgarciae
Copy link
Collaborator

@sooheon is there a way to extract all possible names that might need rng from the variables? I am thinking of trying to overcompensate (gives more names tan required) just to keep Flax happy.

@sooheon
Copy link
Contributor Author

sooheon commented Feb 22, 2021

There's no static way to know ahead of time afaict. Submodules can call self.make_rng(name='foo'), and it's kind of up to you to provide 'foo' rng. Hopefully this clunky API gets improved in the future. OTOH, just adding dropout would give you 99% coverage, I think (I've yet to see a different rng key required).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants