Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for virtual steps that accumulate per-sample or clipped gradients #16

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3572425
accumulate per-sample batches
ftramer May 13, 2020
6235399
merge with refactored clipper
ftramer May 13, 2020
8df2433
accumulate in autograd
ftramer May 13, 2020
8274334
per-sample grad accumulation
ftramer May 13, 2020
fa9fec3
accumulate mini-batches in gradient clipper
ftramer May 14, 2020
f646ebb
gradient accumulator for virtual batches
ftramer May 14, 2020
9dcc569
rename 'accumulate_grads' to 'virtual_step'
ftramer May 14, 2020
0da3fbc
typo
ftramer May 14, 2020
affcb1b
replace lambda function by functools.partial
ftramer May 15, 2020
1ada364
allow for batches that are smaller than expected
ftramer May 15, 2020
d716398
graciously handle the zero-privacy case in the analysis
ftramer May 15, 2020
68ca40e
keep activations in memory
ftramer May 15, 2020
b5d4e64
fix failing test
ftramer May 15, 2020
6e241c3
store accumulated gradients in each parameter
ftramer May 15, 2020
d95818d
merge changes to privacy engine
ftramer May 19, 2020
452ae78
remove the patching on zero_grad and delete accumulated gradients
ftramer May 19, 2020
562b2e4
Fix tests that assume that param.grad_sample will still be available
ftramer May 19, 2020
f40a0d0
Run black formatter
ftramer May 19, 2020
ab10000
Run black formatter
ftramer May 19, 2020
f8e2645
effective_batch_size => n_accumulation_steps
ftramer May 20, 2020
ed74cc2
Merge branch 'master' of https://github.com/facebookresearch/pytorch-dp
ftramer May 21, 2020
c422de5
Merge branch 'master' of https://github.com/facebookresearch/pytorch-dp
ftramer May 21, 2020
74a436d
fix comments, linter, warnings
ftramer May 26, 2020
eaa87ad
comment and formatting fixes
ftramer May 27, 2020
e5ab1e5
comment fixes
ftramer May 27, 2020
80f30ed
clarify batch_size retrieval
ftramer May 27, 2020
d6314fb
clarify batch_size retrieval
ftramer May 27, 2020
9ddacff
clarify virtual batch test cases
ftramer May 27, 2020
58e39f6
test that accumulated gradients are erased upon a call to optimizer.s…
ftramer May 27, 2020
a8ca4dc
remove un-used zero_grad method
ftramer May 27, 2020
c3f6601
implement virtual batches
ftramer May 27, 2020
8a6599e
Merge branch 'virtual_batch'
ftramer May 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions examples/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

"""
Runs ImageNet training with differential privacy.

"""

import argparse
Expand Down Expand Up @@ -48,7 +47,7 @@

# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": True, "enable_stat": True}
clipping = {"clip_per_layer": False, "enable_stat": True}

