Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WGAN-GP low-level API example #157

Merged
merged 16 commits into from Feb 14, 2021
Merged

WGAN-GP low-level API example #157

merged 16 commits into from Feb 14, 2021

Conversation

alexander-g
Copy link
Contributor

@alexander-g alexander-g commented Feb 4, 2021

A more extensive example using the new low-level API:
Wasserstein-GAN with Gradient Penalty (WGAN-GP) trained on the CelebA dataset.

Some good generated images:
epoch-0079
epoch-0084
epoch-0089

Some notes:

  • I first tried to train a DCGAN which uses binary crossentropy but I've run into balancing issues. The discriminator quickly becomes too good so that the generator does not learn anything. The same model implemented in PyTorch or TensorFlow works. Most modern GANs don't use the WGAN loss anymore, most use BCE.
  • I'm still in favor of making Module.apply() return init(). It's just too much boilerplate to use an if-else every time. I avoided it by manually calling wgan.states = wgan.init(...) after model instantiation which I think is also not nice.
  • Can we make Module.apply() accept params and states separately instead of collections. It's annoying having to construct a dict {'params':params, 'states':states} every time
  • It would be nice if elegy.States was a dict so that the user can decide by themself what to put into it. With GANs where you have to manage generator and discriminator states separately one has to always split them like (g_states, d_states) = net_states which is again annoying
  • Model.save() fails on this model. Partially due to the extra jitted functions but even when I remove them, cloudpickle chokes on _HooksContext

@cgarciae I'm not completely sure I've used the low-level API correctly, maybe you can take a closer look?

@codecov-io
Copy link

codecov-io commented Feb 4, 2021

Codecov Report

Merging #157 (db7b420) into master (cb70cf5) will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #157   +/-   ##
=======================================
  Coverage   84.90%   84.90%           
=======================================
  Files         131      131           
  Lines        6815     6815           
=======================================
  Hits         5786     5786           
  Misses       1029     1029           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update cb70cf5...db7b420. Read the comment docs.

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 4, 2021

Awesome @alexander-g !

I first tried to train a DCGAN which uses binary crossentropy but I've run into balancing issues. The discriminator quickly becomes too good so that the generator does not learn anything. The same model implemented in PyTorch or TensorFlow works. Most modern GANs don't use the WGAN loss anymore, most use BCE.

I wonder if its our implementation of BCE, or the optimizer, or our implementation of the layers. I've also seen differences in the learning capacity between Jax / Elegy and Keras, I believe this is sometimes because the implementation of Adam in Keras uses a lot of improvements while optax.adam is a very vanilla implementation.

I'm still in favor of making Module.apply() return init(). It's just too much boilerplate to use an if-else every time. I avoided it by manually calling wgan.states = wgan.init(...) after model instantiation which I think is also not nice.

What about creating a init_or_apply method? I would like this behavior to be explicit.

Can we make Module.apply() accept params and states separately instead of collections. It's annoying having to construct a dict {'params':params, 'states':states} every time

I am all for this, my main worry was the reasons why Flax does it like this: google/flax#857 (comment)

But I believe we might be able to safely assume that there is only one set of trainable parameters (what we call parameters) and the rest, while being organized by collection name, are non-trainable. I'll work on this next!

It would be nice if elegy.States was a dict so that the user can decide by themself what to put into it. With GANs where you have to manage generator and discriminator states separately one has to always split them like (g_states, d_states) = net_states which is again annoying

I like the idea of making elegy.States more flexible, however, Model would still use a predefined set of names so if the user does something like this on e.g. test_step:

states = states.update(params_a=params_a, params_b=params_b)

the default implementation of train_step would still feed states.net_params to the optimizer. I mean, the user might still need to beware of standard names depending on what he is doing.

Model.save() fails on this model. Partially due to the extra jitted functions but even when I remove them, cloudpickle chokes on _HooksContext

This is odd, the various elegy_* examples are not having issues serializing, I would love to pinpoint the part that is causing the issue.

@cgarciae I'm not completely sure I've used the low-level API correctly, maybe you can take a closer look?

Sure, I'll take a look :)

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 4, 2021

@alexander-g I see you are still using the Model.init method which no longer exists. Is this branch up to date with master?

This is still working progress but for more info on the low-level API you can consult this document:

https://github.com/poets-ai/elegy/blob/ad03ea9e73aeae972acd5194cc45893117e1db52/docs/low-level-api/basics.md

Also check out:

@alexander-g
Copy link
Contributor Author

You're right, I've got my git repo mixed up. Will fix it soon

@alexander-g alexander-g marked this pull request as draft February 6, 2021 07:52
@alexander-g
Copy link
Contributor Author

alexander-g commented Feb 10, 2021

  • updated to version 0.5.0
  • tested for 5 epochs instead of full 100 epochs, did not change architecture and intermediate output looked Ok, so I hope it works for a full run too
  • Model.save() works
  • added to test-examples.sh for 2 epochs and 3 dummy images for automatic testing

I see you are still using the Model.init method which no longer exists.

I re-define it to avoid the if initializing: module.init() else: module.apply() boilerplate. I have to call apply 7 times, this gets too verbose

What about creating a init_or_apply method?

Yes, please

@alexander-g alexander-g marked this pull request as ready for review February 10, 2021 07:39
)

wgan = WGAN_GP()
wgan.init(np.zeros([8, 128]))
Copy link
Collaborator

