New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR: Enable local gradient accumulation #546

Merged
merged 30 commits into from Dec 12, 2018

Conversation

Projects
None yet
5 participants
@andfoy
Copy link
Contributor

andfoy commented Oct 6, 2018

Fixes #543
Fixes #455

@CLAassistant

This comment has been minimized.

Copy link

CLAassistant commented Oct 6, 2018

CLA assistant check
All committers have signed the CLA.

@alsrgv
Copy link
Collaborator

alsrgv left a comment

Thanks for the PR. Good start! I left a few comments.

self._handles[p] = (handle, ctx)
return hook

def synchronize(self):
for p, value in self._handles.items():
handle, ctx = value
if not self._reduce_gradients:

This comment has been minimized.

@alsrgv

alsrgv Oct 7, 2018

Collaborator

It'd be better to check for handle is None, since flag may have been flipped after some gradients were submitted.

self._handles[p] = (handle, ctx)
return hook

def synchronize(self):
for p, value in self._handles.items():
handle, ctx = value
if not self._reduce_gradients:
warnings.warn("Attempting to synchronize an optimizer that "

This comment has been minimized.

@alsrgv

alsrgv Oct 7, 2018

Collaborator

We should mention why it's bad (by default gradients are allreduced asynchronously). Not sure we should fallback / mention it, since it may create an illusion of "no problem".

This comment has been minimized.

@alsrgv

alsrgv Oct 9, 2018

Collaborator

With new backward_passes_per_step, we can say something like Called step on optimizer before computing gradients backwards_passes_per_step times. This will cause performance degradation, as gradient exchange is interleaved with computation only on the backwards_passes_per_step'th pass.

@@ -63,6 +65,9 @@ def __init__(self, params, named_parameters, compression):
if size() > 1:
self._register_hooks()

def ignore_gradients(self, state):

This comment has been minimized.

@alsrgv

alsrgv Oct 7, 2018

Collaborator

I'll think a bit more about a better name for this method.

This comment has been minimized.

@andfoy

andfoy Oct 7, 2018

Contributor

@alsrgv What about disable/enable_gradient_updates

This comment has been minimized.

@tgaddair

tgaddair Oct 8, 2018

Collaborator

Thanks for putting this together @andfoy!

This is a very useful capability, but personally, I would rather not add this method to the optimizer API. I think it forces the user to make a lot of modifications on their end to experiment with local gradient accumulation, when it could be simplified instead to a single parameter to the DistributedOptimizer constructor:

opt = hvd.DistributedOptimizer(opt, aggregation_delay=5)

We could then manage all of the state internally without having to expose the optimizer API as a state machine, which I think would be more in line with our goals of making Horovod easy to plug into existing single-GPU scripts. The step() function would then be a no-op in the case that we're currently aggregating locally.

We lose some versatility, but I think it would greatly simplify things and make the feature less intimidating for people who want to try it out without modifying their training script (and less prone to error).

@alsrgv thoughts?

This comment has been minimized.

@alsrgv

alsrgv Oct 8, 2018

Collaborator

Assuming the aggregation_delay is static, this sounds like a neat idea!

This comment has been minimized.

@andfoy

andfoy Oct 8, 2018

Contributor

Thanks for your suggestions, it sounds good, but I have a corner case (Like in our research), what would happen if the aggregation_delay is dynamic and not static. e.g., In variable-sized batches

This comment has been minimized.

@tgaddair

tgaddair Oct 8, 2018

Collaborator

Good question. I was wondering if it would suffice to have a opt.set_aggregation_delay(delay) method? That way advanced users could set it dynamically based on some criteria, and beginning users could just set it via the constructor. Would that address the corner case you had in mind?

This comment has been minimized.

@andfoy

andfoy Oct 8, 2018

Contributor

@tgaddair, that sounds perfect! So, let me introduce the changes on the PR, and discuss ir further

This comment has been minimized.

@alsrgv

alsrgv Oct 8, 2018

Collaborator

@andfoy, if the delay is dynamic, I'm wondering about a couple of things:

  1. Will the number of synchronizations be the same across all workers?
  2. Do you always know the aggregation delay before you start aggregating?

If not, we may still want to expose finer-grained controls.

This comment has been minimized.

@andfoy

andfoy Oct 8, 2018

Contributor

@alsrgv Addressing your two questions:

  1. In my specific case, there is a single synchronization step across all workers once they have finished their individual gradient accumulations.
  2. Before starting an accumulation loop, the number of steps is known by each worker.

The general idea would be to synchronize once all workers have finished, which is limited by the worker that has the largest gradient accumulation iteration number. However, I don't know in which cases the more general scenario (Unknown number of iterations, different synchronization scheme) would arise, but I think that in that case it would be preferable to disable/enable gradient synchronization as it is needed.

andfoy added some commits Oct 7, 2018

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 8, 2018

@alsrgv @tgaddair I've pushed a version of this PR using set_aggregation_delay, but I can go back to ignore_gradients if you decide to.

@alsrgv
Copy link
Collaborator

alsrgv left a comment

I propose we rename aggregation_delay to backward_passes_per_step. Let's make it super clear in documentation to DistributedOptimizer that we interleave allreduce with last backward pass of the step.

What do you think?

@@ -40,7 +41,8 @@


class _DistributedOptimizer(torch.optim.Optimizer):
def __init__(self, params, named_parameters, compression):
def __init__(self, params, named_parameters, compression,
aggregation_delay=1):

This comment has been minimized.

@alsrgv

alsrgv Oct 9, 2018

Collaborator

Can you expose backward_passes_per_step to DistributedOptimizer() constructor?

if p in self._handles and self._handles[p][0] is not None:
if self._parameter_update_delay[p] <= 0:
raise AssertionError(
"Gradients were accumulated twice without a "

This comment has been minimized.

@alsrgv

alsrgv Oct 9, 2018

Collaborator

With other comments in mind, we can say something like Gradients were computed more than backward_passes_per_step times without call to step(). Please increase backward_passes_per_step to accumulate gradients locally.

"have a performance impact or cause "
"synchronization failures")
handle, ctx = self._all_reduce_grad(p)
self.ignore_gradients(False)

This comment has been minimized.

@alsrgv

alsrgv Oct 9, 2018

Collaborator

Don't need this.

self._handles[p] = (handle, ctx)
return hook

def synchronize(self):
for p, value in self._handles.items():
handle, ctx = value
if not self._reduce_gradients:
warnings.warn("Attempting to synchronize an optimizer that "

This comment has been minimized.

@alsrgv

alsrgv Oct 9, 2018

Collaborator

With new backward_passes_per_step, we can say something like Called step on optimizer before computing gradients backwards_passes_per_step times. This will cause performance degradation, as gradient exchange is interleaved with computation only on the backwards_passes_per_step'th pass.

@alsrgv

This comment has been minimized.

Copy link
Collaborator

alsrgv commented Oct 9, 2018

@andfoy, by the way - would you want to divide by # of backward_passes_per_step before allreduce?

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 9, 2018

would you want to divide by # of backward_passes_per_step before allreduce?

@alsrgv are you referring to scale the gradients by backward_passes_per_step?

@alsrgv

This comment has been minimized.

Copy link
Collaborator

alsrgv commented Oct 9, 2018

@andfoy, yes - divide gradients by backward_passes_per_step before allreduce.

@alsrgv

This comment has been minimized.

Copy link
Collaborator

alsrgv commented Oct 9, 2018

@andfoy, sorry, that was a question. :-) Would such feature be helpful to you or other use cases you can imagine?

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 9, 2018

sorry, that was the question. :-) Would such feature be helpful to you or other use cases you can imagine?

