Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Explainability Evaluation] Add accuracy metrics for evaluation with groundtruth #6137

Merged
merged 40 commits into from
Dec 16, 2022

Conversation

shhs29
Copy link
Contributor

@shhs29 shhs29 commented Dec 5, 2022

This is a PR for the issue #5962.
Accuracy metrics such as accuracy, recall, precision , auc and f1_score are available.

TODOs

  • Add docstring as needed.
  • Update CHANGELOG.md
  • Fix failing import of torchmetrics.

@shhs29 shhs29 changed the title Add accuracy metrics for explainability [Explainability] Add accuracy metrics for explainability Dec 5, 2022
@shhs29 shhs29 changed the title [Explainability] Add accuracy metrics for explainability [Explainability] Add accuracy metrics for evaluation with groundtruth Dec 5, 2022
@codecov
Copy link

codecov bot commented Dec 5, 2022

Codecov Report

Merging #6137 (4912084) into master (d2f2503) will increase coverage by 0.02%.
The diff coverage is 100.00%.

❗ Current head 4912084 differs from pull request most recent head f872a48. Consider uploading reports for the commit f872a48 to get more accurate results

@@            Coverage Diff             @@
##           master    #6137      +/-   ##
==========================================
+ Coverage   84.52%   84.54%   +0.02%     
==========================================
  Files         376      377       +1     
  Lines       20906    20940      +34     
==========================================
+ Hits        17670    17703      +33     
- Misses       3236     3237       +1     
Impacted Files Coverage Δ
torch_geometric/explain/explanation.py 98.87% <100.00%> (+0.05%) ⬆️
torch_geometric/explain/metrics.py 100.00% <100.00%> (ø)
torch_geometric/utils/subgraph.py 98.78% <0.00%> (-1.22%) ⬇️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@shhs29 shhs29 changed the title [Explainability] Add accuracy metrics for evaluation with groundtruth [Explainability Evaluation] Add accuracy metrics for evaluation with groundtruth Dec 5, 2022
@shhs29 shhs29 force-pushed the add-eval-metric-with-groundtruth branch from 3f8104d to a3a08bf Compare December 6, 2022 07:47
@shhs29 shhs29 marked this pull request as ready for review December 6, 2022 07:47
Copy link
Contributor

@BlazStojanovic BlazStojanovic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @shhs29 for this PR! Here are a few high level comments:

  • Please move the contents of explain/evaluate/accuracy_metrics.py into explain/metrics.py
  • Let's try to refrain from using networkx in computing these metrics, given that you can access the masks directly from the Explanation, using torch for most of your computations will be much more efficient.
  • Moreover, try to use torchmetrics to evaluate ROC
  • Don't forget about thresholding masks!

Other things to help you out:
You can add this masks method to the Explanation class:

    @property
    def masks(self) -> dict[str, Tensor]:
        r"""Returns a dictionary of all masks available in the explanation"""
        mask_dict = {key:self[key] for key in self. keys if key.endswith('_mask') and self[key] is not None}
        return dict(sorted(mask_dict.items()))

This will allow you to access all the available masks in the explanation as

explanation.masks

then you can combine all the masks in the returned dictionary in to a single tensor with torch.view(-1) and torch.cat. This will make it easy to work with all masks at once, afterwards you can calculate number of true positives as

torch.sum(gt_mask_tensor == ex_mask_tensor)

and similarly for other metrics. Also with these improvements you should be left with considerably less code, so no need to split the function into two parts, lets just have grountruth_metrics(explanation: Explanation, groundtruth: Explanation)

@shhs29
Copy link
Contributor Author

shhs29 commented Dec 8, 2022

Hi @BlazStojanovic,
Thanks a lot for the detailed comments. @venomouscyanide and I will work on these and update the PR soon.

Copy link
Contributor

@BlazStojanovic BlazStojanovic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @shhs29 and @venomouscyanide, this is looking much better. Left some more comments for you to address, otherwise we are close to closing this issue :)

