-
Notifications
You must be signed in to change notification settings - Fork 492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add progress & apply to feature ablation, shapley, & lime based methods #630
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thank you for working on this @aobo-y !
I left couple comments.
Do you mind running type hint checker and fix related issues:
./scripts/run_mypy.sh
Also, running isort .
for sorting the imports - just to check that everything is okay.
cc: @vivekmig, @miguelmartin75 , @bilalsal
captum/_utils/progress.py
Outdated
if tqdm and use_tqdm: | ||
return tqdm(iterable, desc=desc, total=total, **kwargs) | ||
else: | ||
return _simple_progress_out(iterable, desc=desc, total=total) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We might want to show a warning saying that tqdm is not available and we are going to print the progress in the output.
@@ -179,6 +182,11 @@ def attribute( | |||
and use a single feature mask to describe the features | |||
for all examples in the batch. | |||
Default: 1 | |||
show_progress (bool, optional): Displays the progress of computation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it would be good to have this for shapley values, shapley values sampling and kernel shap as well.
captum/_utils/progress.py
Outdated
else f"{desc}{'.' * cur}" | ||
) | ||
|
||
print("\r" + progress_str(cur), end="", file=sys.stderr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of doing printouts perhaps we can have file=sys.stderr
as an input argument and wrap it with something like this?
https://github.com/tqdm/tqdm/blob/bcce20f771a16cb8e4ac5cc5b2307374a2c0e535/tqdm/utils.py#L131
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work on this @aobo-y ! Just one suggestion on the progress bar per input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for the awesome work on this :) !
As @NarineK mentioned, it would be great to also add this for option for other perturbation based methods (Shapley Value Sampling, Lime, etc.), but we can do that later in a separate diff as well.
@@ -302,25 +310,39 @@ def attribute( | |||
for input in inputs | |||
] | |||
|
|||
if show_progress: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Could we potentially move this before the initial eval and also add 1 to the total_forwards to include this evaluation? It doesn't make a large difference, but the progress bar will show up sooner rather than after a full forward pass, which may seem like a delay to the user.
d05b582
to
9cecdd4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aobo-y has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aobo-y has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Issue #581
progress(...)
. By default, it try to loadtqdm
. If unavailable, fallback to a naive local implementationFeatureAblation
,FeaturePermutation
,Occlution
unitest.mock
to mockstderr
to test the progress output.black .
andflake8 .
Sample fallback progress output
Sample tqdm progress output