Skip to content

Commit

Permalink
Merge pull request #609 from gabrieltseng/pytorch/improvements
Browse files Browse the repository at this point in the history
Add Adaptive Average Pooling and Transposed Convolutions to the Pytorch Deep Explainer
  • Loading branch information
slundberg committed Jun 7, 2019
2 parents 7269d02 + 2ce9aaa commit 9866ae7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 6 additions & 0 deletions shap/explainers/deep/deep_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,16 @@ def nonlinear_1d(module, grad_input, grad_output):
op_handler['Conv1d'] = linear_1d
op_handler['Conv2d'] = linear_1d
op_handler['Conv3d'] = linear_1d
op_handler['ConvTranspose1d'] = linear_1d
op_handler['ConvTranspose2d'] = linear_1d
op_handler['ConvTranspose3d'] = linear_1d
op_handler['Linear'] = linear_1d
op_handler['AvgPool1d'] = linear_1d
op_handler['AvgPool2d'] = linear_1d
op_handler['AvgPool3d'] = linear_1d
op_handler['AdaptiveAvgPool1d'] = linear_1d
op_handler['AdaptiveAvgPool2d'] = linear_1d
op_handler['AdaptiveAvgPool3d'] = linear_1d
op_handler['BatchNorm1d'] = linear_1d
op_handler['BatchNorm2d'] = linear_1d
op_handler['BatchNorm3d'] = linear_1d
Expand Down
8 changes: 5 additions & 3 deletions tests/explainers/test_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ def __init__(self):
nn.MaxPool2d(2),
nn.Tanh(),
nn.Conv2d(10, 20, kernel_size=5),
nn.MaxPool2d(2),
nn.ConvTranspose2d(20, 20, 1),
nn.AdaptiveAvgPool2d(output_size=(4, 4)),
nn.Softplus(),
)
self.fc_layers = nn.Sequential(
Expand Down Expand Up @@ -369,12 +370,13 @@ def __init__(self, num_features):
super(Net, self).__init__()
self.linear = nn.Linear(num_features // 2, 2)
self.conv1d = nn.Conv1d(1, 1, 1)
self.convt1d = nn.ConvTranspose1d(1, 1, 1)
self.leaky_relu = nn.LeakyReLU()
self.maxpool1 = nn.MaxPool1d(kernel_size=2)
self.aapool1d = nn.AdaptiveAvgPool1d(output_size=6)
self.maxpool2 = nn.MaxPool1d(kernel_size=2)

def forward(self, X):
x = self.maxpool1(self.conv1d(X.unsqueeze(1))).squeeze(1)
x = self.aapool1d(self.convt1d(self.conv1d(X.unsqueeze(1)))).squeeze(1)
return self.maxpool2(self.linear(self.leaky_relu(x)).unsqueeze(1)).squeeze(1)
model = Net(num_features)
optimizer = torch.optim.Adam(model.parameters())
Expand Down

0 comments on commit 9866ae7

Please sign in to comment.