@@ -68,7 +68,9 @@ def backward(self):
68
68
reduced_loss .backward (retain_graph = True )
69
69
self .optimizer .zero_grad ()
70
70
coeff = self .module .get_clipping_coef ()
71
- second_loss_per_sample = coeff * self .loss_per_sample
71
+ second_loss_per_sample = (
72
+ coeff .to (self .loss_per_sample .device ) * self .loss_per_sample
73
+ )
72
74
second_loss = torch .sum (second_loss_per_sample )
73
75
self .module .disable_hooks ()
74
76
second_loss .backward ()
@@ -104,15 +106,27 @@ def __init__(
104
106
self .loss_reduction = loss_reduction
105
107
self .criterion .reduction = "none"
106
108
107
- def __call__ (self , input , target ) -> DPTensorFastGradientClipping :
109
+ def __call__ (self , input , target , shape = None ) -> DPTensorFastGradientClipping :
108
110
"""
109
111
Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping
110
112
"""
111
113
112
- loss_per_sample = self .criterion (
113
- input ,
114
- target ,
115
- )
114
+ loss_per_sample = self .criterion (input , target )
115
+
116
+ if shape is not None and loss_per_sample .shape [0 ] == shape [0 ] * shape [1 ]:
117
+ # Note that the privacy unit for generative NLP tasks is per sequence.
118
+ # The shape variable is the shape of the logits before flattening i.e., [batch_size, sequence_lenght, vocab_size].
119
+ # This variable is necessary for ghost clipping to work with generative NLP tasks.
120
+ loss_per_sample = loss_per_sample .view (shape [0 ], shape [1 ]) # BxT
121
+ if self .loss_reduction == "mean" :
122
+ loss_per_sample = loss_per_sample .mean (dim = 1 ) # B
123
+ elif self .loss_reduction == "sum" :
124
+ loss_per_sample = loss_per_sample .sum (dim = 1 ) # B
125
+ else :
126
+ raise ValueError (
127
+ f"loss_reduction = { self .loss_reduction } . Only 'sum' and 'mean' losses are supported"
128
+ )
129
+
116
130
return DPTensorFastGradientClipping (
117
131
self .module , self .optimizer , loss_per_sample , self .loss_reduction
118
132
)
0 commit comments