Comment on lines 76 to 93
def test_masks(data, node_mask, edge_mask, node_feat_mask, edge_feat_mask):
expected = []
if node_mask:
expected.append('node_mask')
if edge_mask:
expected.append('edge_mask')
if node_feat_mask:
expected.append('node_feat_mask')
if edge_feat_mask:
expected.append('edge_feat_mask')

explanation = create_random_explanation(
data,
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also test for the right mask values to be included in the masks property

Suggested change
def test_masks(data, node_mask, edge_mask, node_feat_mask, edge_feat_mask):
expected = []
if node_mask:
expected.append('node_mask')
if edge_mask:
expected.append('edge_mask')
if node_feat_mask:
expected.append('node_feat_mask')
if edge_feat_mask:
expected.append('edge_feat_mask')
explanation = create_random_explanation(
data,
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
)
@pytest.mark.parametrize('node_mask', [True, False])
@pytest.mark.parametrize('edge_mask', [True, False])
@pytest.mark.parametrize('node_feat_mask', [True, False])
@pytest.mark.parametrize('edge_feat_mask', [True, False])
def test_masks(data, node_mask, edge_mask, node_feat_mask, edge_feat_mask):
explanation = create_random_explanation(
data,
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
)
expected_keys = []
expected_values = []
if node_mask:
expected_keys.append('node_mask')
expected_values.append(explanation.node_mask)
if edge_mask:
expected_keys.append('edge_mask')
expected_values.append(explanation.edge_mask)
if node_feat_mask:
expected_keys.append('node_feat_mask')
expected_values.append(explanation.node_feat_mask)
if edge_feat_mask:
expected_keys.append('edge_feat_mask')
expected_values.append(explanation.edge_feat_mask)
assert set(explanation.masks.keys()) == set(expected_keys)
assert set(explanation.masks.values()) == set(expected_values)

from torch_geometric.explain import Explanation


def groundtruth_metrics(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def groundtruth_metrics(
def get_groundtruth_metrics(

Comment on lines 73 to 77
assert accuracy_metrics[0] == 1.0
assert accuracy_metrics[1] == 1.0
assert accuracy_metrics[2] == 1.0
assert accuracy_metrics[3] == 1.0
assert accuracy_metrics[4] == 0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you ran the tests, to see if these asserts are met? To me it is not immediately obvious that two random explanations will result in these.

Copy link
Contributor Author

@shhs29 shhs29 Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. Random explanations is not a good way to setup tests. I have updated the test to have hardcoded mask values. A few questions that I had in mind was: 1. What should the return type of these metrics be ? Should they be tensors or floats ? 2. Should we round off different accuracy values ? 3. Should we check for division by 0 (i.e, TP and TN is 0).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Great, using hardcoded masks here is a better way to test. Just a note when writing such tests in the future, always make sure to verify outcomes of hardcoded tests independently (i.e. by hand or other library)
  2. Probably returning Tensors is better, torch functions return them anyway so you don't need to change anything but the type hint
  3. I think we shouldn't round anything, let the user do this if they so choose afterwards
  4. Yes, this is a very good point! We need to handle edge cases, let's handle them in a way consistent with sklearn
  • when true positive + false positive == 0, precision returns 0
  • when true positive + false negative == 0, recall returns 0

Copy link
Contributor Author

@shhs29 shhs29 Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BlazStojanovic Thanks for the detailed comment. The status of each of these items are as follows:

  1. I have verified all the values, except for auroc, against sklearn metrics. I need some clarification on AUROC. Based on my analysis of sklearn auc score, its equivalent in torchmetrics is
    auroc = AUROC(task="binary") auc = auroc(ex_mask_tensor, gt_mask_tensor)
    Moreover, looking at their example of AUROC calculation, we do not need ROC thresholding first. However, I would appreciate your thoughts on the same.
  2. I have updated the typehint.
  3. I am currently not rounding the values.
  4. I have added a couple of new conditions and added a new test for this flow. In addition to the conditions you mentioned, I added one more condition for f1_score where f1_score is set to 0 when precision == 0.0 or recall == 0.0.

def groundtruth_metrics(
explanation: Explanation,
groundtruth: Explanation) -> Tuple[float, float, float, float, float]:
"""accuracy_scores: Compute accuracy scores when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Describe thresholding behaviour in the docstring, i.e. >0 thresholding for TP, and FP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also explain the order of returned metrics

accuracy = (tp + tn) / (tp + fp + tn + fn)
recall = tp / (tp + fn)
f1_score = 2 * (precision * recall) / (precision + recall)
roc = ROC(task="binary")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to do this before you threshold out the ex_mask_tensor? Because here roc receives a binary tensor, which it cannot threshold?

This comment refers to line 41.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That's right. I have updated the implementation to use original ex_mask_tensor in ROC.

Copy link
Contributor

@BlazStojanovic BlazStojanovic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just left two more minor comments, this is looking very good and once you incorporate these last two, we can approve and merge this! 👍🏻

Comment on lines 16 to 20
Currently we perform hard thresholding (where the threshold value
is set to 0) on explanation and groundtruth masks to get true
positives, true negatives, false positives and false negatives.
I.e., all values in explanation masks and ground truth masks which
are greater than 0 is set to 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might actually be better to have an additional argument, threshold=0.0 which defaults to 0.0, but can be set by the user to be the thresholding of both masks. This docstring will then describe the default behavior.

Comment on lines 40 to 41
roc = ROC(task="binary")
fpr, tpr, thresholds = roc(ex_mask_tensor, gt_mask_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, we don't need the roc calculation for auroc. I think we have two options here, We can also return the full ROC curve, or we can remove these two lines and just return the AUROC, I leave the choice up to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROC curve gives three values: fpr, tpr, thresholds, whereas all other metrics return a single value. To maintain consistency, I have decided to calculate AUROC and return that. Moreover, with the inclusion of ROC curve, our return value will have nested values, in which case I believe returning a dict makes more sense.

@venomouscyanide
Copy link
Contributor

@BlazStojanovic Struggling a bit with reStructuredText. Can you please share some best practices regarding .rst files and how to make sure the formatting is correct?

Comment on lines 15 to 16
r"""Returns different accuracy metrics on explanation when
groundtruth is available.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
r"""Returns different accuracy metrics on explanation when
groundtruth is available.
r"""Compares an explanation with the ground truth explanation. Returns basic evaluation metrics - accuracy, recall, precision, f1_score, and auroc.

Comment on lines 18 to 23
.. note::
Currently we perform hard thresholding (where the threshold value
defaults to 0) on explanation and groundtruth masks to get true
positives, true negatives, false positives and false negatives.
I.e., all values in explanation masks and ground truth masks which
are greater than the threshold value is set to 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this note

Comment on lines 29 to 30
threshold (float): threshold value to perform hard thresholding.
(default: :obj:`0.0`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
threshold (float): threshold value to perform hard thresholding.
(default: :obj:`0.0`)
threshold (float): threshold value to perform hard thresholding of the `explanation` and `groundtruth` masks.
(default: :obj:`0.0`)

Copy link
Contributor

@BlazStojanovic BlazStojanovic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @shhs29 for including all the comments. This is looking good, so I think we can move towards merging this with other explainability code @rusty1s!

@shhs29
Copy link
Contributor Author

shhs29 commented Dec 15, 2022

@BlazStojanovic @rusty1s For the fidelity metric in main now, the metrics are under a new folder. We have a different structure. Is that fine ?

@rusty1s
Copy link
Member

rusty1s commented Dec 15, 2022

Yes, don't worry about it :)

@rusty1s
Copy link
Member

rusty1s commented Dec 16, 2022

Thank you!

Please note that I changed the code slightly. I think it is a bit dangerous to combine masks of different levels with each other, so I changed the interface to expects masks rather than Explanation objects. Also added a metrics argument to be able to select a subset of metrics to compute, and used torchmetrics consistently for computation. Hope the changes are okay for you.

@rusty1s rusty1s enabled auto-merge (squash) December 16, 2022 09:29
@rusty1s rusty1s merged commit e43aa42 into pyg-team:master Dec 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants