Skip to content

Commit

Permalink
Merge pull request #14 from softwaremill/fix/lrp_rules
Browse files Browse the repository at this point in the history
lrp missing rules fixed
  • Loading branch information
kamilrzechowski committed Jan 4, 2023
2 parents 1cad9b8 + a5658db commit c3154fd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# VS Code
.vscode/
31 changes: 25 additions & 6 deletions autoxai/explainer/lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from captum.attr import LRP, LayerLRP
from captum.attr._utils.lrp_rules import EpsilonRule, GammaRule

from autoxai.explainer.base_explainer import CVExplainer
from autoxai.explainer.model_utils import modify_modules
Expand Down Expand Up @@ -45,12 +46,30 @@ def calculate_features(

lrp = self.create_explainer(model=model, layer=layer)

attributions = lrp.attribute(
input_data,
target=pred_label_idx,
)
attributions = lrp.attribute(input_data, target=pred_label_idx)
return attributions

def add_rules(self, model: torch.nn.Module) -> torch.nn.Module:
"""Add rules for the LRP explainer,
according to https://arxiv.org/pdf/1910.09840.pdf.
Args:
model: DNN object to be modified.
Returns:
Modified DNN object.
"""
layers_number: int = len(list(model.modules()))
for idx_layer, module in enumerate(model.modules()):
if idx_layer <= layers_number // 2:
setattr(module, "rule", GammaRule())
elif idx_layer != (layers_number - 1):
setattr(module, "rule", EpsilonRule())
else:
setattr(module, "rule", EpsilonRule(epsilon=0)) # LRP-0

return model


class LRPCVExplainer(BaseLRPCVExplainer):
"""LRP algorithm explainer."""
Expand All @@ -68,7 +87,7 @@ def create_explainer(self, **kwargs) -> Union[LRP, LayerLRP]:
if model is None:
raise RuntimeError(f"Missing or `None` argument `model` passed: {kwargs}")

model = modify_modules(model)
model = self.add_rules(modify_modules(model))

return LRP(model=model)

Expand All @@ -92,6 +111,6 @@ def create_explainer(self, **kwargs) -> Union[LRP, LayerLRP]:
f"Missing or `None` arguments `model` or `layer` passed: {kwargs}"
)

model = modify_modules(model)
model = self.add_rules(modify_modules(model))

return LayerLRP(model=model, layer=layer)
1 change: 0 additions & 1 deletion autoxai/explainer/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""File contains functions to modifiy DNN models."""

import torch


Expand Down

0 comments on commit c3154fd

Please sign in to comment.