@cgarciae cgarciae Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking a lot about initialization lately, I am wondering if we would benefit from having a mandatory Model.init that you must call at the beginning onces or continue with our progressive initialization strategy (e.g. if you call predict only pred_step is initialized since you don't have the labels to initialize the metrics, but if you then call fit the rest of the model (metrics + optimizer) is initialized).

I think having an explicit init is more clean and its less confusing but would require an additional mandatory call from the user as in this case. I'll open an issue with an RFC around this topic.

examples/WGAN-GP/model.py Outdated Show resolved Hide resolved
d_grads, d_opt_states = self.d_optimizer.update(d_grads, d_opt_states, d_params)
d_params = optax.apply_updates(d_params, d_grads)

return d_loss, d_params, d_states, d_opt_states, rng, gp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider accepting and returning states as with most *_step methods, the can make life easier for the caller of these methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had tried this too but got turned off by having to write states.update(this=this, that=that, banana=banana....etc) every time. I've now added elegy.States.safe_update() which is basically the opposite of maybe_update() and allows doing states.safe_update(**locals()). Less writing, happier me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_known and update_unknown might be better names for safe_update and maybe_update

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like those names! You can introduce update_known in this PR and we can refactor to update_unknown in a future PR once we merge these 2 PRs that are active.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.
Consider overriding operators | & + as shortcuts for update_unknown, update_known and update for even less writing: states & locals()

@cgarciae
Copy link
Collaborator

@alexander-g this is looking great! Left a few comments.

BTW: if you are interested in continuing work on GANs, I was thinking that maybe we could add a GANModel in the future so the user only has to pass a couple of arguments to the GANModel constructor and it would get the training loop for free, basically generalizing what you just did and maybe adding the option to have metrics.

@alexander-g
Copy link
Contributor Author

@alexander-g this is looking great! Left a few comments.

BTW: if you are interested in continuing work on GANs, I was thinking that maybe we could add a GANModel in the future so the user only has to pass a couple of arguments to the GANModel constructor and it would get the training loop for free, basically generalizing what you just did and maybe adding the option to have metrics.

Good idea


return {"d_loss": d_loss, "g_loss": g_loss, "gp": gp}, states.update(step=step)

def discriminator_step(self, x_real: jnp.ndarray, S: elegy.States):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we should use S as a variable name, looks cool since its short but its not within the python style guide and all other examples use states. Also the linter thinks its a class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to states.
What I find confusing about is that it contains another attributes called states: states.g_states, states.g_opt_states ...

I see you've changed the name for non-trainable parameters to ParameterCollection, is that the correct term now? I think this is also not the best name, but I can't think of anything better either.

Copy link
Collaborator

@cgarciae cgarciae Feb 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, ParametersCollection is not a great name now that Module.apply takes parameters and collections, a better name might be StatesCollection since you will usually have the states of the network there.

step = states.step + 1
no_update = lambda args: (0.0, args[1])
do_update = lambda args: self.generator_step(len(args[0]), args[1])
g_loss, states = jax.lax.cond(step % 5 == 0, do_update, no_update, (x, states))
Copy link
Collaborator

@cgarciae cgarciae Feb 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should work and its a bit more readable:

g_loss, states = jax.lax.cond(
    step % 5 == 0, 
    lambda _: (0.0, states), 
    lambda _: self.generator_step(len(x), states), 
    None,
)

Not sure when exactly when the operand argument is necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this doesn't work, the operand is required:

jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped throug
h global state from a previously traced function.                                                   
The functions being transformed should not save traced values to global state.                      
The tracer that caused this error was created on line elegy/elegy/types.py:66 (next).                                                          
When the tracer was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
elegy/elegy/utils.py:101 (wrapper)
elegy/examples/WGAN-GP/model.py:89 (train_step)
elegy/examples/WGAN-GP/model.py:88 (<lambda>)
elegy/examples/WGAN-GP/model.py:125 (generator_step)
elegy/elegy/types.py:66 (next)
The function being traced when the tracer leaked was <lambda> at elegy/examples/WGAN-GP/model.py:88.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the
`jax.checking_leaks` context manager.

Because of rng.next()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to a mix which I for one find more readable:

no_update = lambda states: (0.0, states)
do_update = lambda states: self.generator_step(len(x), states)
g_loss, states = jax.lax.cond(step % 5 == 0, do_update, no_update, states)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I got it working with a simple example but maybe it was too simple. This looks good!

@cgarciae
Copy link
Collaborator

@alexander-g looking good! Left more comments.

Would you mind if we merge #163 and then update this code to use the new init and init_step methods? It would be a good showcase. Also init_or_apply is there if you want to use it but given you implemented init separately it might not be needed.

@alexander-g
Copy link
Contributor Author

alexander-g commented Feb 12, 2021

No idea about the failing test.
Edit: passes again but not related to this PR

@cgarciae
Copy link
Collaborator

No idea about the failing test.
Edit: passes again but not related to this PR

Sometimes tests fail due to a Network Timeout. We sadly download stuff during tests which I think its not good :(

@cgarciae
Copy link
Collaborator

@alexander-g made a small change to use init_step, ran it with the test-examples script and it works.

@cgarciae cgarciae merged commit 5c9db37 into poets-ai:master Feb 14, 2021
@alexander-g alexander-g deleted the wgan branch February 15, 2021 07:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants