In [6]:
# Step size for each attack iteration
alpha = 0.00784
# Maximum perturbation allowed
epsilon = 1.0/255
# Number of attack iterations
k = 20


class PGDAttack(object):
    """
    Implements the PGD attack for generating adversarial examples.

    Args:
        model: The target model to attack.
    """
    def __init__(self, model):
        self.model = model

    def perturb(self, input_natural, y):
        # Detach the original input to prevent tracking its gradients
        inputs = input_natural.detach()
        # Add random noise to the input
        inputs = inputs + torch.zeros_like(inputs).uniform_(-epsilon, epsilon)
        # Perform k iterations of PGD
        for i in range(k):
            inputs.requires_grad_()
            with torch.enable_grad():
                output = self.model(inputs)
                loss = F.cross_entropy(output, y)
            # Compute gradients of the loss w.r.t. the input
            grad = torch.autograd.grad(loss, [inputs])[0]
            # Update the input using the sign of the gradient and the step size
            inputs = inputs.detach() + alpha * torch.sign(grad.detach())
            # Project the perturbed input back into the allowed epsilon-ball around x_natural
            inputs = torch.min(torch.max(inputs, input_natural - epsilon), input_natural + epsilon)
            inputs = torch.clamp(x, 0, 1)
        return inputs

In [5]:
def evaluate_model(model, dataloader, device):
    """
    This function evaluates the model using test dataset
    """
    model.eval()
    prediction = torch.Tensor().to(device)
    labels = torch.LongTensor().to(device)

    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            prediction = torch.cat([prediction, model(x_batch)])
            labels = torch.cat([labels, y_batch])
            
    # passing the logits through Softmax layer to get predicted class
    prediction = torch.nn.functional.softmax(prediction, dim=1)
    
    return prediction, labels

In [42]:
def count_params(model):
    """
    This function calculates the number of parameters in a model

    Args:
        model: Model to count parameter
    """
    total_parameters = 0
    for layer, parameter in model.named_parameters():
        total_parameters += torch.count_nonzero(parameter.data)
    return total_parameters

In [25]:
def compute_sparsity_lenet(model):
    """
    Calculate the global sparsity of a LeNet model.

    Args:
        model: The LeNet model for which sparsity is calculated.

    Returns:
        float: The global sparsity of the model.
    """
    conv1_sparsity = torch.sum(model.conv_1.weight == 0) 
    conv2_sparsity = torch.sum(model.conv_2.weight == 0)
    fc1_sparsity = torch.sum(model.fc_1.weight == 0) 
    fc2_sparsity = torch.sum(model.fc_2.weight == 0) 
    fc3_sparsity = torch.sum(model.fc_3.weight == 0) 
    
    num = conv1_sparsity + conv2_sparsity +fc1_sparsity + fc2_sparsity + fc3_sparsity
    denom = model.conv_1.weight.nelement() + model.conv_2.weight.nelement() + model.fc_1.weight.nelement() + model.fc_2.weight.nelement() + model.fc_3.weight.nelement()

    global_sparsity = num/denom * 100

    return global_sparsity

