-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Explainability Evaluation] Add accuracy metrics for evaluation with groundtruth #6137
[Explainability Evaluation] Add accuracy metrics for evaluation with groundtruth #6137
Conversation
Codecov Report
@@ Coverage Diff @@
## master #6137 +/- ##
==========================================
+ Coverage 84.52% 84.54% +0.02%
==========================================
Files 376 377 +1
Lines 20906 20940 +34
==========================================
+ Hits 17670 17703 +33
- Misses 3236 3237 +1
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
3f8104d
to
a3a08bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @shhs29 for this PR! Here are a few high level comments:
- Please move the contents of
explain/evaluate/accuracy_metrics.py
intoexplain/metrics.py
- Let's try to refrain from using networkx in computing these metrics, given that you can access the masks directly from the
Explanation
, usingtorch
for most of your computations will be much more efficient. - Moreover, try to use
torchmetrics
to evaluate ROC - Don't forget about thresholding masks!
Other things to help you out:
You can add this masks
method to the Explanation
class:
@property
def masks(self) -> dict[str, Tensor]:
r"""Returns a dictionary of all masks available in the explanation"""
mask_dict = {key:self[key] for key in self. keys if key.endswith('_mask') and self[key] is not None}
return dict(sorted(mask_dict.items()))
This will allow you to access all the available masks in the explanation as
explanation.masks
then you can combine all the masks in the returned dictionary in to a single tensor with torch.view(-1)
and torch.cat
. This will make it easy to work with all masks at once, afterwards you can calculate number of true positives as
torch.sum(gt_mask_tensor == ex_mask_tensor)
and similarly for other metrics. Also with these improvements you should be left with considerably less code, so no need to split the function into two parts, lets just have grountruth_metrics(explanation: Explanation, groundtruth: Explanation)
Hi @BlazStojanovic, |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @shhs29 and @venomouscyanide, this is looking much better. Left some more comments for you to address, otherwise we are close to closing this issue :)
test/explain/test_explanation.py
Outdated
def test_masks(data, node_mask, edge_mask, node_feat_mask, edge_feat_mask): | ||
expected = [] | ||
if node_mask: | ||
expected.append('node_mask') | ||
if edge_mask: | ||
expected.append('edge_mask') | ||
if node_feat_mask: | ||
expected.append('node_feat_mask') | ||
if edge_feat_mask: | ||
expected.append('edge_feat_mask') | ||
|
||
explanation = create_random_explanation( | ||
data, | ||
node_mask=node_mask, | ||
edge_mask=edge_mask, | ||
node_feat_mask=node_feat_mask, | ||
edge_feat_mask=edge_feat_mask, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also test for the right mask values to be included in the masks
property
def test_masks(data, node_mask, edge_mask, node_feat_mask, edge_feat_mask): | |
expected = [] | |
if node_mask: | |
expected.append('node_mask') | |
if edge_mask: | |
expected.append('edge_mask') | |
if node_feat_mask: | |
expected.append('node_feat_mask') | |
if edge_feat_mask: | |
expected.append('edge_feat_mask') | |
explanation = create_random_explanation( | |
data, | |
node_mask=node_mask, | |
edge_mask=edge_mask, | |
node_feat_mask=node_feat_mask, | |
edge_feat_mask=edge_feat_mask, | |
) | |
@pytest.mark.parametrize('node_mask', [True, False]) | |
@pytest.mark.parametrize('edge_mask', [True, False]) | |
@pytest.mark.parametrize('node_feat_mask', [True, False]) | |
@pytest.mark.parametrize('edge_feat_mask', [True, False]) | |
def test_masks(data, node_mask, edge_mask, node_feat_mask, edge_feat_mask): | |
explanation = create_random_explanation( | |
data, | |
node_mask=node_mask, | |
edge_mask=edge_mask, | |
node_feat_mask=node_feat_mask, | |
edge_feat_mask=edge_feat_mask, | |
) | |
expected_keys = [] | |
expected_values = [] | |
if node_mask: | |
expected_keys.append('node_mask') | |
expected_values.append(explanation.node_mask) | |
if edge_mask: | |
expected_keys.append('edge_mask') | |
expected_values.append(explanation.edge_mask) | |
if node_feat_mask: | |
expected_keys.append('node_feat_mask') | |
expected_values.append(explanation.node_feat_mask) | |
if edge_feat_mask: | |
expected_keys.append('edge_feat_mask') | |
expected_values.append(explanation.edge_feat_mask) | |
assert set(explanation.masks.keys()) == set(expected_keys) | |
assert set(explanation.masks.values()) == set(expected_values) |
torch_geometric/explain/metrics.py
Outdated
from torch_geometric.explain import Explanation | ||
|
||
|
||
def groundtruth_metrics( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def groundtruth_metrics( | |
def get_groundtruth_metrics( |
test/explain/test_metrics.py
Outdated
assert accuracy_metrics[0] == 1.0 | ||
assert accuracy_metrics[1] == 1.0 | ||
assert accuracy_metrics[2] == 1.0 | ||
assert accuracy_metrics[3] == 1.0 | ||
assert accuracy_metrics[4] == 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you ran the tests, to see if these asserts are met? To me it is not immediately obvious that two random explanations will result in these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. Random explanations is not a good way to setup tests. I have updated the test to have hardcoded mask values. A few questions that I had in mind was: 1. What should the return type of these metrics be ? Should they be tensors or floats ? 2. Should we round off different accuracy values ? 3. Should we check for division by 0 (i.e, TP and TN is 0).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Great, using hardcoded masks here is a better way to test. Just a note when writing such tests in the future, always make sure to verify outcomes of hardcoded tests independently (i.e. by hand or other library)
- Probably returning
Tensor
s is better, torch functions return them anyway so you don't need to change anything but the type hint - I think we shouldn't round anything, let the user do this if they so choose afterwards
- Yes, this is a very good point! We need to handle edge cases, let's handle them in a way consistent with
sklearn
- when
true positive + false positive == 0
, precision returns 0 - when
true positive + false negative == 0
, recall returns 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BlazStojanovic Thanks for the detailed comment. The status of each of these items are as follows:
- I have verified all the values, except for auroc, against sklearn metrics. I need some clarification on AUROC. Based on my analysis of sklearn auc score, its equivalent in torchmetrics is
auroc = AUROC(task="binary") auc = auroc(ex_mask_tensor, gt_mask_tensor)
Moreover, looking at their example of AUROC calculation, we do not need ROC thresholding first. However, I would appreciate your thoughts on the same. - I have updated the typehint.
- I am currently not rounding the values.
- I have added a couple of new conditions and added a new test for this flow. In addition to the conditions you mentioned, I added one more condition for f1_score where f1_score is set to 0 when
precision == 0.0 or recall == 0.0
.
torch_geometric/explain/metrics.py
Outdated
def groundtruth_metrics( | ||
explanation: Explanation, | ||
groundtruth: Explanation) -> Tuple[float, float, float, float, float]: | ||
"""accuracy_scores: Compute accuracy scores when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Describe thresholding behaviour in the docstring, i.e. >0 thresholding for TP, and FP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also explain the order of returned metrics
torch_geometric/explain/metrics.py
Outdated
accuracy = (tp + tn) / (tp + fp + tn + fn) | ||
recall = tp / (tp + fn) | ||
f1_score = 2 * (precision * recall) / (precision + recall) | ||
roc = ROC(task="binary") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you need to do this before you threshold out the ex_mask_tensor
? Because here roc
receives a binary tensor, which it cannot threshold?
This comment refers to line 41.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. That's right. I have updated the implementation to use original ex_mask_tensor
in ROC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just left two more minor comments, this is looking very good and once you incorporate these last two, we can approve and merge this! 👍🏻
torch_geometric/explain/metrics.py
Outdated
Currently we perform hard thresholding (where the threshold value | ||
is set to 0) on explanation and groundtruth masks to get true | ||
positives, true negatives, false positives and false negatives. | ||
I.e., all values in explanation masks and ground truth masks which | ||
are greater than 0 is set to 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might actually be better to have an additional argument, threshold=0.0
which defaults to 0.0, but can be set by the user to be the thresholding of both masks. This docstring will then describe the default behavior.
torch_geometric/explain/metrics.py
Outdated
roc = ROC(task="binary") | ||
fpr, tpr, thresholds = roc(ex_mask_tensor, gt_mask_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, we don't need the roc calculation for auroc. I think we have two options here, We can also return the full ROC curve, or we can remove these two lines and just return the AUROC, I leave the choice up to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ROC curve gives three values: fpr
, tpr
, thresholds
, whereas all other metrics return a single value. To maintain consistency, I have decided to calculate AUROC and return that. Moreover, with the inclusion of ROC curve, our return value will have nested values, in which case I believe returning a dict makes more sense.
@BlazStojanovic Struggling a bit with reStructuredText. Can you please share some best practices regarding .rst files and how to make sure the formatting is correct? |
torch_geometric/explain/metrics.py
Outdated
r"""Returns different accuracy metrics on explanation when | ||
groundtruth is available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r"""Returns different accuracy metrics on explanation when | |
groundtruth is available. | |
r"""Compares an explanation with the ground truth explanation. Returns basic evaluation metrics - accuracy, recall, precision, f1_score, and auroc. |
torch_geometric/explain/metrics.py
Outdated
.. note:: | ||
Currently we perform hard thresholding (where the threshold value | ||
defaults to 0) on explanation and groundtruth masks to get true | ||
positives, true negatives, false positives and false negatives. | ||
I.e., all values in explanation masks and ground truth masks which | ||
are greater than the threshold value is set to 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this note
torch_geometric/explain/metrics.py
Outdated
threshold (float): threshold value to perform hard thresholding. | ||
(default: :obj:`0.0`) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
threshold (float): threshold value to perform hard thresholding. | |
(default: :obj:`0.0`) | |
threshold (float): threshold value to perform hard thresholding of the `explanation` and `groundtruth` masks. | |
(default: :obj:`0.0`) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BlazStojanovic @rusty1s For the fidelity metric in main now, the metrics are under a new folder. We have a different structure. Is that fine ? |
Yes, don't worry about it :) |
Thank you! Please note that I changed the code slightly. I think it is a bit dangerous to combine masks of different levels with each other, so I changed the interface to expects masks rather than |
This is a PR for the issue #5962.
Accuracy metrics such as accuracy, recall, precision , auc and f1_score are available.
TODOs