/
dp_optimizer_keras.py
602 lines (504 loc) · 25.9 KB
/
dp_optimizer_keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Differentially private version of Keras optimizer v2."""
from typing import List, Optional, Type, Union
import warnings
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import restart_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
_VarListType = List[Union[tf.Tensor, tf.Variable]]
def _normalize(microbatch_gradient: tf.Tensor,
num_microbatches: float) -> tf.Tensor:
"""Normalizes `microbatch_gradient` by `num_microbatches`."""
return tf.truediv(microbatch_gradient,
tf.cast(num_microbatches, microbatch_gradient.dtype))
def make_keras_generic_optimizer_class(
cls: Type[tf.keras.optimizers.Optimizer]):
"""Returns a differentially private (DP) subclass of `cls`.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.legacy.Optimizer`.
Returns:
A generic DP-SGD subclass of `cls`, compatible with many DP queries.
"""
class DPOptimizerClass(cls): # pylint: disable=empty-docstring,missing-class-docstring
__doc__ = """Differentially private subclass of class `{base_class}`.
You can use this as a differentially private replacement for
`{base_class}`. This optimizer implements a differentiallyy private version
of the stochastic gradient descent optimizer `cls` using the chosen
`dp_query.DPQuery` instance.
When instantiating this optimizer, you need to supply several
DP-related arguments followed by the standard arguments for
`{short_base_class}`.
Examples:
```python
# Create optimizer.
gaussian_query = gaussian_query.GaussianSumQuery(
l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1
)
opt = {dp_keras_class}(dp_sum_query=gaussian_query, <standard arguments>)
```
When using the optimizer, be sure to pass in the loss as a
rank-one tensor with one entry for each example.
The optimizer can be used directly via its `minimize` method, or
through a Keras `Model`.
```python
# Compute loss as a tensor by using tf.losses.Reduction.NONE.
# Compute vector of per-example loss rather than its mean over a minibatch.
loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE)
# Use optimizer in a Keras model.
opt.minimize(loss, var_list=[var])
```
```python
# Compute loss as a tensor by using tf.losses.Reduction.NONE.
# Compute vector of per-example loss rather than its mean over a minibatch.
loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE)
# Use optimizer in a Keras model.
model = tf.keras.Sequential(...)
model.compile(optimizer=opt, loss=loss, metrics=['accuracy'])
model.fit(...)
```
In DP-SGD training, a larger batch size typically helps to achieve better
privacy/utility tradeoff. However there is typically a maximum batch size
imposed by hardware.
This optimizer can emulate large batch sizes on hardware with limited
memory by accumulating gradients for several steps before actually
applying them to update model weights.
Constructor argument `gradient_accumulation_steps` controls the number
of steps for which gradients are accumulated before updating
the model weights.
Below is an example which demonstrates how to use this feature:
```python
# Create optimizer which will be accumulating gradients for 4 steps.
# and then performing an update of model weights.
gaussian_query = gaussian_query.GaussianSumQuery(
l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1
)
opt = {dp_keras_class}(dp_sum_query=gaussian_query,
num_microbatches=1,
gradient_accumulation_steps=4,
<standard arguments>)
# Use optimizer in a regular way.
# First three calls to opt.minimize won't update model weights and will
# only accumulate gradients. Model weights will be updated on the fourth
# call to opt.minimize
opt.minimize(loss, var_list=[var])
```
Note that when using this feature,
1. effective batch size is `gradient_accumulation_steps * one_step_batch_size`
where `one_step_batch_size` is the size of the batch passed to single step
of the optimizer. Thus user may have to adjust learning rate, weight decay
and possibly other training hyperparameters accordingly.
2. effective noise (the noise to be used for privacy computation) is
`noise_multiplier * sqrt(gradient_accumulation_steps)`, as the optimizer
adds noise of `self._noise_multiplier` to every step. Thus user may have
to adjust the `noise_multiplier` or the privacy computation.
Additionally, user may need to adjust the batch size in the data generator,
or the number of calls to the data generator, depending on the training
framework used. For example, when using Keras model.fit(...) with a
user-defined data generator, one may need to make the data generator return
`one_step_batch_size` examples each time, and scale the `steps_per_epoch`
by `gradient_accumulation_steps`. This is because the data generator is
called `steps_per_epoch` times per epoch, and one call only returns
`one_step_batch_size` (instead of `effective_batch_size`) examples now.
""".format(
base_class='tf.keras.optimizers.legacy' + cls.__name__,
short_base_class=cls.__name__,
dp_keras_class='DPKeras' + cls.__name__)
# The class tf.keras.optimizers.legacy.Optimizer has two methods to compute
# gradients, `_compute_gradients` and `get_gradients`. The first works
# with eager execution, while the second runs in graph mode and is used
# by canned estimators.
# Internally, DPOptimizerClass stores hyperparameters both individually
# and encapsulated in a `GaussianSumQuery` object for these two use cases.
# However, this should be invisible to users of this class.
def __init__(
self,
dp_sum_query: dp_query.DPQuery,
num_microbatches: Optional[int] = None,
gradient_accumulation_steps: int = 1,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Initializes the DPOptimizerClass.
Args:
dp_sum_query: `DPQuery` object, specifying differential privacy
mechanism to use.
num_microbatches: Number of microbatches into which each minibatch is
split. Default is `None` which means that number of microbatches is
equal to batch size (i.e. each microbatch contains exactly one
example). If `gradient_accumulation_steps` is greater than 1 and
`num_microbatches` is not `None` then the effective number of
microbatches is equal to `num_microbatches *
gradient_accumulation_steps`.
gradient_accumulation_steps: If greater than 1 then optimizer will be
accumulating gradients for this number of optimizer steps before
applying them to update model weights. If this argument is set to 1
then updates will be applied on each optimizer step.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
super().__init__(*args, **kwargs)
self.gradient_accumulation_steps = gradient_accumulation_steps
self._num_microbatches = num_microbatches
self._dp_sum_query = dp_sum_query
self._was_dp_gradients_called = False
# We initialize here for `_compute_gradients` because of requirements from
# the tf.keras.Model API. Specifically, keras models use the
# `_compute_gradients` method for both eager and graph mode. So,
# instantiating the state here is necessary to avoid graph compilation
# issues.
self._global_state = self._dp_sum_query.initial_global_state()
def _create_slots(self, var_list):
super()._create_slots(var_list) # pytype: disable=attribute-error
if self.gradient_accumulation_steps > 1:
for var in var_list:
self.add_slot(var, 'grad_acc')
def _prepare_local(self, var_device, var_dtype, apply_state):
super()._prepare_local(var_device, var_dtype, apply_state) # pytype: disable=attribute-error
if self.gradient_accumulation_steps > 1:
apply_update = tf.math.equal(
tf.math.floormod(self.iterations + 1,
self.gradient_accumulation_steps), 0)
grad_scaler = tf.cast(1. / self.gradient_accumulation_steps, var_dtype)
apply_state[(var_device, var_dtype)].update({
'apply_update': apply_update,
'grad_scaler': grad_scaler
})
def _resource_apply_dense(self, grad, var, apply_state=None):
if self.gradient_accumulation_steps > 1:
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
self._fallback_apply_state(var_device, var_dtype))
grad_acc = self.get_slot(var, 'grad_acc')
def _update_grad():
apply_grad_op = super(DPOptimizerClass, self)._resource_apply_dense(
grad_acc + grad * coefficients['grad_scaler'], var, apply_state) # pytype: disable=attribute-error
with tf.control_dependencies([apply_grad_op]):
return grad_acc.assign(
tf.zeros_like(grad_acc),
use_locking=self._use_locking,
read_value=False)
def _accumulate():
return grad_acc.assign_add(
grad * coefficients['grad_scaler'],
use_locking=self._use_locking,
read_value=False)
return tf.cond(coefficients['apply_update'], _update_grad, _accumulate)
else:
return super()._resource_apply_dense(grad, var, apply_state) # pytype: disable=attribute-error
def _resource_apply_sparse_duplicate_indices(self, *args, **kwargs):
if self.gradient_accumulation_steps > 1:
raise NotImplementedError(
'Sparse gradients are not supported with large batch emulation.')
else:
return super()._resource_apply_sparse_duplicate_indices(*args, **kwargs) # pytype: disable=attribute-error
def _resource_apply_sparse(self, *args, **kwargs):
if self.gradient_accumulation_steps > 1:
raise NotImplementedError(
'Sparse gradients are not supported with large batch emulation.')
else:
return super()._resource_apply_sparse(*args, **kwargs) # pytype: disable=attribute-error
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""DP-SGD version of base class method."""
self._was_dp_gradients_called = True
# Compute loss.
if not callable(loss) and tape is None:
raise ValueError('`tape` is required when a `Tensor` loss is passed.')
tape = tape if tape is not None else tf.GradientTape()
with tape:
if callable(loss):
if not callable(var_list):
tape.watch(var_list)
loss = loss()
if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [num_microbatches, -1]), axis=1)
if callable(var_list):
var_list = var_list()
var_list = tf.nest.flatten(var_list)
sample_params = (
self._dp_sum_query.derive_sample_params(self._global_state))
# Compute the per-microbatch losses using helpful jacobian method.
with tf.keras.backend.name_scope(self._name + '/gradients'):
jacobian_per_var = tape.jacobian(
microbatch_losses, var_list, unconnected_gradients='zero')
def process_microbatch(sample_state, microbatch_jacobians):
"""Process one microbatch (record) with privacy helper."""
sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, microbatch_jacobians)
return sample_state
sample_state = self._dp_sum_query.initial_sample_state(var_list)
def body_fn(idx, sample_state):
microbatch_jacobians_per_var = [
jacobian[idx] for jacobian in jacobian_per_var
]
sample_state = process_microbatch(sample_state,
microbatch_jacobians_per_var)
return tf.add(idx, 1), sample_state
cond_fn = lambda idx, _: tf.less(idx, num_microbatches)
idx = tf.constant(0)
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
final_grads = tf.nest.map_structure(_normalize, grad_sums,
[num_microbatches] * len(grad_sums))
return list(zip(final_grads, var_list))
def get_gradients(self, loss, params):
"""DP-SGD version of base class method."""
if not self._was_dp_gradients_called:
# We create the global state here due to tf.Estimator API requirements,
# specifically, that instantiating the global state outside this
# function leads to graph compilation errors of attempting to capture an
# EagerTensor.
self._global_state = self._dp_sum_query.initial_global_state()
self._was_dp_gradients_called = True
# This code mostly follows the logic in the original DPOptimizerClass
# in dp_optimizer.py, except that this returns only the gradients,
# not the gradients and variables.
if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reshape(loss, [num_microbatches, -1])
sample_params = (
self._dp_sum_query.derive_sample_params(self._global_state))
def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper."""
mean_loss = tf.reduce_mean(
input_tensor=tf.gather(microbatch_losses, [i]))
grads = tf.gradients(mean_loss, params)
sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, grads)
return sample_state
sample_state = self._dp_sum_query.initial_sample_state(params)
def body_fn(idx, sample_state):
sample_state = process_microbatch(idx, sample_state)
return tf.add(idx, 1), sample_state
cond_fn = lambda idx, _: tf.less(idx, num_microbatches)
idx = tf.constant(0)
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
final_grads = tf.nest.map_structure(_normalize, grad_sums,
[num_microbatches] * len(grad_sums))
return final_grads
def get_config(self):
"""Returns the config of the optimizer.
An optimizer config is a Python dictionary (serializable)
containing the configuration of an optimizer.
The same optimizer can be reinstantiated later
(without any saved state) from this configuration.
Returns:
Python dictionary.
"""
config = super().get_config()
# The below is necessary to ensure that the global state can be serialized
# by JSON serializers inside of tensorflow saving.
global_state_as_numpy = tf.nest.map_structure(lambda x: x.numpy(),
self._global_state)
config.update({
'global_state': global_state_as_numpy._asdict(),
'num_microbatches': self._num_microbatches,
})
return config
def apply_gradients(self, *args, **kwargs):
"""DP-SGD version of base class method."""
assert self._was_dp_gradients_called, (
'Neither _compute_gradients() or get_gradients() on the '
'differentially private optimizer was called. This means the '
'training is not differentially private. It may be the case that '
'you need to upgrade to TF 2.4 or higher to use this particular '
'optimizer.')
return super().apply_gradients(*args, **kwargs)
return DPOptimizerClass
def make_gaussian_query_optimizer_class(cls):
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
Args:
cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
Returns:
A DP-SGD subclass of `cls` using the `GaussianQuery`, the canonical DP-SGD
implementation.
"""
def return_gaussian_query_optimizer(
l2_norm_clip: float,
noise_multiplier: float,
num_microbatches: Optional[int] = None,
gradient_accumulation_steps: int = 1,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Returns a `DPOptimizerClass` `cls` using the `GaussianSumQuery`.
This function is a thin wrapper around
`make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
apply a `GaussianSumQuery` to any `DPOptimizerClass`.
When combined with stochastic gradient descent, this creates the canonical
DP-SGD algorithm of "Deep Learning with Differential Privacy"
(see https://arxiv.org/abs/1607.00133).
When instantiating this optimizer, you need to supply several
DP-related arguments followed by the standard arguments for
`{short_base_class}`.
As an example, see the below or the documentation of the DPOptimizerClass.
```python
# Create optimizer.
opt = {dp_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5,
num_microbatches=1, <standard arguments>)
```
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch is
split. Default is `None` which means that number of microbatches is
equal to batch size (i.e. each microbatch contains exactly one example).
If `gradient_accumulation_steps` is greater than 1 and
`num_microbatches` is not `None` then the effective number of
microbatches is equal to `num_microbatches *
gradient_accumulation_steps`.
gradient_accumulation_steps: If greater than 1 then optimizer will be
accumulating gradients for this number of optimizer steps before
applying them to update model weights. If this argument is set to 1 then
updates will be applied on each optimizer step.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
dp_sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier)
return cls(
dp_sum_query=dp_sum_query,
num_microbatches=num_microbatches,
gradient_accumulation_steps=gradient_accumulation_steps,
*args,
**kwargs)
return return_gaussian_query_optimizer
def make_dpftrl_tree_aggregation_optimizer_class(cls):
"""Returns a differentially private follow-the-regularized-leader optimizer.
Args:
cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
"""
def return_dpftrl_tree_aggregation_optimizer(
l2_norm_clip: float,
noise_multiplier: float,
var_list_or_model: Union[_VarListType, tf.keras.Model],
num_microbatches: Optional[int] = None,
gradient_accumulation_steps: int = 1,
restart_period: Optional[int] = None,
restart_warmup: Optional[int] = None,
noise_seed: Optional[int] = None,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Returns a `DPOptimizerClass` `cls` using the `TreeAggregationQuery`.
Combining this query with a SGD optimizer can be used to implement the
DP-FTRL algorithm in
"Practical and Private (Deep) Learning without Sampling or Shuffling".
This function is a thin wrapper around
`make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
apply a `TreeAggregationQuery` to any `DPOptimizerClass`.
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
var_list_or_model: Either a tf.keras.Model or a list of tf.variables from
which `tf.TensorSpec`s can be defined. These specify the structure and
shapes of records (gradients).
num_microbatches: Number of microbatches into which each minibatch is
split. Default is `None` which means that number of microbatches is
equal to batch size (i.e. each microbatch contains exactly one example).
If `gradient_accumulation_steps` is greater than 1 and
`num_microbatches` is not `None` then the effective number of
microbatches is equal to `num_microbatches *
gradient_accumulation_steps`.
gradient_accumulation_steps: If greater than 1 then optimizer will be
accumulating gradients for this number of optimizer steps before
applying them to update model weights. If this argument is set to 1 then
updates will be applied on each optimizer step.
restart_period: (Optional) Restart wil occur after `restart_period` steps.
The default (None) means there will be no periodic restarts. Must be a
positive integer. If `restart_warmup` is passed, this only applies to
the second restart and onwards and must be not None.
restart_warmup: (Optional) The first restart will occur after
`restart_warmup` steps. The default (None) means no warmup. Must be an
integer in the range [1, `restart_period` - 1].
noise_seed: (Optional) Integer seed for the Gaussian noise generator. If
`None`, a nondeterministic seed based on system time will be generated.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
Raise:
ValueError: If restart_warmup is not None and restart_period is None.
"""
if restart_warmup is not None and restart_period is None:
raise ValueError(
'`restart_period` was None when `restart_warmup` was not None.')
if isinstance(var_list_or_model, tf.keras.layers.Layer):
model_trainable_specs = tf.nest.map_structure(
lambda t: tf.TensorSpec(t.shape),
var_list_or_model.trainable_variables)
else:
model_trainable_specs = tf.nest.map_structure(
lambda t: tf.TensorSpec(tf.shape(t)), var_list_or_model)
if restart_period is not None:
sum_query = (
tree_aggregation_query.TreeResidualSumQuery.build_l2_gaussian_query(
l2_norm_clip, noise_multiplier, model_trainable_specs,
noise_seed))
restart_indicator = restart_query.PeriodicRoundRestartIndicator(
period=restart_period, warmup=restart_warmup)
tree_aggregation_sum_query = restart_query.RestartQuery(
sum_query, restart_indicator)
else:
tree_aggregation_sum_query = (
tree_aggregation_query.TreeResidualSumQuery.build_l2_gaussian_query(
l2_norm_clip, noise_multiplier, model_trainable_specs,
noise_seed))
return cls(
dp_sum_query=tree_aggregation_sum_query,
num_microbatches=num_microbatches,
gradient_accumulation_steps=gradient_accumulation_steps,
*args,
**kwargs)
return return_dpftrl_tree_aggregation_optimizer
def make_keras_optimizer_class(cls: Type[tf.keras.optimizers.Optimizer]):
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
For backwards compatibility, we create this symbol to match the previous
output of `make_keras_optimizer_class` but using the new logic.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.Optimizer`.
"""
warnings.warn(
'`make_keras_optimizer_class` will be depracated on 2023-02-23. '
'Please switch to `make_gaussian_query_optimizer_class` and the '
'generic optimizers (`make_keras_generic_optimizer_class`).')
return make_gaussian_query_optimizer_class(
make_keras_generic_optimizer_class(cls))
GenericDPAdagradOptimizer = make_keras_generic_optimizer_class(
tf.keras.optimizers.legacy.Adagrad)
GenericDPAdamOptimizer = make_keras_generic_optimizer_class(
tf.keras.optimizers.legacy.Adam)
GenericDPSGDOptimizer = make_keras_generic_optimizer_class(
tf.keras.optimizers.legacy.SGD)
DPFTRLTreeAggregationOptimizer = (
make_dpftrl_tree_aggregation_optimizer_class(GenericDPSGDOptimizer))
# We keep the same names for backwards compatibility.
DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class(
GenericDPAdagradOptimizer)
DPKerasAdamOptimizer = make_gaussian_query_optimizer_class(
GenericDPAdamOptimizer)
DPKerasSGDOptimizer = make_gaussian_query_optimizer_class(GenericDPSGDOptimizer)