In [10]:
import copy
import torch
import numpy as np

In [11]:
def print_nonzeros(model):
    """
    This function inspects every parameter tensor in the given model and
    counts how many elements are non-zero vs. zero.

    Parameters:
        model : The model whose parameters will be analyzed.

    Returns:
        float : The percentage of non-zero parameters in the model.
    """
    non_zeros = 0
    total_parameters = 0
    # Iterate over each parameter in the model
    for name, param in model.named_parameters():
        # Convert parameter to a NumPy array on CPU
        tensor = param.data.cpu().numpy()
        tensor_shape = tensor.shape

        # Count how many values are non-zero in this tensor
        non_zero_count = np.count_nonzero(tensor)
        
        # Calculate total number of parameters in this tensor
        parameters = np.prod(tensor_shape)
        non_zeros += non_zero_count
        total_parameters += parameters
        
    print(f'Non pruned: {non_zeros}, pruned : {total_parameters - non_zeros}, total: {total_parameters}, Compression rate:{100 * (total_parameters-non_zeros) / total_parameters:6.2f}% pruned')
    # percentage of non zero parameters
    result = round((non_zeros/total_parameters)*100,1)
    return result

In [12]:
def original_initialization(mask_temp, model, initial_state_dict):
    """
    Restore a pruned model's parameters to their original (pre-pruning) values
    while keeping the same pruning structure.

    Parameters:
        mask_temp : Binary masks for each weight tensor (1 = keep, 0 = pruned).
        model : The model to restore.
        initial_state_dict : A state_dict containing the original parameters before pruning.
    """
    step = 0
    for name, param in model.named_parameters():
        # restore bias fully
        if "bias" in name:
            param.data = initial_state_dict[name]
            
        # restore weights with mask applied
        if "weight" in name: 
            weight_device = param.device
            
            # Multiply original weights by mask to keep pruned positions at zero
            masked_weights = mask_temp[step] * initial_state_dict[name].cpu().numpy()
            param.data = torch.from_numpy(masked_weights).to(weight_device)
            step += 1
    step = 0

In [13]:
def checkdir(directory):
    """
    Ensure that a given directory exists. If it does not exist, create it.

    Parameters:
        directory : Path to the directory to check/create.
    """
    if not os.path.exists(directory):
        os.makedirs(directory)

In [14]:
def prune_network(percent,mask, model,**kwargs):
    """
    Prune weights in a model by a given percentile of their magnitude.

    Parameters:
        percent : Percentile value (0–100) used as the cutoff for pruning.
        mask : List of binary masks (1 = keep, 0 = prune) for each weight tensor.
        model : The model whose weights will be pruned.
        **kwargs : Additional arguments (currently unused).

    Returns:
        The updated pruning masks after applying percentile-based pruning.
    """
    global step
    # Calculate percentile
    step = 0
    for name, param in model.named_parameters():
        # Not pruning bias term
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            # Get non-zero weight values
            nonzero_values = np.nonzero(tensor)
            
            # flattened array of nonzero values
            nonzero_list = tensor[nonzero_values]
            
            # Determine magnitude threshold for pruning
            percentile_threshold = np.percentile(abs(nonzero_list), percent)

            # Convert Tensors to numpy and calculate
            weight_device = param.device
            
            # Create new mask: prune if abs(weight) < threshold, otherwise keep old mask value
            new_mask = np.where(abs(tensor) < percentile_threshold, 0, mask[step])
            
            # Apply new weight and mask
            param.data = torch.from_numpy(tensor * new_mask).to(weight_device)
            mask[step] = new_mask
            step += 1
    step = 0
    return mask



