Skip to content

Commit f3bbcb3

Browse files
authored
Merge 8a6c564 into 9741fe2
2 parents 9741fe2 + 8a6c564 commit f3bbcb3

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

opacus/utils/fast_gradient_clipping_utils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def backward(self):
6868
reduced_loss.backward(retain_graph=True)
6969
self.optimizer.zero_grad()
7070
coeff = self.module.get_clipping_coef()
71-
second_loss_per_sample = coeff * self.loss_per_sample
71+
second_loss_per_sample = (
72+
coeff.to(self.loss_per_sample.device) * self.loss_per_sample
73+
)
7274
second_loss = torch.sum(second_loss_per_sample)
7375
self.module.disable_hooks()
7476
second_loss.backward()
@@ -104,15 +106,27 @@ def __init__(
104106
self.loss_reduction = loss_reduction
105107
self.criterion.reduction = "none"
106108

107-
def __call__(self, input, target) -> DPTensorFastGradientClipping:
109+
def __call__(self, input, target, shape=None) -> DPTensorFastGradientClipping:
108110
"""
109111
Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping
110112
"""
111113

112-
loss_per_sample = self.criterion(
113-
input,
114-
target,
115-
)
114+
loss_per_sample = self.criterion(input, target)
115+
116+
if shape is not None and loss_per_sample.shape[0] == shape[0] * shape[1]:
117+
# Note that the privacy unit for generative NLP tasks is per sequence.
118+
# The shape variable is the shape of the logits before flattening i.e., [batch_size, sequence_lenght, vocab_size].
119+
# This variable is necessary for ghost clipping to work with generative NLP tasks.
120+
loss_per_sample = loss_per_sample.view(shape[0], shape[1]) # BxT
121+
if self.loss_reduction == "mean":
122+
loss_per_sample = loss_per_sample.mean(dim=1) # B
123+
elif self.loss_reduction == "sum":
124+
loss_per_sample = loss_per_sample.sum(dim=1) # B
125+
else:
126+
raise ValueError(
127+
f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported"
128+
)
129+
116130
return DPTensorFastGradientClipping(
117131
self.module, self.optimizer, loss_per_sample, self.loss_reduction
118132
)

0 commit comments

Comments
 (0)