In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

# Define your PyTorch model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize the model and load its weights
model = SimpleCNN()
model.eval()

# Load a batch of MNIST test samples
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
testset = MNIST(root='./data', train=False, transform=transform, download=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

# Define a function to calculate accuracy
def get_accuracy(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [10]:
num_epochs = 1
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in trainloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}')

    # Calculate and print training accuracy
    train_accuracy = get_accuracy(model, trainloader)
    print(f'Training Accuracy: {train_accuracy:.2f}%')

    # Calculate and print test accuracy
    test_accuracy = get_accuracy(model, testloader)
    print(f'Test Accuracy: {test_accuracy:.2f}%')

print('Finished Training')

Epoch 1, Loss: 0.14999926203652333
Training Accuracy: 98.50%
Test Accuracy: 98.52%
Epoch 2, Loss: 0.046102329016613985


KeyboardInterrupt: ignored

In [11]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
testset = MNIST(root='./data', train=False, transform=transform, download=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
images, labels = next(iter(testloader))

# Select a single image from the batch (you can change the index)
input_image = images[0].unsqueeze(0)
input_image.requires_grad = True

# Define a function to calculate the integrated gradients
def integrated_gradients(input_tensor, model, baseline=None, num_steps=100):
    if baseline is None:
        baseline = torch.zeros_like(input_tensor)

    # Create a linear interpolation path from the baseline to the input
    path = torch.linspace(0, 1, num_steps, device=input_tensor.device).unsqueeze(1)
    interpolated_inputs = baseline + path * (input_tensor - baseline)

    # Compute the gradients with respect to the input at each step of the path
    gradients = []
    for step_input in interpolated_inputs:
        step_input.requires_grad_()
        output = model(step_input)
        output.sum().backward()  # You can change this based on your specific task
        gradient = step_input.grad
        gradients.append(gradient)

    # Calculate the integrated gradients as the average gradient over the path
    integrated_gradients = torch.mean(torch.stack(gradients), dim=0)

    return integrated_gradients

# Calculate integrated gradients for the selected image
baseline_tensor = torch.zeros_like(input_image)
ig = integrated_gradients(input_image, model, baseline=baseline_tensor)

# The 'ig' tensor now contains the integrated gradients for each pixel in the image
print(ig)

RuntimeError: ignored