Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacement for experimental_run_tf_function after removal from tf.keras.Model.compile #35138

Closed
tgaddair opened this issue Dec 15, 2019 · 22 comments
Assignees
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.1 for tracking issues in 2.1 release type:docs-bug Document issues

Comments

@tgaddair
Copy link

It looks like experimental_run_tf_function was removed from tf.keras.Model.compile in this commit a few days ago: c73c99c#diff-de9b96ac2d81503324cbbbe21732031fR1159

In Horovod, this flag / graph mode is necessary in order for Optimizer.get_gradients() to be called, which aggregates gradients across workers. Since this flag has been removed, distributed training in Horovod with tf.keras is not working in our nightly builds.

Is there a workaround to achieve the same behavior with the latest changes on master?

Note that we cannot perform the allreduce aggregation in apply_gradients due to interactions with gradient clipping and loss scaling (see horovod/horovod#1347).

@tgaddair tgaddair added the type:docs-bug Document issues label Dec 15, 2019
@oanush oanush self-assigned this Dec 17, 2019
@oanush oanush added comp:keras Keras related issues TF 2.1 for tracking issues in 2.1 release labels Dec 17, 2019
@oanush oanush assigned jvishnuvardhan and unassigned oanush Dec 17, 2019
@jvishnuvardhan
Copy link
Contributor

@tgaddair I notice your PR is already merged. Do we still need to keep this open? Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting response Status - Awaiting response from author label Dec 19, 2019
@tgaddair
Copy link
Author

Hey @jvishnuvardhan, thanks for the response. That PR was only to pin our integration tests to an older version of tf-nightly. We are still looking for a long-term fix that will make Horovod compatible with TensorFlow 2.1.0.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Dec 20, 2019
@tgaddair
Copy link
Author

cc @martinwicke, @alsrgv was telling me you might have some thoughts on this as well. Thanks.

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Dec 23, 2019
@martinwicke
Copy link
Member

@tanzhenyu or @robieta would know more details.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jan 8, 2020
@MarkDaoust MarkDaoust assigned tanzhenyu and robieta and unassigned MarkDaoust Jan 8, 2020
@tgaddair
Copy link
Author

Hey @tanzhenyu and @robieta, any thoughts on this? Currently we're unable to support TensorFlow 2.1 effectively without this. But hopefully there's a workaround.

@nbro
Copy link

nbro commented Feb 5, 2020

@tgaddair In my case, experimental_run_tf_function still exits as an argument of the compile method in TF 2.1 (and I've been using this flag for some months with TF 2.0 and now with TF 2.1), so I don't get this issue.

@tgaddair
Copy link
Author

tgaddair commented Feb 5, 2020

@nbro the commit (c73c99c) was made before the release of TF 2.1, but was not included in the TF 2.1 branch, and as such will not go into effect until TF 2.2.

@nbro
Copy link

nbro commented Feb 5, 2020

@tgaddair I am currently using experimental_run_tf_function in TF 2.1.

@tgaddair
Copy link
Author

tgaddair commented Feb 5, 2020

@nbro I think you misunderstood me. I'm saying that this functionality will be dropped in TF 2.2, not 2.1. In the original issue, I assumed it would be in effect in 2.1 because the change was made before the 2.1 release, but the change was never merged into the 2.1 branch. Does that make sense?

See this discussion for more details: #36398

@nbro
Copy link

nbro commented Feb 5, 2020

@tgaddair Yes, now it makes sense.

@Flamefire
Copy link
Contributor

Flamefire commented Apr 22, 2020

2.2 release is on its way, I guess the solution from #37765 (edit: actually #36398) will work.

Some analysis (still using 2.1):
The same issue exists not only with model.fit but also with model.train_on_batch (which is ultimately called by fit) and also occurs when using tf.config.experimental_run_functions_eagerly(True) to see what is going on which (despite experiment_run_tf_function being set) uses the same method which does NOT call get_gradients

IMO the underlying problem seems to be in the duality of the methods used.

  • Method A uses optimizer.get_updates() which calls get_gradients and apply_gradients
  • Method B uses GradientTape and calls apply_gradients

Having multiple seemingly equivalent ways of doing the same thing always complicates (or even breaks) things.

But to move forward: (e.g.) Horovod has a DistributedGradientTape which works when get_gradients is not invoked.

Question: Why is there a get_gradients in the optimizer? Isn't that the task for e.g. GradientTape?

@tgaddair Correct me if I'm mistaken but I think a good way to solve this is using the plain Optimizer (not the Horovod DistributedOptimizer) and tell TensorFlow to use Horovods DistributedGradientTape.

So proposal:

Short-term: Allow users to pass in an optional GradientTape to model.compile or make the training loop query the optimizer to create one.

Long-term: Clarify why there are at least 2 ways to compute the gradients and decide for one to use. Support that properly. Obviously my call would be on GradientTape to split responsibilities properly.

@MarkDaoust
Copy link
Member

I'm not familiar with all, the issues here.

But, have you seen that train_step is now a overloadable in keras models: https://twitter.com/fchollet/status/1250622989541838848

That lets you affect how the model calculates the gradients. Does that help this use case?

@Flamefire
Copy link
Contributor

Not really. Still to high-level. The goal is to only replace the way TF gathers gradients. Not change the whole training algorithm

@tgaddair
Copy link
Author

Hey @Flamefire, following #36398 Horovod will work correctly when calling model.fit() with a hvd.DistributedOptimizer, so at least for now the problem is resolved.

But I agree that there is an API problem where we have two ways of doing the same thing that needs to be addressed.

One solution that's been proposed has been to add an optimizer.minimize() function where we could wrap all of this, eliminating the need for the DistributedGradientTape (which I prefer, because it is more consistent with every other API we support including PyTorch, MXNet, and TensorFlow v1).

The other solution is, as you suggest, to use DistributedGradientTape with a plain optimizer. This is what we do when not using Keras for TensorFlow v2 (see: tensorflow_mnist.py). I think it would also be reasonable to use this in place of optimizer._aggregate_gradients(), but it requires that TensorFlow allow us to override/inject the gradient tape when using model.fit() (not currently supported).

As to why optimizer.get_gradients exists, it's because that's what was used in TensorFlow v1, and Horovod maintains compatibility across both v1 and v2 using the same DistributedOptimizer.

@Flamefire
Copy link
Contributor

I think you misunderstood some parts. optimizer.get_gradients exists in TF. It gets used in 2.1 when experimental_run_tf_function=False while GradientTape is used for True. This is the cause of this whole issue and the current workaround in Horovod (disabling experimental_run_tf_function)

optimizer.minimize exists too. It calls _compute_gradients which does use GradientTape in yet another way to compute the gradients. This makes it 3 ways...

Hence my proposal to unify the whole training stuff in TF so there is exactly one way and place things are done. This can then be extended to provide customization points. Example: Provide a custom GradientTape which would eliminate the need for the DistributedOptimizer.

I don't think Optimizer should have a method for gathering gradients. An optimizer is a thing like SGD, Adam, ... which defines how gradients are applied. Also the LossScalingOptimizer makes sense. But the DistributedOptimizer is a hack IMO.

@tgaddair
Copy link
Author

Hey @Flamefire, I think we're on the same page here. Essentially, these different ways of doing the same thing need to be unified. The only question is whether it should be done via a custom gradient tape or custom hooks into the optimizer. Or put differently, how much of the training loop should be managed as internals of the optimizer.

I don't have a strong preference either way, so long as TensorFlow chooses to be consistent going forward (which has historically been an issue, as you've pointed out with the three different ways of doing the same thing). I like keeping DistributedOptimizer because it doesn't require Horovod users to relearn the API for TensorFlow 2, but that's secondary to consistency and consolidation.

@relativeflux
Copy link

relativeflux commented Mar 3, 2021

What is the actual solution for this? I see the TF2 Keras example in the docs still uses experimental_run_tf_function=False, but is the recommended solution for TF>=2.2 to override model train_step (which I am doing anyway) and use DistributedGradientTape with a plain optimizer?

@MarkDaoust MarkDaoust assigned MarkDaoust and unassigned tanzhenyu and MarkDaoust Mar 3, 2021
@tgaddair
Copy link
Author

tgaddair commented Mar 3, 2021

Hey @relativeflux, the TF2 Keras example uses experimental_run_tf_funcion for backwards compatibility with older versions of TF2. However, with versions >= 2.2, you should be able to use model.fit() without it. The DistributedGradientTape approach is also fine if you aren't using model.fit().

@relativeflux
Copy link

@tgaddair Excellent, that's good to know.

@rmothukuru
Copy link
Contributor

@tgaddair,
With respect to your comment, can you please let us know if we can close this issue? Thanks!

@rmothukuru rmothukuru self-assigned this Apr 14, 2021
@rmothukuru rmothukuru added the stat:awaiting response Status - Awaiting response from author label Apr 14, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Apr 21, 2021
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.1 for tracking issues in 2.1 release type:docs-bug Document issues
Projects
None yet
Development

No branches or pull requests