-
Notifications
You must be signed in to change notification settings - Fork 345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for virtual steps that accumulate per-sample or clipped gradients #16
Closed
Closed
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
3572425
accumulate per-sample batches
ftramer 6235399
merge with refactored clipper
ftramer 8df2433
accumulate in autograd
ftramer 8274334
per-sample grad accumulation
ftramer fa9fec3
accumulate mini-batches in gradient clipper
ftramer f646ebb
gradient accumulator for virtual batches
ftramer 9dcc569
rename 'accumulate_grads' to 'virtual_step'
ftramer 0da3fbc
typo
ftramer affcb1b
replace lambda function by functools.partial
ftramer 1ada364
allow for batches that are smaller than expected
ftramer d716398
graciously handle the zero-privacy case in the analysis
ftramer 68ca40e
keep activations in memory
ftramer b5d4e64
fix failing test
ftramer 6e241c3
store accumulated gradients in each parameter
ftramer d95818d
merge changes to privacy engine
ftramer 452ae78
remove the patching on zero_grad and delete accumulated gradients
ftramer 562b2e4
Fix tests that assume that param.grad_sample will still be available
ftramer f40a0d0
Run black formatter
ftramer ab10000
Run black formatter
ftramer f8e2645
effective_batch_size => n_accumulation_steps
ftramer ed74cc2
Merge branch 'master' of https://github.com/facebookresearch/pytorch-dp
ftramer c422de5
Merge branch 'master' of https://github.com/facebookresearch/pytorch-dp
ftramer 74a436d
fix comments, linter, warnings
ftramer eaa87ad
comment and formatting fixes
ftramer e5ab1e5
comment fixes
ftramer 80f30ed
clarify batch_size retrieval
ftramer d6314fb
clarify batch_size retrieval
ftramer 9ddacff
clarify virtual batch test cases
ftramer 58e39f6
test that accumulated gradients are erased upon a call to optimizer.s…
ftramer a8ca4dc
remove un-used zero_grad method
ftramer c3f6601
implement virtual batches
ftramer 8a6599e
Merge branch 'virtual_batch'
ftramer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we not taking a step if len_data/batch is not dividable by effective_batch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's an example of what I had in mind:
len(train_loader) = 1250
batch_size = 100
effective_batch_size = 500
so
n_virtual_steps = 5
, andi
ranges from 0 to 12In a training epoch, we'll call
optimizer.step()
when(i+1)=5
, and when(i+1)=10
, to process batches of size 500.Then, the last steps will process mini-batches of size 100, 100, 50, and exit the training loop.
So we're left with 3 mini-batches accumulated into
param.grad_sample
orparam.summed_grad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exactly, so what's happening to those in the next epoch? am I missing something?
we either have to miss them by calling zero grad or handle them at the end of the loop, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, so in the next epoch
i
is reset to 0 so we'll accumulate 5 more mini-batches before we calloptimizer.step()
.So now we have an effective batch consisting of 8 mini-batches (3 from the previous epoch, and 5 from the current epoch) and with 750 samples total.
This is problematic for two reasons:
It could cause memory issues if we're planning on holding the entire batch in memory, and we were expecting at most 500 gradients.
The privacy accounting for that batch is incorrect because we're assuming a sample probability of
effective_batch_size / len(train_loader) = 500/1250
whereas this batch has a sample probability of750/1250
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So my thought is we should err on the conservative side (as epsilon is a an upper-bound).
To do that I'd suggest:
either:
calling step() once at the end of the loop and checking inside step if there is un-finished business, we could check the accumulation_state or yet cleaner have a flag that tracks state. if a virtual step is called the flag is set and on virtual=false flag is cleared.
or:
do the same, but call the method finalize() [just sounds better and less confusing IMO]. For this you could simply just promote finalize_batch.
or:
handle everything internally in the engine counting number of calls to step and finalizing whenever needed. For this there would be no need for a virtual_step method at all. [I kind of like this one, but also I know this completely contradicts the talks we had offline, so may be we can do this in another diff later on [let me know what you think about this :) ]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you make this change as well please (I'd assume it would need some minimal changes in the step function as well) and we should be more than ready to land this amazing diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. There's still the question of when/where to delete the
param.grad_sample
fields so that they stop accumulating.If we want to do that in
clipper.step()
, we'd have to rewrite a bunch of tests.I think the most natural thing is still to monkeypatch
optimizer.zero_grad()
and do this there, as it aligns nicely with the semantics ofzero_grad()
for accumulation of regular gradients.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the latest commits, the deletions are now handled directly in
clipper.step()
, so that gradient aggregators are reset as soon as they are consumed.And we just have to make sure to always take a step on the last mini-batch of a training epoch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's true that
zero_grad
would make more sense, but we save a whole lot of memory by doing it in clipper, so :)