Skip to content

Commit 1be5d20

Browse files
author
Igor Shilov
committed
update all tutorials
1 parent f8ce4f1 commit 1be5d20

File tree

7 files changed

+3073
-3070
lines changed

7 files changed

+3073
-3070
lines changed

opacus/tests/batch_memory_manager_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,19 @@ def test_basic(
6262
max_grad_norm=1.0,
6363
poisson_sampling=False,
6464
)
65+
max_physical_batch_size = 3
6566
with BatchMemoryManager(
66-
data_loader=data_loader, max_physical_batch_size=3, optimizer=optimizer
67+
data_loader=data_loader,
68+
max_physical_batch_size=max_physical_batch_size,
69+
optimizer=optimizer,
6770
) as new_data_loader:
68-
self.assertEqual(len(data_loader), len(new_data_loader))
71+
self.assertEqual(
72+
len(data_loader), len(data_loader.dataset) // self.batch_size
73+
)
74+
self.assertEqual(
75+
len(new_data_loader),
76+
len(data_loader.dataset) // max_physical_batch_size,
77+
)
6978
weights_before = torch.clone(model._module.fc.weight)
7079
for i, (x, y) in enumerate(new_data_loader):
7180
self.assertTrue(x.shape[0] <= 3)

opacus/utils/batch_memory_manager.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33

44
import numpy as np
55
from opacus.optimizers import DPOptimizer
6-
from torch.utils.data import DataLoader, Sampler, BatchSampler
7-
from opacus.utils.uniform_sampler import UniformWithReplacementSampler, DistributedUniformWithReplacementSampler
6+
from opacus.utils.uniform_sampler import (
7+
DistributedUniformWithReplacementSampler,
8+
UniformWithReplacementSampler,
9+
)
10+
from torch.utils.data import BatchSampler, DataLoader, Sampler
11+
812

913
class BatchSplittingSampler(Sampler[List[int]]):
1014
def __init__(
@@ -27,13 +31,18 @@ def __iter__(self):
2731

2832
def __len__(self):
2933
if isinstance(self.sampler, BatchSampler):
30-
return int(len(self.sampler) * (self.sampler.batch_size / self.max_batch_size))
31-
elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance(self.sampler, DistributedUniformWithReplacementSampler):
34+
return int(
35+
len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
36+
)
37+
elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance(
38+
self.sampler, DistributedUniformWithReplacementSampler
39+
):
3240
expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
3341
return int(len(self.sampler) * (expected_batch_size / self.max_batch_size))
3442

3543
return len(self.sampler)
3644

45+
3746
def wrap_data_loader(data_loader, max_batch_size: int, optimizer: DPOptimizer):
3847
return DataLoader(
3948
dataset=data_loader.dataset,

opacus/validators/batch_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33

44
import logging
5-
import sys
65
from typing import Union
76

87
import torch.nn as nn
98

109
from .errors import ShouldReplaceModuleError, UnsupportableModuleError
1110
from .utils import register_module_fixer, register_module_validator
1211

12+
1313
logger = logging.getLogger(__name__)
1414

1515
BATCHNORM = Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm]

opacus/validators/module_validator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33

44
import logging
5-
import sys
65
from typing import List
76

87
import torch.nn as nn
@@ -13,8 +12,10 @@
1312
UnsupportedModuleError,
1413
)
1514

15+
1616
logger = logging.getLogger(__name__)
1717

18+
1819
class ModuleValidator:
1920
"""
2021
Encapsulates all the validation logic required by Opacus.
@@ -49,7 +50,7 @@ def validate(
4950
# 2. validate that all trainable modules are supported by GradSampleModule.
5051
errors.extend(GradSampleModule.validate(module=module, raise_if_error=False))
5152
# 3. perform module specific validations.
52-
#TODO: use module name here - it's useful part of error message
53+
# TODO: use module name here - it's useful part of error message
5354
for _, sub_module in module.named_modules():
5455
if type(sub_module) in ModuleValidator.VALIDATORS:
5556
sub_module_validator = ModuleValidator.VALIDATORS[type(sub_module)]

tutorials/building_image_classifier.ipynb

Lines changed: 1153 additions & 658 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)