New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WGAN-GP low-level API example #157
Conversation
Codecov Report
@@ Coverage Diff @@
## master #157 +/- ##
=======================================
Coverage 84.90% 84.90%
=======================================
Files 131 131
Lines 6815 6815
=======================================
Hits 5786 5786
Misses 1029 1029 Continue to review full report at Codecov.
|
Awesome @alexander-g !
I wonder if its our implementation of BCE, or the optimizer, or our implementation of the layers. I've also seen differences in the learning capacity between Jax / Elegy and Keras, I believe this is sometimes because the implementation of Adam in Keras uses a lot of improvements while
What about creating a
I am all for this, my main worry was the reasons why Flax does it like this: google/flax#857 (comment) But I believe we might be able to safely assume that there is only one set of trainable parameters (what we call
I like the idea of making states = states.update(params_a=params_a, params_b=params_b) the default implementation of
This is odd, the various
Sure, I'll take a look :) |
@alexander-g I see you are still using the This is still working progress but for more info on the low-level API you can consult this document: Also check out: |
You're right, I've got my git repo mixed up. Will fix it soon |
I re-define it to avoid the
Yes, please |
) | ||
|
||
wgan = WGAN_GP() | ||
wgan.init(np.zeros([8, 128])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been thinking a lot about initialization lately, I am wondering if we would benefit from having a mandatory Model.init
that you must call at the beginning onces or continue with our progressive initialization strategy (e.g. if you call predict
only pred_step
is initialized since you don't have the labels to initialize the metrics, but if you then call fit
the rest of the model (metrics + optimizer) is initialized).
I think having an explicit init
is more clean and its less confusing but would require an additional mandatory call from the user as in this case. I'll open an issue with an RFC around this topic.
examples/WGAN-GP/model.py
Outdated
d_grads, d_opt_states = self.d_optimizer.update(d_grads, d_opt_states, d_params) | ||
d_params = optax.apply_updates(d_params, d_grads) | ||
|
||
return d_loss, d_params, d_states, d_opt_states, rng, gp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider accepting and returning states
as with most *_step
methods, the can make life easier for the caller of these methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had tried this too but got turned off by having to write states.update(this=this, that=that, banana=banana....etc)
every time. I've now added elegy.States.safe_update()
which is basically the opposite of maybe_update()
and allows doing states.safe_update(**locals())
. Less writing, happier me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update_known
and update_unknown
might be better names for safe_update
and maybe_update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like those names! You can introduce update_known
in this PR and we can refactor to update_unknown
in a future PR once we merge these 2 PRs that are active.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Consider overriding operators |
&
+
as shortcuts for update_unknown
, update_known
and update
for even less writing: states & locals()
@alexander-g this is looking great! Left a few comments. BTW: if you are interested in continuing work on GANs, I was thinking that maybe we could add a |
Good idea |
examples/WGAN-GP/model.py
Outdated
|
||
return {"d_loss": d_loss, "g_loss": g_loss, "gp": gp}, states.update(step=step) | ||
|
||
def discriminator_step(self, x_real: jnp.ndarray, S: elegy.States): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we should use S
as a variable name, looks cool since its short but its not within the python style guide and all other examples use states
. Also the linter thinks its a class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to states
.
What I find confusing about is that it contains another attributes called states: states.g_states
, states.g_opt_states
...
I see you've changed the name for non-trainable parameters to ParameterCollection
, is that the correct term now? I think this is also not the best name, but I can't think of anything better either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, ParametersCollection
is not a great name now that Module.apply
takes parameters
and collections
, a better name might be StatesCollection
since you will usually have the states of the network there.
examples/WGAN-GP/model.py
Outdated
step = states.step + 1 | ||
no_update = lambda args: (0.0, args[1]) | ||
do_update = lambda args: self.generator_step(len(args[0]), args[1]) | ||
g_loss, states = jax.lax.cond(step % 5 == 0, do_update, no_update, (x, states)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this should work and its a bit more readable:
g_loss, states = jax.lax.cond(
step % 5 == 0,
lambda _: (0.0, states),
lambda _: self.generator_step(len(x), states),
None,
)
Not sure when exactly when the operand
argument is necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this doesn't work, the operand is required:
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped throug
h global state from a previously traced function.
The functions being transformed should not save traced values to global state.
The tracer that caused this error was created on line elegy/elegy/types.py:66 (next).
When the tracer was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
elegy/elegy/utils.py:101 (wrapper)
elegy/examples/WGAN-GP/model.py:89 (train_step)
elegy/examples/WGAN-GP/model.py:88 (<lambda>)
elegy/examples/WGAN-GP/model.py:125 (generator_step)
elegy/elegy/types.py:66 (next)
The function being traced when the tracer leaked was <lambda> at elegy/examples/WGAN-GP/model.py:88.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the
`jax.checking_leaks` context manager.
Because of rng.next()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to a mix which I for one find more readable:
no_update = lambda states: (0.0, states)
do_update = lambda states: self.generator_step(len(x), states)
g_loss, states = jax.lax.cond(step % 5 == 0, do_update, no_update, states)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, I got it working with a simple example but maybe it was too simple. This looks good!
@alexander-g looking good! Left more comments. Would you mind if we merge #163 and then update this code to use the new |
No idea about the failing test. |
Sometimes tests fail due to a Network Timeout. We sadly download stuff during tests which I think its not good :( |
@alexander-g made a small change to use |
A more extensive example using the new low-level API:
Wasserstein-GAN with Gradient Penalty (WGAN-GP) trained on the CelebA dataset.
Some good generated images:
Some notes:
Module.apply()
returninit()
. It's just too much boilerplate to use anif-else
every time. I avoided it by manually callingwgan.states = wgan.init(...)
after model instantiation which I think is also not nice.Module.apply()
acceptparams
andstates
separately instead ofcollections
. It's annoying having to construct a dict{'params':params, 'states':states}
every timeelegy.States
was adict
so that the user can decide by themself what to put into it. With GANs where you have to manage generator and discriminator states separately one has to always split them like(g_states, d_states) = net_states
which is again annoyingModel.save()
fails on this model. Partially due to the extra jitted functions but even when I remove them,cloudpickle
chokes on_HooksContext
@cgarciae I'm not completely sure I've used the low-level API correctly, maybe you can take a closer look?