diff --git a/opacus/grad_sample/conv.py b/opacus/grad_sample/conv.py index 16f2d95b..633782f9 100644 --- a/opacus/grad_sample/conv.py +++ b/opacus/grad_sample/conv.py @@ -51,7 +51,7 @@ def compute_conv_grad_sample( return ret # get activations and backprops in shape depending on the Conv layer - if type(layer) == nn.Conv2d: + if type(layer) is nn.Conv2d: activations = unfold2d( activations, kernel_size=layer.kernel_size, @@ -59,7 +59,7 @@ def compute_conv_grad_sample( stride=layer.stride, dilation=layer.dilation, ) - elif type(layer) == nn.Conv1d: + elif type(layer) is nn.Conv1d: activations = activations.unsqueeze(-2) # add the H dimension # set arguments to tuples with appropriate second element if layer.padding == "same": @@ -77,7 +77,7 @@ def compute_conv_grad_sample( stride=(1, layer.stride[0]), dilation=(1, layer.dilation[0]), ) - elif type(layer) == nn.Conv3d: + elif type(layer) is nn.Conv3d: activations = unfold3d( activations, kernel_size=layer.kernel_size, diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index e9f59eb3..0ca6811b 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -181,7 +181,7 @@ def _prepare_model( if ( module.batch_first != batch_first or module.loss_reduction != loss_reduction - or type(module) != get_gsm_class(grad_sample_mode) + or type(module) is not get_gsm_class(grad_sample_mode) ): raise ValueError( f"Pre-existing GradSampleModule doesn't match new arguments."