In [15]:
def generate_mask(model):
    """
    Creates a list of masks for the weight parameters of a given model.

    Args:
        model: The neural network model

    Returns:
        list: A list of NumPy arrays, where each array is a mask of ones
              corresponding to a weight layer in the model.
    """
    global step
    step = 0
    # Iterate through all named parameters 
    for name, param in model.named_parameters(): 
        # Check if the parameter's name contains 'weight'.
        if 'weight' in name:
            step += 1
    # Create a list initialized with 'None', with a length equal to the number of weight layers.
    mask = [None]* step 
    step = 0
    for name, param in model.named_parameters(): 
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            # Create a NumPy array of ones with the same shape and type as the weight tensor.
            mask[step] = np.ones_like(tensor)
            step += 1
    step = 0
    return mask

In [16]:
def train(model, train_dataloader, optimizer, criterion,scheduler):
    """
    Performs one full training epoch for the given model.

    Args:
        model: The neural network model to be trained.
        train_dataloader: The dataloader for the training set.
        optimizer: The optimization algorithm.
        criterion: The loss function.
        scheduler: The learning rate scheduler.

    Returns:
        float: The loss value from the last batch of the epoch.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    for index, (x_batch, y_batch) in enumerate(train_dataloader):
        # Clear previous gradients
        optimizer.zero_grad()
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        # Compute predicted outputs by passing inputs to the model.
        output = model(x_batch)
        # Calculate the loss
        loss = criterion(output, y_batch)
        # Compute gradient of the loss with respect to model parameters.
        loss.backward()
        # Update weights: call the optimizer to update the model's weights.
        optimizer.step()
    # update the learning rate.
    scheduler.step()
    return loss.item()



In [17]:
def test(model, test_dataloader, criterion):
    """
    Evaluates the model's performance on the test dataset.

    Args:
        model: The trained neural network model to evaluate.
        test_dataloader: The dataloader for the test set.
        criterion: The loss function used to measure the model's error.

    Returns:
        float: The classification accuracy of the model on the test set.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for x_batch, y_batch in test_dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            output = model(x_batch)
            # Calculate and accumulate loss for the batch.
            loss += criterion(output, y_batch).item()
            
            # Get the predicted class.
            pred = output.argmax(1, keepdim=True)
            
            # Count correct predictions.
            correct += pred.eq(y_batch.data.view_as(pred)).sum().item()
        loss /= len(test_dataloader.dataset)
        accuracy = 100. * correct / len(test_dataloader.dataset)
    return accuracy

In [18]:
def image_explanation(index, x_batch, y_batch, a_batch_saliency, a_batch_integrad, file_name):
    """
    Visualize and save explanation maps for a single image using ResNet or VGG.

    Args:
        index : Index of the image in the batch.
        x_batch: Batch of input images.
        y_batch: labels for the input images.
        a_batch_saliency:Explanation maps generated using Integrated Gradients Vanilla Gradient.
        a_batch_integrad : Explanation maps generated using Integrated Gradients.
        a_batch_smoothgrad: Explanation maps generated using SmoothGrad.
        file_name: Name for the file to save the visualizations.
    """
    nr_images = 2
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(nr_images*4, int(nr_images)))
    
    # Get the raw image data
    normal_image = np.moveaxis(x_batch[index], 0, -1) 
    # Rescale it to the [0, 1] range for display
    normal_image_display = (normal_image - normal_image.min()) / (normal_image.max() - normal_image.min())
   

    #plot normal
    axes[0].imshow(normal_image_display)
    axes[0].title.set_text(f"Normal Image {y_batch[index].item()}")
    axes[0].axis("off")
    axes[1].imshow(a_batch_saliency[index], cmap="hot")
    axes[1].title.set_text(f"Vanilla Gradient")
    axes[1].axis("off")  
    axes[2].imshow(a_batch_integrad[index], cmap="hot")
    axes[2].title.set_text(f"Integrated Gradients")
    axes[2].axis("off")
    # axes[3].imshow(a_batch_smoothgrad[index], cmap="hot")
    # axes[3].title.set_text(f"SmoothGrad")
    # axes[3].axis("off")
    plt.tight_layout()
    
    plt.savefig(f'explanations/{file_name}.png')
    plt.show()