In [27]:
def compute_sparsity_resnet(model):
    """
    Calculate the global sparsity of a ResNet model.

    Args:
        model: The LeNet model for which sparsity is calculated.

    Returns:
        float: The global sparsity of the model.
    """
    conv0_sparsity = (torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()) * 100
    bn0_sparsity = (torch.sum(model.bn1.weight == 0) / model.bn1.weight.nelement()) * 100
    
    conv1_sparsity = (torch.sum(model.layer1[0].conv1.weight == 0) / model.layer1[0].conv1.weight.nelement()) * 100
    conv1_sparsity = (torch.sum(model.layer1[0].bn1.weight == 0) / model.layer1[0].bn1.weight.nelement()) * 100

    conv2_sparsity = (torch.sum(model.layer1[0].conv2.weight == 0) / model.layer1[0].conv2.weight.nelement()) * 100
    conv2_sparsity = (torch.sum(model.layer1[0].bn2.weight == 0) / model.layer1[0].bn2.weight.nelement()) * 100

    conv3_sparsity = (torch.sum(model.layer1[1].conv1.weight == 0) / model.layer1[1].conv1.weight.nelement()) * 100
    conv3_sparsity = (torch.sum(model.layer1[1].bn1.weight == 0) / model.layer1[1].bn1.weight.nelement()) * 100

    conv4_sparsity = (torch.sum(model.layer1[1].conv2.weight == 0) / model.layer1[1].conv2.weight.nelement()) * 100
    conv4_sparsity = (torch.sum(model.layer1[1].bn2.weight == 0) / model.layer1[1].bn2.weight.nelement()) * 100

    conv5_sparsity = (torch.sum(model.layer2[0].conv1.weight == 0) / model.layer2[0].conv1.weight.nelement()) * 100
    conv5_sparsity = (torch.sum(model.layer2[0].bn1.weight == 0) / model.layer2[0].bn1.weight.nelement()) * 100

    conv6_sparsity = (torch.sum(model.layer2[0].conv2.weight == 0) / model.layer2[0].conv2.weight.nelement()) * 100
    conv6_sparsity = (torch.sum(model.layer2[0].bn2.weight == 0) / model.layer2[0].bn2.weight.nelement()) * 100

    conv7_sparsity = (torch.sum(model.layer2[1].conv1.weight == 0) / model.layer2[1].conv1.weight.nelement()) * 100
    conv7_sparsity = (torch.sum(model.layer2[1].bn1.weight == 0) / model.layer2[1].bn1.weight.nelement()) * 100

    conv8_sparsity = (torch.sum(model.layer2[1].conv2.weight == 0) / model.layer2[1].conv2.weight.nelement()) * 100
    conv8_sparsity = (torch.sum(model.layer2[1].bn2.weight == 0) / model.layer2[1].bn2.weight.nelement()) * 100

    conv9_sparsity = (torch.sum(model.layer3[0].conv1.weight == 0) / model.layer3[0].conv1.weight.nelement()) * 100
    conv9_sparsity = (torch.sum(model.layer3[0].bn1.weight == 0) / model.layer3[0].bn1.weight.nelement()) * 100

    conv10_sparsity = (torch.sum(model.layer3[0].conv2.weight == 0) / model.layer3[0].conv2.weight.nelement()) * 100
    conv10_sparsity = (torch.sum(model.layer3[0].bn2.weight == 0) / model.layer3[0].bn2.weight.nelement()) * 100

    conv11_sparsity = (torch.sum(model.layer3[1].conv1.weight == 0) / model.layer3[1].conv1.weight.nelement()) * 100
    conv11_sparsity = (torch.sum(model.layer3[1].bn1.weight == 0) / model.layer3[1].bn1.weight.nelement()) * 100

    conv12_sparsity = (torch.sum(model.layer3[1].conv2.weight == 0) / model.layer3[1].conv2.weight.nelement()) * 100
    conv12_sparsity = (torch.sum(model.layer3[1].bn2.weight == 0) / model.layer3[1].bn2.weight.nelement()) * 100

    conv13_sparsity = (torch.sum(model.layer4[0].conv1.weight == 0) / model.layer4[0].conv1.weight.nelement()) * 100
    conv13_sparsity = (torch.sum(model.layer4[0].bn1.weight == 0) / model.layer4[0].bn1.weight.nelement()) * 100

    conv14_sparsity = (torch.sum(model.layer4[0].conv2.weight == 0) / model.layer4[0].conv2.weight.nelement()) * 100
    conv14_sparsity = (torch.sum(model.layer4[0].bn2.weight == 0) / model.layer4[0].bn2.weight.nelement()) * 100

    conv15_sparsity = (torch.sum(model.layer4[1].conv1.weight == 0) / model.layer4[1].conv1.weight.nelement()) * 100
    conv15_sparsity = (torch.sum(model.layer4[1].bn1.weight == 0) / model.layer4[1].bn1.weight.nelement()) * 100

    conv16_sparsity = (torch.sum(model.layer4[1].conv2.weight == 0) / model.layer4[1].conv2.weight.nelement()) * 100
    conv16_sparsity = (torch.sum(model.layer4[1].bn2.weight == 0) / model.layer4[1].bn2.weight.nelement()) * 100
    
    fc_sparsity = (torch.sum(model.fc.weight == 0) / model.fc.weight.nelement()) * 100

    num =  torch.sum(model.conv1.weight == 0) + torch.sum(model.bn1.weight == 0) + torch.sum(model.layer1[0].conv1.weight == 0) + torch.sum(model.layer1[0].bn1.weight == 0) + torch.sum(model.layer1[0].conv2.weight == 0) +  torch.sum(model.layer1[0].bn2.weight == 0) + torch.sum(model.layer1[1].conv1.weight == 0) +  torch.sum(model.layer1[1].bn1.weight == 0) + torch.sum(model.layer1[1].conv2.weight == 0) + torch.sum(model.layer1[1].bn2.weight == 0) +torch.sum(model.layer2[0].conv1.weight == 0) + torch.sum(model.layer2[0].bn1.weight == 0) + torch.sum(model.layer2[0].conv2.weight == 0) +  torch.sum(model.layer2[0].bn2.weight == 0) + torch.sum(model.layer2[1].conv1.weight == 0) + torch.sum(model.layer2[1].bn1.weight == 0) + torch.sum(model.layer2[1].conv2.weight == 0) + torch.sum(model.layer2[1].bn2.weight == 0) + torch.sum(model.layer3[0].conv1.weight == 0) + torch.sum(model.layer3[0].bn1.weight == 0) + torch.sum(model.layer3[0].conv2.weight == 0) +  torch.sum(model.layer3[0].bn2.weight == 0) + torch.sum(model.layer3[1].conv1.weight == 0) +  torch.sum(model.layer3[1].bn1.weight == 0) + torch.sum(model.layer3[1].conv2.weight == 0) + torch.sum(model.layer3[1].bn2.weight == 0) + torch.sum(model.layer4[0].conv1.weight == 0) + torch.sum(model.layer4[0].bn1.weight == 0) + torch.sum(model.layer4[0].conv2.weight == 0) +  torch.sum(model.layer4[0].bn2.weight == 0) + torch.sum(model.layer4[1].conv1.weight == 0) +  torch.sum(model.layer4[1].bn1.weight == 0) + torch.sum(model.layer4[1].conv2.weight == 0) + torch.sum(model.layer4[1].bn2.weight == 0) + torch.sum(model.fc.weight == 0) 
                                                                                                                                                                                                                                                                                                                                    
    denom =  model.conv1.weight.nelement() +  model.bn1.weight.nelement() + model.layer1[0].conv1.weight.nelement() + model.layer1[0].bn1.weight.nelement() + model.layer1[0].conv2.weight.nelement() + model.layer1[0].bn2.weight.nelement() + model.layer1[1].conv1.weight.nelement() +  model.layer1[1].bn1.weight.nelement() + model.layer1[1].conv2.weight.nelement() + model.layer1[1].bn2.weight.nelement() +  model.layer2[0].conv1.weight.nelement() + model.layer2[0].bn1.weight.nelement() + model.layer2[0].conv2.weight.nelement() + model.layer2[0].bn2.weight.nelement() + model.layer2[1].conv1.weight.nelement() +  model.layer2[1].bn1.weight.nelement() + model.layer2[1].conv2.weight.nelement() + model.layer2[1].bn2.weight.nelement() +  model.layer3[0].conv1.weight.nelement() + model.layer3[0].bn1.weight.nelement() + model.layer3[0].conv2.weight.nelement() + model.layer3[0].bn2.weight.nelement() + model.layer3[1].conv1.weight.nelement() +  model.layer3[1].bn1.weight.nelement() + model.layer3[1].conv2.weight.nelement() + model.layer3[1].bn2.weight.nelement() +  model.layer4[0].conv1.weight.nelement() + model.layer4[0].bn1.weight.nelement() + model.layer4[0].conv2.weight.nelement() + model.layer4[0].bn2.weight.nelement() + model.layer4[1].conv1.weight.nelement() +  model.layer4[1].bn1.weight.nelement() + model.layer4[1].conv2.weight.nelement() + model.layer4[1].bn2.weight.nelement() + model.fc.weight.nelement()
    global_sparsity = num/denom * 100
    return global_sparsity

