Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more algorithms #8

Merged
merged 14 commits into from
Dec 28, 2022
23 changes: 22 additions & 1 deletion example/streamlit_app/method_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,22 @@
from enum import Enum
from typing import Optional

from src.explainer.conductance import LayerConductanceCVExplainer
from src.explainer.deconv import DeconvolutionCVExplainer
from src.explainer.deeplift import DeepLIFTCVExplainer, LayerDeepLIFTCVExplainer
from src.explainer.deeplift_shap import (
DeepLIFTSHAPCVExplainer,
LayerDeepLIFTSHAPCVExplainer,
)
from src.explainer.gradcam import GuidedGradCAMCVExplainer, LayerGradCAMCVExplainer
from src.explainer.gradient_shap import (
GradientSHAPCVExplainer,
LayerGradientSHAPCVExplainer,
)
from src.explainer.input_x_gradient import (
InputXGradientCVExplainer,
LayerInputXGradientCVExplainer,
)
from src.explainer.integrated_gradients import (
IntegratedGradientsCVExplainer,
LayerIntegratedGradientsCVExplainer,
Expand All @@ -18,10 +29,11 @@
NoiseTunnelCVExplainer,
)
from src.explainer.occulusion import OcculusionCVExplainer
from src.explainer.saliency import SaliencyCVExplainer


class MethodName(Enum):
"""XAI algorithm names."""
"""XAI algorithms names."""

OCCULUSION = OcculusionCVExplainer().algorithm_name
NOISE_TUNNEL = NoiseTunnelCVExplainer().algorithm_name
Expand All @@ -34,6 +46,15 @@ class MethodName(Enum):
LAYER_GRAD_CAM = LayerGradCAMCVExplainer().algorithm_name
INTEGRATED_GRADIENTS = IntegratedGradientsCVExplainer().algorithm_name
LAYER_INTEGRATED_GRADIENTS = LayerIntegratedGradientsCVExplainer().algorithm_name
SALIENCY = SaliencyCVExplainer().algorithm_name
DEEP_LIFT = DeepLIFTCVExplainer().algorithm_name
LAYER_DEEP_LIFT = LayerDeepLIFTCVExplainer().algorithm_name
DEEP_LIFT_SHAP = DeepLIFTSHAPCVExplainer().algorithm_name
LAYER_DEEP_LIFT_SHAP = LayerDeepLIFTSHAPCVExplainer().algorithm_name
DECONVOLUTION = DeconvolutionCVExplainer().algorithm_name
INPUT_X_GRADIENT = InputXGradientCVExplainer().algorithm_name
LAYER_INPUT_X_GRADIENT = LayerInputXGradientCVExplainer().algorithm_name
LAYER_CONDUCTANCE = LayerConductanceCVExplainer().algorithm_name

@classmethod
def from_string(cls, name: str) -> "MethodName":
Expand Down
26 changes: 25 additions & 1 deletion example/streamlit_app/run_streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,22 @@
)

from src.explainer.base_explainer import CVExplainer
from src.explainer.conductance import LayerConductanceCVExplainer
from src.explainer.deconv import DeconvolutionCVExplainer
from src.explainer.deeplift import DeepLIFTCVExplainer, LayerDeepLIFTCVExplainer
from src.explainer.deeplift_shap import (
DeepLIFTSHAPCVExplainer,
LayerDeepLIFTSHAPCVExplainer,
)
from src.explainer.gradcam import GuidedGradCAMCVExplainer, LayerGradCAMCVExplainer
from src.explainer.gradient_shap import (
GradientSHAPCVExplainer,
LayerGradientSHAPCVExplainer,
)
from src.explainer.input_x_gradient import (
InputXGradientCVExplainer,
LayerInputXGradientCVExplainer,
)
from src.explainer.integrated_gradients import (
IntegratedGradientsCVExplainer,
LayerIntegratedGradientsCVExplainer,
Expand All @@ -36,6 +47,7 @@
NoiseTunnelCVExplainer,
)
from src.explainer.occulusion import OcculusionCVExplainer
from src.explainer.saliency import SaliencyCVExplainer

cache_path = os.environ.get("LOGDIR", "logs")

Expand All @@ -51,6 +63,15 @@
LayerGradientSHAPCVExplainer(),
LayerLRPCVExplainer(),
LayerGradCAMCVExplainer(),
SaliencyCVExplainer(),
DeepLIFTCVExplainer(),
LayerDeepLIFTCVExplainer(),
DeepLIFTSHAPCVExplainer(),
LayerDeepLIFTSHAPCVExplainer(),
DeconvolutionCVExplainer(),
InputXGradientCVExplainer(),
LayerInputXGradientCVExplainer(),
LayerConductanceCVExplainer(),
]

explainer_map = {entry.algorithm_name: entry for entry in explainer_list}
Expand All @@ -60,9 +81,12 @@
LayerGradientSHAPCVExplainer,
LayerIntegratedGradientsCVExplainer,
LayerLRPCVExplainer,
LayerDeepLIFTCVExplainer,
LayerDeepLIFTSHAPCVExplainer,
LayerInputXGradientCVExplainer,
LayerConductanceCVExplainer,
]


method_list = [e.value for e in MethodName]


Expand Down
2 changes: 1 addition & 1 deletion example/streamlit_app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ class Settings: # pylint: disable = (too-few-public-methods)
sample_name_key: str = "sample_name"
epoch_number_key: str = "epoch_number"
date_selectbox_key: str = "date_selectbox"
selected_layer_key: str = "selected_layer"
selected_layer_key: str = "layer"
model_layers_key: str = "model_layers"
explain_key: str = "explain"
47 changes: 33 additions & 14 deletions src/explainer/base_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from captum.attr import visualization as viz

