Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in LimeBase example? #905

Closed
th789 opened this issue Mar 19, 2022 · 4 comments
Closed

Bug in LimeBase example? #905

th789 opened this issue Mar 19, 2022 · 4 comments

Comments

@th789
Copy link

th789 commented Mar 19, 2022

🐛 Bug

I am trying to run the example provided for LimeBase on https://captum.ai/api/lime.html. However, running the example leads to the following error message: "TypeError: () got an unexpected keyword argument 'kernel_width'" (more info below).

To Reproduce

import torch
import torch.nn as nn
from captum.attr import LimeBase
from captum._utils.models.linear_model import SkLearnLinearModel

class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 3)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out

###all code below copied from example provided
net = SimpleClassifier()

def similarity_kernel(original_input, perturbed_input, perturbed_interpretable_input, **kwargs):
      # kernel_width will be provided to attribute as a kwarg
      kernel_width = kwargs["kernel_width"]
      l2_dist = torch.norm(original_input - perturbed_input)
      return torch.exp(- (l2_dist**2) / (kernel_width**2))

def perturb_func(original_input, **kwargs):
      return original_input + torch.randn_like(original_input)

input = torch.randn(2, 5)

lime_attr = LimeBase(net,
                     SkLearnLinearModel("linear_model.Ridge"),
                     similarity_func=similarity_kernel,
                     perturb_func=perturb_func,
                     perturb_interpretable_space=False,
                     from_interp_rep_transform=None,
                     to_interp_rep_transform=lambda x: x)

attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1) #error message appears after running this line

Error message:
Screen Shot 2022-03-19 at 1 18 04 PM

Expected behavior

There should be no error message and attr_coefs should return the feature attributions.

Environment

Describe the environment used for Captum

 - Captum version: 0.5.0
 - Pytorch version: 1.10.0+cu111
 - OS (e.g., Linux): macOS
 - How you installed Captum (`conda`, `pip`, source): 'conda' and 'pip' --> this error message arises whether I use `conda install captum -c pytorch` or `pip install captum` to install captum
 - Python version: 3.7.12
@vivekmig
Copy link
Contributor

Hi @th789 , thanks for catching this! It seems that the lambda function provided as to_interp_rep_transform should be changed to accept kwargs. We will update the documentation to correct this.

@th789
Copy link
Author

th789 commented Mar 22, 2022

Hi @vivekmig, thank you so much for the reply! I had tried that, but it led to another error message (see below). Perhaps I misunderstood what you meant?

#(keeping all lines above the same)
def to_interp_rep_transform_none(x, **kwargs):
    return x

lime_attr = LimeBase(net,
                     SkLearnLinearModel("linear_model.Ridge"),
                     similarity_func=similarity_kernel,
                     perturb_func=perturb_func,
                     perturb_interpretable_space=False,
                     from_interp_rep_transform=None,
                     to_interp_rep_transform=to_interp_rep_transform_none)

attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1)

image

@vivekmig
Copy link
Contributor

Hi @th789 , the signature also needs to include a positional argument for original inputs, based on the expected function signature defined here. Seems like this wasn't updated after a change in the function signature, sorry about that! You can try out the updated sample in #908 .

In addition to the function, you would also need to change to a batch size of 1, since LimeBase supports attribution for one example at a time, unless the forward function returns a batch-wise output or loss.

@th789
Copy link
Author

th789 commented Mar 22, 2022

Thank you very much @vivekmig, I really appreciate it!

facebook-github-bot pushed a commit that referenced this issue Mar 22, 2022
Summary:
LimeBase example code did not apply the expected signature, causing issues when running code based on the example, as reported in #905 . This fixes the example to run appropriately.

Pull Request resolved: #908

Reviewed By: miguelmartin75

Differential Revision: D35050422

Pulled By: vivekmig

fbshipit-source-id: 9bf1e727c3fa02a8e9dbf5aee00125a59f8a2062
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants