Skip to content

Commit 177cd98

Browse files
authored
Merge 8aac554 into c7d6144
2 parents c7d6144 + 8aac554 commit 177cd98

File tree

4 files changed

+489
-0
lines changed

4 files changed

+489
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import sys
18+
import unittest
19+
20+
import torch
21+
import torch.distributed as dist
22+
import torch.multiprocessing as mp
23+
import torch.nn as nn
24+
import torch.optim as optim
25+
from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import (
26+
DistributedDPOptimizerFastGradientClipping,
27+
)
28+
from opacus.utils.adaptive_clipping.adaptive_clipping_utils import (
29+
PrivacyEngineAdaptiveClipping,
30+
)
31+
from torch.nn.parallel import DistributedDataParallel as DDP
32+
from torch.utils.data import DataLoader, TensorDataset
33+
from torch.utils.data.distributed import DistributedSampler
34+
35+
36+
def setup(rank, world_size):
37+
if sys.platform == "win32":
38+
raise ValueError("Windows platform is not supported for this test")
39+
else:
40+
os.environ["MASTER_ADDR"] = "localhost"
41+
os.environ["MASTER_PORT"] = "12355"
42+
43+
# initialize the process group
44+
45+
os.environ["RANK"] = str(rank)
46+
os.environ["WORLD_SIZE"] = str(world_size)
47+
torch.distributed.init_process_group(
48+
init_method="env://",
49+
backend="nccl",
50+
)
51+
52+
53+
def cleanup():
54+
dist.destroy_process_group()
55+
56+
57+
class ToyModel(nn.Module):
58+
def __init__(self):
59+
super(ToyModel, self).__init__()
60+
self.net1 = nn.Linear(10, 10)
61+
self.relu = nn.ReLU()
62+
self.net2 = nn.Linear(10, 5)
63+
64+
def forward(self, x):
65+
return self.net2(self.relu(self.net1(x)))
66+
67+
68+
def demo_basic(rank, weight, world_size, dp):
69+
torch.manual_seed(world_size)
70+
batch_size = 32
71+
setup(rank, world_size)
72+
73+
# create model and move it to GPU with id rank
74+
model = ToyModel().to(rank)
75+
model.net1.weight.data.zero_()
76+
optimizer = optim.SGD(model.parameters(), lr=1)
77+
78+
# create dataset
79+
labels = torch.randn(2 * batch_size, 5).to(rank)
80+
data = torch.randn(2 * batch_size, 10)
81+
dataset = TensorDataset(data, labels)
82+
83+
criterion = nn.CrossEntropyLoss(reduction="mean")
84+
85+
max_grad_norm = 1e8
86+
87+
ddp_model = DDP(model, device_ids=[rank])
88+
89+
privacy_engine = PrivacyEngineAdaptiveClipping()
90+
91+
sampler = DistributedSampler(
92+
dataset, num_replicas=world_size, rank=rank, shuffle=False
93+
)
94+
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
95+
96+
if dp:
97+
ddp_model, optimizer, criterion, data_loader = privacy_engine.make_private(
98+
module=ddp_model,
99+
optimizer=optimizer,
100+
criterion=criterion,
101+
data_loader=data_loader,
102+
noise_multiplier=0,
103+
max_grad_norm=max_grad_norm,
104+
poisson_sampling=False,
105+
grad_sample_mode="ghost",
106+
target_unclipped_quantile=1.0,
107+
)
108+
assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping)
109+
110+
for x, y in data_loader:
111+
outputs = ddp_model(x.to(rank))
112+
loss = criterion(outputs, y)
113+
optimizer.zero_grad()
114+
loss.backward()
115+
optimizer.step()
116+
break
117+
118+
weight.copy_(model.net1.weight.data.cpu())
119+
cleanup()
120+
121+
122+
def run_demo(demo_fn, weight, world_size, dp):
123+
mp.spawn(
124+
demo_fn,
125+
args=(weight, world_size, dp),
126+
nprocs=world_size,
127+
join=True,
128+
)
129+
130+
131+
class GradientComputationTestAdaptiveClipping(unittest.TestCase):
132+
def test_gradient_correct_adaptive(self) -> None:
133+
134+
# Tests that gradient is the same with DP or without DP in the distributed setting
135+
n_gpus = torch.cuda.device_count()
136+
self.assertTrue(
137+
n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}."
138+
)
139+
140+
weight_dp, weight_nodp = torch.ones(10, 10), torch.ones(10, 10)
141+
142+
run_demo(
143+
demo_basic,
144+
weight_nodp,
145+
2,
146+
dp=False,
147+
)
148+
run_demo(
149+
demo_basic,
150+
weight_dp,
151+
2,
152+
dp=True,
153+
)
154+
155+
self.assertTrue(torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3))
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Adaptive Clipping (with Ghost Clipping)
2+
3+
Adaptive clipping [1] adapts the clipping norm (and amount of noise) during training to a quantile of per-sample gradient norms. It can reduce hyper-parameter tuning efforts and improve model accuracy by injecting less noise.
4+
5+
It is supported with:
6+
- Ghost clipping
7+
- Distributed data parallel training
8+
9+
It is **not** currently supported with:
10+
- Vanilla DP-SGD
11+
- Virtual batch sizes via Batch Memory Manager
12+
13+
## Overview
14+
15+
`PrivacyEngineAdaptiveClipping` is the entry-point for adaptive clipping training. It extends `PrivacyEngine` with additional arguments for adaptive clipping:
16+
17+
* `target_unclipped_quantile`: the quantile of per-sample gradient norms at which to clip (between 0 and 1)
18+
* `min_clipbound`: the minimum allowed clipping norm
19+
* `max_clipbound`: the maximum allowed clipping norm
20+
* `clipbound_learning_rate`: the learning rate for tracking the true quantile
21+
* `max_grad_norm`: the initial clipping norm (used at step 0)
22+
23+
The main hyper-parameter to tune is `target_unclipped_quantile`, which replaces tuning the clipping norm (`max_grad_norm`) in constant clipping DP-SGD. This parameter can be easier to tune, since the search is over a smaller range of values.
24+
25+
26+
## Example usage
27+
28+
```python
29+
from opacus.utils.adaptive_clipping.adaptive_clipping_utils import PrivacyEngineAdaptiveClipping
30+
31+
# ...
32+
privacy_engine = PrivacyEngineAdaptiveClipping()
33+
model, optimizer, criterion, train_loader = privacy_engine.make_private(
34+
module=model,
35+
optimizer=optimizer,
36+
data_loader=train_loader,
37+
criterion=criterion,
38+
noise_multiplier=args.sigma,
39+
max_grad_norm=10, # initial clipping norm
40+
grad_sample_mode="ghost",
41+
target_unclipped_quantile=0.5, # key parameter, may need tuning
42+
min_clipbound=1, # default value
43+
max_clipbound=1e8, # default value
44+
clipbound_learning_rate=0.2 # default value, tuning not recommended
45+
)
46+
# ...
47+
```
48+
49+
Note that `grad_sample_mode` must be set to `"ghost"` for adaptive clipping to work.
50+
51+
## References
52+
53+
[1] Galen Andrew, Om Thakkar, H. Brendan McMahan, Swaroop Ramaswamy, "Differentially Private Learning with Adaptive Clipping", NeurIPS, 2021.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .adaptive_clipping_utils import (
2+
DPLossFastGradientAdaptiveClipping,
3+
DPTensorFastGradientAdaptiveClipping,
4+
PrivacyEngineAdaptiveClipping,
5+
)
6+
7+
8+
__all__ = [
9+
"DPTensorFastGradientAdaptiveClipping",
10+
"DPLossFastGradientAdaptiveClipping",
11+
"PrivacyEngineAdaptiveClipping",
12+
]

0 commit comments

Comments
 (0)