from src.explainer.model_utils import standardize_matrix


class CVExplainer(ABC):
"""Abstract explainer class."""
Expand Down Expand Up @@ -80,19 +82,36 @@ def visualize(
if len(transformed_img.shape) >= 3:
transformed_img_np = np.transpose(transformed_img_np, (1, 2, 0))

figure, _ = viz.visualize_image_attr_multiple(
attr=attributions_np,
original_image=transformed_img_np,
methods=["original_image", "heat_map", "heat_map", "heat_map"],
signs=["all", "positive", "negative", "all"],
titles=[
"Original image",
"Positive attributes",
"Negative attributes",
"All attributes",
],
show_colorbar=True,
use_pyplot=False,
)
attributions_np = standardize_matrix(matrix=attributions_np)
transformed_img_np = standardize_matrix(matrix=transformed_img_np)

try:
figure, _ = viz.visualize_image_attr_multiple(
attr=attributions_np,
original_image=transformed_img_np,
methods=["original_image", "heat_map", "heat_map", "heat_map"],
signs=["all", "positive", "negative", "all"],
titles=[
"Original image",
"Positive attributes",
"Negative attributes",
"All attributes",
],
show_colorbar=True,
use_pyplot=False,
)
except AssertionError:
figure, _ = viz.visualize_image_attr_multiple(
attr=attributions_np,
original_image=transformed_img_np,
methods=["original_image", "heat_map"],
signs=["all", "positive"],
titles=[
"Original image",
"Positive attributes",
],
show_colorbar=True,
use_pyplot=False,
)
adamwawrzynski marked this conversation as resolved.
Show resolved Hide resolved

return figure
69 changes: 69 additions & 0 deletions src/explainer/conductance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""File with Conductance algorithm explainer classes."""

from typing import Optional

import torch
from captum.attr import LayerConductance

from src.explainer.base_explainer import CVExplainer


class LayerConductanceCVExplainer(CVExplainer):
"""Layer Conductance algorithm explainer."""

def create_explainer(self, **kwargs) -> LayerConductance:
"""Create explainer object.

Raises:
RuntimeError: When passed arguments are invalid.

Returns:
Explainer object.
"""
model: Optional[torch.nn.Module] = kwargs.get("model", None)
layer: Optional[torch.nn.Module] = kwargs.get("layer", None)
if model is None or layer is None:
raise RuntimeError(
f"Missing or `None` arguments `model` or `layer` passed: {kwargs}"
)

conductance = LayerConductance(forward_func=model, layer=layer)

return conductance

def calculate_features(
self,
model: torch.nn.Module,
input_data: torch.Tensor,
adamwawrzynski marked this conversation as resolved.
Show resolved Hide resolved
pred_label_idx: int,
**kwargs,
) -> torch.Tensor:
"""Generate features image with Layer Conductance algorithm explainer.

Args:
model: Any DNN model You want to use.
input_data: Input image.
pred_label_idx: Predicted label.

Returns:
Features matrix.
"""
layer: Optional[torch.nn.Module] = kwargs.get("layer", None)

conductance = self.create_explainer(model=model, layer=layer)
attributions = conductance.attribute(
input_data,
baselines=torch.rand( # pylint: disable = (no-member)
1,
input_data.shape[1],
input_data.shape[2],
input_data.shape[3],
),
target=pred_label_idx,
)
if attributions.shape[0] == 0:
raise RuntimeError(
"Error occured during attribution calculation. "
+ "Make sure You are applying this method to CNN network.",
)
return attributions
78 changes: 78 additions & 0 deletions src/explainer/deconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""File with Deconvolution algorithm explainer classes."""

from abc import abstractmethod
from typing import Optional, Union

import torch
from captum.attr import Deconvolution, NeuronDeconvolution

from src.explainer.base_explainer import CVExplainer
from src.explainer.model_utils import modify_modules


class BaseDeconvolutionCVExplainer(CVExplainer):
"""Base Deconvolution algorithm explainer."""

@abstractmethod
def create_explainer(self, **kwargs) -> Union[Deconvolution, NeuronDeconvolution]:
"""Create explainer object.

Raises:
RuntimeError: When passed arguments are invalid.

Returns:
Explainer object.
"""

def calculate_features(
self,
model: torch.nn.Module,
input_data: torch.Tensor,
pred_label_idx: int,
**kwargs,
) -> torch.Tensor:
"""Generate features image with Deconvolution algorithm explainer.

Args:
model: Any DNN model You want to use.
input_data: Input image.
pred_label_idx: Predicted label.

Returns:
Features matrix.
"""
layer: Optional[torch.nn.Module] = kwargs.get("layer", None)

deconv = self.create_explainer(model=model, layer=layer)
attributions = deconv.attribute(
input_data,
target=pred_label_idx,
)
if attributions.shape[0] == 0:
raise RuntimeError(
"Error occured during attribution calculation. "
+ "Make sure You are applying this method to CNN network.",
)
return attributions


class DeconvolutionCVExplainer(BaseDeconvolutionCVExplainer):
"""Base Deconvolution algorithm explainer."""

def create_explainer(self, **kwargs) -> Union[Deconvolution, NeuronDeconvolution]:
"""Create explainer object.

Raises:
RuntimeError: When passed arguments are invalid.

Returns:
Explainer object.
"""
model: Optional[torch.nn.Module] = kwargs.get("model", None)
if model is None:
raise RuntimeError(f"Missing or `None` argument `model` passed: {kwargs}")

model = modify_modules(model=model)
deconv = Deconvolution(model=model)

return deconv
Loading