In [29]:
def compute_sparsity_vgg(model):
    """
    Calculate the global sparsity of a VGG model.

    Args:
        model: The LeNet model for which sparsity is calculated.

    Returns:
        float: The global sparsity of the model.
    """
    conv1_sparsity = (torch.sum(model.features[0].weight == 0) / model.features[0].weight.nelement()) * 100
    conv2_sparsity = (torch.sum(model.features[2].weight == 0) / model.features[2].weight.nelement()) * 100
    conv3_sparsity = (torch.sum(model.features[5].weight == 0) / model.features[5].weight.nelement()) * 100
    conv4_sparsity = (torch.sum(model.features[7].weight == 0) / model.features[7].weight.nelement()) * 100
    conv5_sparsity = (torch.sum(model.features[10].weight == 0) / model.features[10].weight.nelement()) * 100
    conv6_sparsity = (torch.sum(model.features[12].weight == 0) / model.features[12].weight.nelement()) * 100
    conv7_sparsity = (torch.sum(model.features[14].weight == 0) / model.features[14].weight.nelement()) * 100
    conv8_sparsity = (torch.sum(model.features[17].weight == 0) / model.features[17].weight.nelement()) * 100
    conv9_sparsity = (torch.sum(model.features[19].weight == 0) / model.features[19].weight.nelement()) * 100
    conv10_sparsity = (torch.sum(model.features[21].weight == 0) / model.features[21].weight.nelement()) * 100
    conv11_sparsity = (torch.sum(model.features[24].weight == 0) / model.features[24].weight.nelement()) * 100
    conv12_sparsity = (torch.sum(model.features[26].weight == 0) / model.features[26].weight.nelement()) * 100
    conv13_sparsity = (torch.sum(model.features[28].weight == 0) / model.features[28].weight.nelement()) * 100
    fc1_sparsity = (torch.sum(model.classifier[1].weight == 0) / model.classifier[1].weight.nelement()) * 100
    fc2_sparsity = (torch.sum(model.classifier[4].weight == 0) / model.classifier[4].weight.nelement()) * 100
    op_sparsity = (torch.sum(model.classifier[6].weight == 0) / model.classifier[6].weight.nelement()) * 100

    num = torch.sum(model.features[0].weight == 0) + torch.sum(model.features[2].weight == 0) + torch.sum(model.features[5].weight == 0) + torch.sum(model.features[7].weight == 0) + torch.sum(model.features[10].weight == 0) + torch.sum(model.features[12].weight == 0) + torch.sum(model.features[14].weight == 0) + torch.sum(model.features[17].weight == 0) + torch.sum(model.features[19].weight == 0) + torch.sum(model.features[21].weight == 0)+ torch.sum(model.features[24].weight == 0) + torch.sum(model.features[26].weight == 0) + torch.sum(model.features[28].weight == 0) + torch.sum(model.classifier[1].weight == 0) + torch.sum(model.classifier[4].weight == 0) + torch.sum(model.classifier[6].weight == 0)
    denom = model.features[0].weight.nelement() + model.features[2].weight.nelement() + model.features[5].weight.nelement() + model.features[7].weight.nelement() + model.features[10].weight.nelement() + model.features[12].weight.nelement() + model.features[14].weight.nelement() + model.features[17].weight.nelement() + model.features[19].weight.nelement() + model.features[21].weight.nelement() + model.features[24].weight.nelement() + model.features[26].weight.nelement() + model.features[28].weight.nelement() + model.classifier[1].weight.nelement() + model.classifier[4].weight.nelement() + model.classifier[6].weight.nelement()
    global_sparsity = num/denom * 100
    return global_sparsity

