diff --git a/shap/explainers/deep/deep_pytorch.py b/shap/explainers/deep/deep_pytorch.py index ef276c216..885e78f83 100644 --- a/shap/explainers/deep/deep_pytorch.py +++ b/shap/explainers/deep/deep_pytorch.py @@ -351,10 +351,16 @@ def nonlinear_1d(module, grad_input, grad_output): op_handler['Conv1d'] = linear_1d op_handler['Conv2d'] = linear_1d op_handler['Conv3d'] = linear_1d +op_handler['ConvTranspose1d'] = linear_1d +op_handler['ConvTranspose2d'] = linear_1d +op_handler['ConvTranspose3d'] = linear_1d op_handler['Linear'] = linear_1d op_handler['AvgPool1d'] = linear_1d op_handler['AvgPool2d'] = linear_1d op_handler['AvgPool3d'] = linear_1d +op_handler['AdaptiveAvgPool1d'] = linear_1d +op_handler['AdaptiveAvgPool2d'] = linear_1d +op_handler['AdaptiveAvgPool3d'] = linear_1d op_handler['BatchNorm1d'] = linear_1d op_handler['BatchNorm2d'] = linear_1d op_handler['BatchNorm3d'] = linear_1d diff --git a/tests/explainers/test_deep.py b/tests/explainers/test_deep.py index 2a9bf242f..5468ea4c7 100644 --- a/tests/explainers/test_deep.py +++ b/tests/explainers/test_deep.py @@ -258,7 +258,8 @@ def __init__(self): nn.MaxPool2d(2), nn.Tanh(), nn.Conv2d(10, 20, kernel_size=5), - nn.MaxPool2d(2), + nn.ConvTranspose2d(20, 20, 1), + nn.AdaptiveAvgPool2d(output_size=(4, 4)), nn.Softplus(), ) self.fc_layers = nn.Sequential( @@ -369,12 +370,13 @@ def __init__(self, num_features): super(Net, self).__init__() self.linear = nn.Linear(num_features // 2, 2) self.conv1d = nn.Conv1d(1, 1, 1) + self.convt1d = nn.ConvTranspose1d(1, 1, 1) self.leaky_relu = nn.LeakyReLU() - self.maxpool1 = nn.MaxPool1d(kernel_size=2) + self.aapool1d = nn.AdaptiveAvgPool1d(output_size=6) self.maxpool2 = nn.MaxPool1d(kernel_size=2) def forward(self, X): - x = self.maxpool1(self.conv1d(X.unsqueeze(1))).squeeze(1) + x = self.aapool1d(self.convt1d(self.conv1d(X.unsqueeze(1)))).squeeze(1) return self.maxpool2(self.linear(self.leaky_relu(x)).unsqueeze(1)).squeeze(1) model = Net(num_features) optimizer = torch.optim.Adam(model.parameters())