parser = argparse.ArgumentParser(description="PyTorch ImageNet DP Training")
parser.add_argument("data", metavar="DIR", help="path to dataset")
Expand Down Expand Up @@ -80,6 +79,14 @@
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch for SGD",
)
parser.add_argument(
"--lr",
"--learning-rate",
Expand Down Expand Up @@ -371,7 +378,7 @@ def main_worker(gpu, ngpus_per_node, args):
print("PRIVACY ENGINE ON")
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
Expand Down Expand Up @@ -455,7 +462,14 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()

if args.n_accumulation_steps > 1:
optimizer.virtual_step()

# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we not taking a step if len_data/batch is not dividable by effective_batch

Copy link
Contributor Author

@ftramer ftramer May 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an example of what I had in mind:

  • len(train_loader) = 1250
  • batch_size = 100
  • effective_batch_size = 500

so n_virtual_steps = 5, and i ranges from 0 to 12

In a training epoch, we'll call optimizer.step() when (i+1)=5, and when (i+1)=10, to process batches of size 500.
Then, the last steps will process mini-batches of size 100, 100, 50, and exit the training loop.
So we're left with 3 mini-batches accumulated into param.grad_sample or param.summed_grad

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly, so what's happening to those in the next epoch? am I missing something?
we either have to miss them by calling zero grad or handle them at the end of the loop, correct?

Copy link
Contributor Author

@ftramer ftramer May 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, so in the next epoch i is reset to 0 so we'll accumulate 5 more mini-batches before we call optimizer.step().
So now we have an effective batch consisting of 8 mini-batches (3 from the previous epoch, and 5 from the current epoch) and with 750 samples total.

This is problematic for two reasons:

  1. It could cause memory issues if we're planning on holding the entire batch in memory, and we were expecting at most 500 gradients.

  2. The privacy accounting for that batch is incorrect because we're assuming a sample probability of effective_batch_size / len(train_loader) = 500/1250 whereas this batch has a sample probability of 750/1250.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my thought is we should err on the conservative side (as epsilon is a an upper-bound).
To do that I'd suggest:
either:
calling step() once at the end of the loop and checking inside step if there is un-finished business, we could check the accumulation_state or yet cleaner have a flag that tracks state. if a virtual step is called the flag is set and on virtual=false flag is cleared.
or:
do the same, but call the method finalize() [just sounds better and less confusing IMO]. For this you could simply just promote finalize_batch.
or:
handle everything internally in the engine counting number of calls to step and finalizing whenever needed. For this there would be no need for a virtual_step method at all. [I kind of like this one, but also I know this completely contradicts the talks we had offline, so may be we can do this in another diff later on [let me know what you think about this :) ]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make this change as well please (I'd assume it would need some minimal changes in the step function as well) and we should be more than ready to land this amazing diff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. There's still the question of when/where to delete the param.grad_sample fields so that they stop accumulating.
If we want to do that in clipper.step(), we'd have to rewrite a bunch of tests.
I think the most natural thing is still to monkeypatch optimizer.zero_grad() and do this there, as it aligns nicely with the semantics of zero_grad() for accumulation of regular gradients.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the latest commits, the deletions are now handled directly in clipper.step(), so that gradient aggregators are reset as soon as they are consumed.
And we just have to make sure to always take a step on the last mini-batch of a training epoch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true that zero_grad would make more sense, but we save a whole lot of memory by doing it in clipper, so :)


# measure elapsed time
batch_time.update(time.time() - end)
Expand Down
136 changes: 66 additions & 70 deletions torchdp/autograd_grad_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Original license is Unlicense. We put it here for user's convenience, with
the author's permission.
"""

from functools import partial
from typing import List

import torch
Expand All @@ -23,24 +23,42 @@
_enforce_fresh_backprop: bool = False


def add_hooks(model: nn.Module) -> None:
def add_hooks(model: nn.Module, loss_type: str = "mean", batch_dim: int = 0) -> None:
ftramer marked this conversation as resolved.
Show resolved Hide resolved
"""
Adds hooks to model to save activations and backprop values.
The hooks will
1. save activations into param.activations during forward pass
2. append backprops to params.backprops_list during backward pass.
2. compute per-sample gradients in params.grad_sample during backward pass.
Call "remove_hooks(model)" to disable this.
Args:
model:
model: the model to which hooks are added
loss_type: either "mean" or "sum" depending on whether backpropped
loss was averaged or summed over batch (default: "mean")
batch_dim: the batch dimension (default: 0)
"""
if hasattr(model, "autograd_grad_sample_hooks"):
raise ValueError("Trying to add hooks twice to the same model")

global _hooks_disabled
_hooks_disabled = False

if loss_type not in ("sum", "mean"):
raise ValueError(
f"loss_type = {loss_type}. Only 'sum' and 'mean' losses are supported"
)

handles = []
for layer in model.modules():
if get_layer_type(layer) in _supported_layers_grad_samplers.keys():
handles.append(layer.register_forward_hook(_capture_activations))
handles.append(layer.register_backward_hook(_capture_backprops))

handles.append(
layer.register_backward_hook(
partial(
_capture_backprops, loss_type=loss_type, batch_dim=batch_dim
)
)
)

model.__dict__.setdefault("autograd_grad_sample_hooks", []).extend(handles)

Expand Down Expand Up @@ -88,82 +106,60 @@ def _capture_activations(
layer.activations = input[0].detach()


def _capture_backprops(layer: nn.Module, _input, output):
"""Append backprop to layer.backprops_list in backward pass."""
global _enforce_fresh_backprop
def _capture_backprops(
layer: nn.Module,
_input: torch.Tensor,
output: torch.Tensor,
loss_type: str,
batch_dim: int,
):
"""Capture backprops in backward pass and store per-sample gradients."""

if _hooks_disabled:
return

if _enforce_fresh_backprop:
if hasattr(layer, "backprops_list"):
raise ValueError(
f"Seeing result of previous backprop, "
f"use clear_backprops(model) to clear"
)
_enforce_fresh_backprop = False

if not hasattr(layer, "backprops_list"):
layer.backprops_list = []
layer.backprops_list.append(output[0].detach())


def clear_backprops(model: nn.Module) -> None:
"""Delete layer.backprops_list in every layer."""
for layer in model.modules():
if hasattr(layer, "backprops_list"):
del layer.backprops_list


def _check_layer_sanity(layer):
if not hasattr(layer, "activations"):
raise ValueError(
f"No activations detected for {type(layer)},"
" run forward after add_hooks(model)"
)
if not hasattr(layer, "backprops_list"):
raise ValueError("No backprops detected, run backward after add_hooks(model)")
if len(layer.backprops_list) != 1:
raise ValueError(
"Multiple backprops detected, make sure to call clear_backprops(model)"
)
backprops = output[0].detach()
_compute_grad_sample(layer, backprops, loss_type, batch_dim)


def compute_grad_sample(
model: nn.Module, loss_type: str = "mean", batch_dim: int = 0
def _compute_grad_sample(
Anonymani marked this conversation as resolved.
Show resolved Hide resolved
layer: nn.Module, backprops: torch.Tensor, loss_type: str, batch_dim: int
) -> None:
"""
Compute per-example gradients and save them under 'param.grad_sample'.
Must be called after loss.backprop()
Args:
model:
loss_type: either "mean" or "sum" depending whether backpropped
layer: the layer for which per-sample gradients are computed
backprops: the captured backprops
loss_type: either "mean" or "sum" depending on whether backpropped
loss was averaged or summed over batch
batch_dim: the batch dimension
"""
if loss_type not in ("sum", "mean"):
raise ValueError(f"loss_type = {loss_type}. Only 'sum' and 'mean' supported")
for layer in model.modules():
layer_type = get_layer_type(layer)
if (
not requires_grad(layer)
or layer_type not in _supported_layers_grad_samplers.keys()
):
continue

_check_layer_sanity(layer)

A = layer.activations
n = A.shape[batch_dim]
if loss_type == "mean":
B = layer.backprops_list[0] * n
else: # loss_type == 'sum':
B = layer.backprops_list[0]
# rearrange the blob dimensions
if batch_dim != 0:
A = A.permute([batch_dim] + [x for x in range(A.dim()) if x != batch_dim])
B = B.permute([batch_dim] + [x for x in range(B.dim()) if x != batch_dim])
# compute grad sample for individual layers
compute_layer_grad_sample = _supported_layers_grad_samplers.get(
get_layer_type(layer)
layer_type = get_layer_type(layer)
if (
not requires_grad(layer)
or layer_type not in _supported_layers_grad_samplers.keys()
):
return

if not hasattr(layer, "activations"):
raise ValueError(
f"No activations detected for {type(layer)},"
" run forward after add_hooks(model)"
)
compute_layer_grad_sample(layer, A, B)

A = layer.activations
n = A.shape[batch_dim]
if loss_type == "mean":
B = backprops * n
else: # loss_type == 'sum':
B = backprops
# rearrange the blob dimensions
if batch_dim != 0:
A = A.permute([batch_dim] + [x for x in range(A.dim()) if x != batch_dim])
B = B.permute([batch_dim] + [x for x in range(B.dim()) if x != batch_dim])
# compute grad sample for individual layers
compute_layer_grad_sample = _supported_layers_grad_samplers.get(
get_layer_type(layer)
)
compute_layer_grad_sample(layer, A, B)
Loading