In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ToyModel(nn.Module):
    r"""
    Example toy model from the original paper (page 10)

    https://arxiv.org/pdf/1703.01365.pdf


    f(x1, x2) = RELU(ReLU(x1) - 1 - ReLU(x2))
    """

    def __init__(self):
        super().__init__()
        self.fc_layer = nn.Linear(1, 1)

    def forward(self, input1, input2):
        print(input1)
        print(input2)
        relu_out1 = F.relu(input1)
        relu_out2 = F.relu(input2)
        print(F.relu(relu_out1 - 1 - relu_out2).shape)
        return F.relu(relu_out1 - 1 - relu_out2)

In [46]:
from captum.attr import IntegratedGradients
model = ToyModel()

# defining model input tensors
input1 = torch.tensor([3.0], requires_grad=True)
input2 = torch.tensor([1.0], requires_grad=True)

# print(model.forward(input1, input2))
# print(model.forward(baseline1, baseline2))


# defining baselines for each input tensor
baseline1 = torch.tensor([0.0])
baseline2 = torch.tensor([0.0])

# defining and applying integrated gradients on ToyModel and the
ig = IntegratedGradients(model)
attributions, approximation_error = ig.attribute((input1, input2),
                                                 baselines=(baseline1, baseline2),
                                                 method='gausslegendre',
                                                 return_convergence_delta=True)

print(attributions)

tensor([1.7004e-03, 8.9520e-03, 2.1969e-02, 4.0703e-02, 6.5084e-02, 9.5015e-02,
        1.3038e-01, 1.7105e-01, 2.1686e-01, 2.6763e-01, 3.2317e-01, 3.8326e-01,
        4.4767e-01, 5.1616e-01, 5.8845e-01, 6.6426e-01, 7.4331e-01, 8.2529e-01,
        9.0988e-01, 9.9675e-01, 1.0856e+00, 1.1760e+00, 1.2677e+00, 1.3602e+00,
        1.4534e+00, 1.5466e+00, 1.6398e+00, 1.7323e+00, 1.8240e+00, 1.9144e+00,
        2.0033e+00, 2.0901e+00, 2.1747e+00, 2.2567e+00, 2.3357e+00, 2.4116e+00,
        2.4838e+00, 2.5523e+00, 2.6167e+00, 2.6768e+00, 2.7324e+00, 2.7831e+00,
        2.8290e+00, 2.8696e+00, 2.9050e+00, 2.9349e+00, 2.9593e+00, 2.9780e+00,
        2.9910e+00, 2.9983e+00], grad_fn=<CatBackward>)
tensor([5.6680e-04, 2.9840e-03, 7.3230e-03, 1.3568e-02, 2.1695e-02, 3.1672e-02,
        4.3461e-02, 5.7016e-02, 7.2285e-02, 8.9209e-02, 1.0772e-01, 1.2775e-01,
        1.4922e-01, 1.7205e-01, 1.9615e-01, 2.2142e-01, 2.4777e-01, 2.7510e-01,
        3.0329e-01, 3.3225e-01, 3.6186e-01, 3.9200e-01, 4.2255e-

In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ToySoftmaxModel(nn.Module):
    r"""
    Model architecture from:

    https://adventuresinmachinelearning.com/pytorch-tutorial-deep-learning/
    """

    def __init__(self, num_in, num_hidden, num_out):
        super().__init__()
        self.num_in = num_in
        self.num_hidden = num_hidden
        self.num_out = num_out
        self.lin1 = nn.Linear(num_in, num_hidden)
        self.lin2 = nn.Linear(num_hidden, num_hidden)
        self.lin3 = nn.Linear(num_hidden, num_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input):
        lin1 = F.relu(self.lin1(input))
        lin2 = F.relu(self.lin2(lin1))
        lin3 = self.lin3(lin2)
        return self.softmax(lin3)

In [45]:
from captum.attr import IntegratedGradients
num_in = 40
input = torch.arange(0.0, num_in * 1.0, requires_grad=True).unsqueeze(0)
baseline = torch.tensor([0.0]*num_in, requires_grad=True).unsqueeze(0)

print(input.shape, baseline.shape)

# 10-class classification model
model = ToySoftmaxModel(num_in, 20, 10)

# attribution score will be computed with respect to target class
target_class_index = 5

print(torch.sum(model.forward(input)))
print(torch.sum(model.forward(baseline)))

# applying integrated gradients on the SoftmaxModel and input data point
ig = IntegratedGradients(model)
attributions, approximation_error = ig.attribute(input, 
                                                 baselines=baseline,
                                                 target=target_class_index,
                                                 return_convergence_delta=True)

print(torch.sum(attributions))

torch.Size([1, 40]) torch.Size([1, 40])
tensor(1.0000, grad_fn=<SumBackward0>)
tensor(1., grad_fn=<SumBackward0>)
tensor(-0.0943, dtype=torch.float64, grad_fn=<SumBackward0>)