In [10]:
def image_explanation_lenet(index, x_batch, y_batch, a_batch_saliency, a_batch_integrad, a_batch_smoothgrad, file_name):
    """
    Visualize and save explanation maps for a single image using LeNet.

    Args:
        index : Index of the image in the batch.
        x_batch: Batch of input images.
        y_batch: labels for the input images.
        a_batch_saliency:Explanation maps generated using Integrated Gradients Vanilla Gradient.
        a_batch_integrad : Explanation maps generated using Integrated Gradients.
        a_batch_smoothgrad: Explanation maps generated using SmoothGrad.
        file_name: Name for the file to save the visualizations.
    """
    nr_images = 2
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(nr_images*4., int(nr_images)))
    
    #plot normal
    axes[0].imshow(np.reshape(x_batch[index], (28,28)), cmap="grey")
    axes[0].title.set_text(f"Normal Image {y_batch[index].item()}")
    axes[0].axis("off")
    axes[1].imshow(a_batch_saliency[index], cmap="hot")
    axes[1].title.set_text(f"Vanilla Gradient")
    axes[1].axis("off")  
    axes[2].imshow(a_batch_integrad[index], cmap="hot")
    axes[2].title.set_text(f"Integrated Gradients")
    axes[2].axis("off")
    axes[3].imshow(a_batch_smoothgrad[index], cmap="hot")
    axes[3].title.set_text(f"SmoothGrad")
    axes[3].axis("off")
    plt.tight_layout()
    plt.savefig(f'Explanations/{file_name}.png')
    
    plt.show()

In [8]:
def image_explanation(index, x_batch, y_batch, a_batch_saliency, a_batch_integrad, a_batch_smoothgrad, file_name):
    """
    Visualize and save explanation maps for a single image using ResNet or VGG.

    Args:
        index : Index of the image in the batch.
        x_batch: Batch of input images.
        y_batch: labels for the input images.
        a_batch_saliency:Explanation maps generated using Integrated Gradients Vanilla Gradient.
        a_batch_integrad : Explanation maps generated using Integrated Gradients.
        a_batch_smoothgrad: Explanation maps generated using SmoothGrad.
        file_name: Name for the file to save the visualizations.
    """
    nr_images = 2
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(nr_images*4, int(nr_images)))
    
    #plot normal
    axes[0].imshow(np.moveaxis(x_batch[index], 0, -1), vmin=0.0, vmax=1.0)
    axes[0].title.set_text(f"Normal Image {y_batch[index].item()}")
    axes[0].axis("off")
    axes[1].imshow(a_batch_saliency[index], cmap="hot")
    axes[1].title.set_text(f"Vanilla Gradient")
    axes[1].axis("off")  
    axes[2].imshow(a_batch_integrad[index], cmap="hot")
    axes[2].title.set_text(f"Integrated Gradients")
    axes[2].axis("off")
    axes[3].imshow(a_batch_smoothgrad[index], cmap="hot")
    axes[3].title.set_text(f"SmoothGrad")
    axes[3].axis("off")
    plt.tight_layout()
    
    plt.savefig(f'Explanations/{file_name}.png')
    plt.show()