In [1]:
import torch
import torch.nn as nn
import numpy as np


LRP-0 (LRP-Zero) Rule:

This rule is often applied to fully connected layers with ReLU activations. It propagates relevance to the input neurons based on the proportion of the positive contributions of each neuron to the layer's output.

In [None]:
def lrp_zero(R, layer):
    V = torch.clamp(layer.weight, min=0)
    Z = torch.mm(layer.input, V.t()) + 1e-9  # Stabilize division by small constant
    S = R / Z
    C = torch.mm(S, V)
    R_next = layer.input * C
    return R_next


LRP-ε (Epsilon) Rule:

The LRP-ε rule adds a small positive term (epsilon) to the denominator during the relevance propagation to stabilize the division, especially when the contributions in the denominator are close to zero. This prevents the amplification of contributions from low-activation neurons.

In [None]:
def lrp_epsilon(R, layer, epsilon=0.01):
    Z = torch.mm(layer.input, layer.weight.t()) + epsilon  # Add epsilon for numerical stability
    S = R / Z
    C = torch.mm(S, layer.weight)
    R_next = layer.input * C
    return R_next


LRP-γ (Gamma) Rule:

The LRP-γ rule introduces a parameter gamma that amplifies the contribution of positive activations. It is designed to put more emphasis on neurons that have a stronger positive influence on the output.

In [None]:
def lrp_gamma(R, layer, gamma=0.1):
    V = torch.clamp(layer.weight, min=0) * (1 + gamma) - torch.clamp(layer.weight, max=0) * gamma
    Z = torch.mm(layer.input, V.t()) + 1e-9  # Stabilize division by small constant
    S = R / Z
    C = torch.mm(S, V)
    R_next = layer.input * C
    return R_next


LRP-αβ (Alpha Beta) Rule:

This rule uses two parameters, alpha and beta, to separately treat positive and negative contributions during relevance propagation. The alpha parameter scales the positive contributions, while beta scales the negative ones. Typically, alpha is set to 1 and beta to 0, ensuring that only positive contributions are considered, which is useful for networks with mixed activation functions.

In [None]:
def lrp_alpha_beta(R, layer, alpha=1, beta=0):
    V_pos = torch.clamp(layer.weight, min=0) * alpha
    V_neg = torch.clamp(layer.weight, max=0) * beta
    Z = torch.mm(layer.input, (V_pos - V_neg).t()) + 1e-9  # Stabilize division by small constant
    S = R / Z
    C_pos = torch.mm(S, V_pos)
    C_neg = torch.mm(S, V_neg)
    R_next = layer.input * (C_pos + C_neg)
    return R_next


LRP-Flat (Flat) Rule:

The LRP-Flat rule distributes relevance equally among all contributing neurons, regardless of their individual contribution. This rule is often used as a baseline or for comparison with other rules.

In [None]:
def lrp_flat(R, layer):
    contribution = torch.ones_like(layer.weight)
    Z = torch.mm(layer.input, contribution.t()) + 1e-9  # Stabilize division by small constant
    S = R / Z
    C = torch.mm(S, contribution)
    R_next = layer.input * C
    return R_next


LRP-Composite Rule:

Composite rules combine different LRP rules for different layers of the network. For example, one might use the LRP-ε rule for the first few layers and the LRP-γ rule for the last few layers, depending on the characteristics of the network and the desired explanation.

In [None]:
def lrp_composite(R, layer, rule_for_layer):
    # Apply different rules based on the layer type or order
    if rule_for_layer == 'epsilon':
        return lrp_epsilon(R, layer)
    elif rule_for_layer == 'gamma':
        return lrp_gamma(R, layer)
    # Add other conditions as necessary


LRP-w^2 Rule:

This rule considers the square of the weights as the relevance criteria, emphasizing the importance of the weight's magnitude, independent of the activation sign.

In [None]:
def lrp_w_square(R, layer):
    W_square = torch.pow(layer.weight, 2)
    Z = torch.mm(layer.input, W_square.t()) + 1e-9  # Stabilize division by small constant
    S = R / Z
    C = torch.mm(S, W_square)
    R_next = layer.input * C
    return R_next


LRP-z^+ Rule:

Similar to LRP-0 but only considers the positive part of the z contribution, which is the weighted sum of the activations from the lower layer plus the bias term.

In [None]:
def lrp_z_plus(R, layer):
    V = torch.clamp(layer.weight, min=0)
    Z = torch.mm(layer.input, V.t()) + 1e-9  # Stabilize division by small constant
    S = R / Z
    C = torch.mm(S, V)
    R_next = layer.input * C
    return R_next