Hahaha, no worries. Well, I was thinking that all gradient accumulation examples suggest users to add the scaling factor on the loss calculation, so it would be less transparent if the Optimizer does it automatically. Besides, in some cases (Like ours), we would like to scale the gradient by other factors, which is more user-case specific, rather than a single solution for all.

andfoy added some commits Oct 9, 2018

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 9, 2018

@alsrgv I'm having a strange error on Travis, do you know what it could be?

@alsrgv

This comment has been minimized.

Copy link
Collaborator

alsrgv commented Oct 10, 2018

@andfoy, I just fixed the breaking change from PyTorch 1.0 that has been causing unit tests to fail.
If you rebase on master, the break should be fixed.

Besides, in some cases (Like ours), we would like to scale the gradient by other factors, which is more user-case specific, rather than a single solution for all.

Would you scale gradients before or after averaging them? Will there be use cases where you'd scale them before averaging?

We had more discussions with @tgaddair and are now thinking about some sort of extensibility model for DistributedOptimizer that would allow to decouple functionality like this into a separate class.

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 11, 2018

Would you scale gradients before or after averaging them? Will there be use cases where you'd scale them before averaging?

Well, the idea would be to average the loss itself, such that it propagates to the gradients. However, it might be possible that someone prefers to average the gradient values, due to numerical precision or other issues. I don't know how to tackle this from a holistic point-of-view, what are your thoughts @alsrgv @tgaddair ?

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 11, 2018

We had more discussions with @tgaddair and are now thinking about some sort of extensibility model for DistributedOptimizer that would allow to decouple functionality like this into a separate class.

Would you like to introduce this new DistributedOptimizer wrapper on this PR?

@whr94621

This comment has been minimized.

Copy link

whr94621 commented Oct 12, 2018

