Skip to content

Commit 8449d3d

Browse files
authored
Merge 8de232d into 7264cd7
2 parents 7264cd7 + 8de232d commit 8449d3d

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

opacus/tests/grad_sample_module_fast_gradient_clipping_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def setUp_data_sequantial(self, size, length, dim):
120120

121121
@given(
122122
size=st.sampled_from([10]),
123-
length=st.sampled_from([1, 10]),
123+
length=st.sampled_from([5]),
124124
dim=st.sampled_from([2]),
125125
)
126126
@settings(deadline=1000000)
@@ -192,12 +192,12 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
192192
diff = flat_norms_normal - flat_norms_gc
193193

194194
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
195-
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
195+
msg = "Fail: Per-sample gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
196196
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg
197197

198198
@given(
199199
size=st.sampled_from([10]),
200-
length=st.sampled_from([1, 10]),
200+
length=st.sampled_from([5]),
201201
dim=st.sampled_from([2]),
202202
)
203203
@settings(deadline=1000000)

opacus/tests/privacy_engine_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def _compare_to_vanilla(
268268
do_clip=st.booleans(),
269269
do_noise=st.booleans(),
270270
use_closure=st.booleans(),
271-
max_steps=st.sampled_from([1, 4]),
271+
max_steps=st.sampled_from([1, 3]),
272272
)
273273
@settings(suppress_health_check=list(HealthCheck), deadline=None)
274274
def test_compare_to_vanilla(
@@ -660,7 +660,7 @@ def test_checkpoints(
660660

661661
@given(
662662
noise_multiplier=st.floats(0.5, 5.0),
663-
max_steps=st.integers(8, 10),
663+
max_steps=st.integers(3, 5),
664664
secure_mode=st.just(False), # TODO: enable after fixing torchcsprng build
665665
)
666666
@settings(suppress_health_check=list(HealthCheck), deadline=None)

0 commit comments

Comments
 (0)