I have a question on the implementation of DistributedOptimizer. Why we should use hook to implement the allreduce operation of gradients? Can we manually do this opetation in step method, i.e. :

    def step(self, closure=None):
        # just for example
        # for p in all parameters: 
        #     all_reduce_grad_(p)
        self.synchronize()
        return super(self.__class__, self).step(closure)
@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 12, 2018

@whr94621 We allow users to synchronize their gradients, even if they forgot to perform backward backward_passes_per_step times

@whr94621

This comment has been minimized.

Copy link

whr94621 commented Oct 12, 2018

@whr94621 We allow users to synchronize their gradients, even if they forgot to perform backward backward_passes_per_step times

All right now I see that line of code. But I am still curious about the 'performance degradation' mentioned in warning information.

If we always do handle, ctx = self._all_reduce_grad(p) in step, like fix backward_passes_per_step to a very large value, why it will become slower than implementation in this PR.

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 12, 2018

@whr94621 Allreduce instructions are sent to MPI distribution nodes long before they update their parameters, so you would expect that by the time synchronize gets executed, it has already allreduce'd all the gradients across all nodes.

@alsrgv
Copy link
Collaborator

alsrgv left a comment

After more discussion we decided to keep API as-is for now, and once we see more use cases refactor this with a period of a backwards compatibility.

Could you add this functionality to PyTorch ImageNet example as well? This allows to trade-off the total batch size for bandwidth, which should help scaling in low-bandwidth environments.

Show resolved Hide resolved horovod/torch/__init__.py
@@ -72,23 +76,44 @@ def _register_hooks(self):
grad_acc.register_hook(self._make_hook(p))
self._grad_accs.append(grad_acc)

def _all_reduce_grad(self, p):

This comment has been minimized.

@alsrgv

alsrgv Oct 12, 2018

Collaborator

Nit: lets make s/all_reduce/allreduce/g for consistency.

def _make_hook(self, p):
def hook(*ignore):
assert p not in self._handles
if p in self._handles and self._handles[p][0] is not None:
if self._parameter_update_delay[p] <= 0:

This comment has been minimized.

@alsrgv

alsrgv Oct 12, 2018

Collaborator

This if self._parameter_update_delay[p] <= 0 should be turned into another assertion.

if self._parameter_update_delay[p] <= 0:
raise AssertionError(
"Gradients were computed more than "
"backward_passes_per_step times call to step(). "

This comment has been minimized.

@alsrgv

alsrgv Oct 12, 2018

Collaborator

s/times call/times before call/

self._handles[p] = (handle, ctx)
return hook

def synchronize(self):
for p, value in self._handles.items():
handle, ctx = value
if handle is None:
warnings.warn("Called step on optimizer before computing "

This comment has been minimized.

@alsrgv

alsrgv Oct 12, 2018

Collaborator

s/step/step() or synchronize()/

output = synchronize(handle)
self._parameter_update_delay[p] = self.backward_passes_per_step

This comment has been minimized.

@alsrgv

alsrgv Oct 12, 2018

Collaborator

s/_parameter_update_delay/_allreduce_delay/g

@whr94621

This comment has been minimized.

Copy link

whr94621 commented Oct 13, 2018

@alsrgv Ok now I understand the usage. Thank you for your explanation!

andfoy added some commits Oct 18, 2018

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 18, 2018

@alsrgv Done!

@alsrgv
Copy link
Collaborator

alsrgv left a comment

Thanks! I tried running a convergence test, but bumped into a few issues.

self._handles[p] = (handle, ctx)
return hook

def synchronize(self):
for p, value in self._handles.items():
handle, ctx = value
if handle is None:
warnings.warn("Called step()/synchronize() on optimizer "

This comment has been minimized.

@alsrgv

alsrgv Oct 18, 2018

Collaborator

Let's remove this warning (and import warnings) since this will routinely happen due to DistributedSampler.

loss.backward()
# Split batch_size * batches_per_allreduce batch into sub-batches
# of size batch_size
for i in range(0, allreduce_batch_size, args.batch_size):

This comment has been minimized.

@alsrgv

alsrgv Oct 18, 2018

Collaborator

Turns out DistributedSampler just makes sure that every worker has the same # of batches, but does not guarantee that the number of data points is divisible by the batch size.

So, we should do something like this:

# Split `data` into batches of size `batch_size`.
for i in range(0, len(data), args.batch_size):
    ...
output = model(data_batch)
train_accuracy.update(accuracy(output, target_batch))
loss = F.cross_entropy(output, target_batch)
# Average gradients among sub-batches

This comment has been minimized.

@alsrgv

alsrgv Oct 18, 2018

Collaborator

Train loss should be updated before re-scaling loss, since it does its own averaging. Also, we need to scale by proper number of batches.

train_loss.update(loss)
# Average gradients among batches in this data slice.
loss.div_(math.ceil(float(len(data)) / args.batch_size)
"degradation, as gradient exchange is "
"interleaved with computation only on the "
"backwards_passes_per_step'th pass.")
handle, ctx = self._allreduce_grad(p)

This comment has been minimized.

@alsrgv

alsrgv Oct 18, 2018

Collaborator

We need to do two passes. In a first pass, we check all the handles and backfill missing allreduces. In a second pass, we synchronize all the handles. If we don't do that, we can get a race and hang forever.

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 19, 2018

@alsrgv Give it a try!

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 24, 2018

@alsrgv Any updates on this one?

@alsrgv
Copy link
Collaborator

alsrgv left a comment

@andfoy, I'm running convergence tests. A couple of leftover comments:

output = model(data_batch)
train_accuracy.update(accuracy(output, target_batch))
loss = F.cross_entropy(output, target_batch)
train_loss.update(loss.item())

This comment has been minimized.

@alsrgv

alsrgv Oct 24, 2018

Collaborator

Need to remove .item(): train_loss.update(loss)

handle, ctx = None, None
self._allreduce_delay[p] -= 1
if self._allreduce_delay[p] == 0:
handle, ctx = self._allreduce_grad(p)
self._handles[p] = (handle, ctx)
return hook

def synchronize(self):
for p, value in self._handles.items():
handle, ctx = value

This comment has been minimized.

@alsrgv

alsrgv Oct 24, 2018

Collaborator

Can in-line value, like you do below:

for p, (handle, ctx) in self._handles.items():
@alsrgv

This comment has been minimized.

Copy link
Collaborator

alsrgv commented Oct 25, 2018

Bad news. Both current version (with fixes that I mentioned earlier) and the original idea of not adjusting gradients don't converge correctly.

Training curves:
screen shot 2018-10-24 at 8 19 00 pm

Validation curves:
screen shot 2018-10-24 at 8 18 43 pm

Blue is the curve of Horovod 0.15.1 in FP16 mode run over 64 GPUs. Orange is this PR run over 16 GPUs with batches-per-allreduce=4. Gray is an older version of this PR that did not adjust gradients.

Any ideas why it doesn't converge as well?

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Oct 25, 2018

Any ideas why it doesn't converge as well?

Maybe we're not rescaling the gradients correctly? Or is the lr wrong? Shall we try without rescaling the lr? Let me check the maths regarding gradient accumulation and I'll get back to this

@alsrgv

This comment has been minimized.

Copy link
Collaborator

alsrgv commented Nov 1, 2018

@andfoy, did you get any insights?

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Nov 2, 2018

@alsrgv According to the maths, we should divide the loss only by args.batches_per_allreduce . Also, we should increase the LR by hvd.size()

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Nov 2, 2018

We need to update the gradients as it follows, where M is the total number of GPUs, N corresponds to the gradient accumulation steps, and B is the batch size

imagen

@alsrgv I don't know if we need to scale the lr also by N?

@alsrgv

This comment has been minimized.

Copy link
Collaborator

alsrgv commented Nov 8, 2018

@andfoy, sorry for the delay. Isn't that what the current version is doing already?

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Nov 10, 2018

@alsrgv Don't worry, right now I'm very busy with the CVPR deadline. I'll get back to this after Friday

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Nov 18, 2018

@alsrgv What happens if we don't scale the learning by the total number of accumulations?

@andfoy

This comment has been minimized.

Copy link
Contributor

andfoy commented Dec 8, 2018

@alsrgv Any new updates/ideas on this one?

@alsrgv

This comment has been minimized.

Copy link
Collaborator

alsrgv commented Dec 10, 2018

@andfoy, I found the problem. In learning rate warmup adjustment we did not multiply LR by batches_per_allreduce. Could you rebase your branch, resolve conflicts and make this fix? I'm running the convergence test, so far it matches baseline perfectly.

@alsrgv
Copy link
Collaborator

alsrgv left a comment

One comment. Verification is almost done.

Show resolved Hide resolved examples/pytorch_imagenet_resnet50.py Outdated

@andfoy andfoy changed the title PR: Add ignore_gradients function to accumulate gradients locally PR: Enable distributed gradient accumulation Dec 11, 2018

@andfoy andfoy changed the title PR: Enable distributed gradient accumulation PR: Enable local gradient accumulation Dec 11, 2018

@alsrgv

alsrgv approved these changes Dec 12, 2018

Copy link
Collaborator

alsrgv left a comment

Convergence test has successfully passed. In my experiment GPU-hour efficiency was increased by 1.5x!

@alsrgv alsrgv merged commit 9081ba3 into uber:master Dec 12, 2018

2 checks passed

continuous-integration/travis-ci/pr The Travis CI build passed
Details
license/cla Contributor License Agreement is signed.
Details

@andfoy andfoy deleted the andfoy:disable_gradient_updates branch Dec 12, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment