In [1]:
import copy
import torch
import numpy as np

In [14]:
def print_nonzeros(model):
    nonzero = total = 0
    for name, p in model.named_parameters():
        tensor = p.data.cpu().numpy()
        nz_count = np.count_nonzero(tensor)
        total_params = np.prod(tensor.shape)
        nonzero += nz_count
        total += total_params
        print(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}')
    print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x  ({100 * (total-nonzero) / total:6.2f}% pruned)')
    return (round((nonzero/total)*100,1))

In [15]:
def original_initialization(mask_temp, initial_state_dict):
    global model
    
    step = 0
    for name, param in model.named_parameters(): 
        if "weight" in name: 
            weight_dev = param.device
            param.data = torch.from_numpy(mask_temp[step] * initial_state_dict[name].cpu().numpy()).to(weight_dev)
            step = step + 1
        if "bias" in name:
            param.data = initial_state_dict[name]
    step = 0

In [16]:
def checkdir(directory):
            if not os.path.exists(directory):
                os.makedirs(directory)

In [17]:
def weight_init(m):
    '''
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.LSTMCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRU):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)

In [18]:
def prune_by_percentile(percent, resample=False, reinit=False,**kwargs):
        global step
        global mask
        global model

        # Calculate percentile value
        step = 0
        for name, param in model.named_parameters():

            # We do not prune bias term
            if 'weight' in name:
                tensor = param.data.cpu().numpy()
                alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values
                percentile_value = np.percentile(abs(alive), percent)

                # Convert Tensors to numpy and calculate
                weight_dev = param.device
                new_mask = np.where(abs(tensor) < percentile_value, 0, mask[step])
                
                # Apply new weight and mask
                param.data = torch.from_numpy(tensor * new_mask).to(weight_dev)
                mask[step] = new_mask
                step += 1
        step = 0



In [19]:
# Function to make an empty mask of the same size as the model
def make_mask(model):
    global step
    global mask
    step = 0
    for name, param in model.named_parameters(): 
        if 'weight' in name:
            step = step + 1
    mask = [None]* step 
    step = 0
    for name, param in model.named_parameters(): 
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            mask[step] = np.ones_like(tensor)
            step = step + 1
    step = 0

In [20]:
def train(model, train_loader, optimizer, criterion,scheduler):
    EPS = 1e-6
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    for batch_idx, (imgs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        #imgs, targets = next(train_loader)
        imgs, targets = imgs.to(device), targets.to(device)
        output = model(imgs)
        train_loss = criterion(output, targets)
        train_loss.backward()

        # # Freezing Pruned weights by making their gradients Zero
        # for name, p in model.named_parameters():
        #     if 'weight' in name:
        #         tensor = p.data.cpu().numpy()
        #         grad_tensor = p.grad.data.cpu().numpy()
        #         grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
        #         p.grad.data = torch.from_numpy(grad_tensor).to(device)
        optimizer.step()
    scheduler.step()
    return train_loss.item()



In [21]:
# Function for Testing
def test(model, test_loader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy

In [22]:
def evaluate_model(model, dataloader, device):
    """
    This function evaluates the model using test dataset
    """
    model.eval()
    prediction = torch.Tensor().to(device)
    labels = torch.LongTensor().to(device)

    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            prediction = torch.cat([prediction, model(x_batch)])
            labels = torch.cat([labels, y_batch])
            
    # passing the logits through Softmax layer to get predicted class
    prediction = torch.nn.functional.softmax(prediction, dim=1)
    
    return prediction, labels

In [23]:
def count_params(model):
    """
    This function calculates the number of parameters in a model

    Args:
        model: Model to count parameter
    """
    total_parameters = 0
    for layer, parameter in model.named_parameters():
        total_parameters += torch.count_nonzero(parameter.data)
    return total_parameters

In [24]:
def compute_sparsity_resnet(model):
    """
    Calculate the global sparsity of a ResNet model.

    Args:
        model: The LeNet model for which sparsity is calculated.

    Returns:
        float: The global sparsity of the model.
    """
    conv0_sparsity = (torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()) * 100
    bn0_sparsity = (torch.sum(model.bn1.weight == 0) / model.bn1.weight.nelement()) * 100
    
    conv1_sparsity = (torch.sum(model.layer1[0].conv1.weight == 0) / model.layer1[0].conv1.weight.nelement()) * 100
    conv1_sparsity = (torch.sum(model.layer1[0].bn1.weight == 0) / model.layer1[0].bn1.weight.nelement()) * 100

    conv2_sparsity = (torch.sum(model.layer1[0].conv2.weight == 0) / model.layer1[0].conv2.weight.nelement()) * 100
    conv2_sparsity = (torch.sum(model.layer1[0].bn2.weight == 0) / model.layer1[0].bn2.weight.nelement()) * 100

    conv3_sparsity = (torch.sum(model.layer1[1].conv1.weight == 0) / model.layer1[1].conv1.weight.nelement()) * 100
    conv3_sparsity = (torch.sum(model.layer1[1].bn1.weight == 0) / model.layer1[1].bn1.weight.nelement()) * 100

    conv4_sparsity = (torch.sum(model.layer1[1].conv2.weight == 0) / model.layer1[1].conv2.weight.nelement()) * 100
    conv4_sparsity = (torch.sum(model.layer1[1].bn2.weight == 0) / model.layer1[1].bn2.weight.nelement()) * 100

    conv5_sparsity = (torch.sum(model.layer2[0].conv1.weight == 0) / model.layer2[0].conv1.weight.nelement()) * 100
    conv5_sparsity = (torch.sum(model.layer2[0].bn1.weight == 0) / model.layer2[0].bn1.weight.nelement()) * 100

    conv6_sparsity = (torch.sum(model.layer2[0].conv2.weight == 0) / model.layer2[0].conv2.weight.nelement()) * 100
    conv6_sparsity = (torch.sum(model.layer2[0].bn2.weight == 0) / model.layer2[0].bn2.weight.nelement()) * 100

    conv7_sparsity = (torch.sum(model.layer2[1].conv1.weight == 0) / model.layer2[1].conv1.weight.nelement()) * 100
    conv7_sparsity = (torch.sum(model.layer2[1].bn1.weight == 0) / model.layer2[1].bn1.weight.nelement()) * 100

    conv8_sparsity = (torch.sum(model.layer2[1].conv2.weight == 0) / model.layer2[1].conv2.weight.nelement()) * 100
    conv8_sparsity = (torch.sum(model.layer2[1].bn2.weight == 0) / model.layer2[1].bn2.weight.nelement()) * 100

    conv9_sparsity = (torch.sum(model.layer3[0].conv1.weight == 0) / model.layer3[0].conv1.weight.nelement()) * 100
    conv9_sparsity = (torch.sum(model.layer3[0].bn1.weight == 0) / model.layer3[0].bn1.weight.nelement()) * 100

    conv10_sparsity = (torch.sum(model.layer3[0].conv2.weight == 0) / model.layer3[0].conv2.weight.nelement()) * 100
    conv10_sparsity = (torch.sum(model.layer3[0].bn2.weight == 0) / model.layer3[0].bn2.weight.nelement()) * 100

    conv11_sparsity = (torch.sum(model.layer3[1].conv1.weight == 0) / model.layer3[1].conv1.weight.nelement()) * 100
    conv11_sparsity = (torch.sum(model.layer3[1].bn1.weight == 0) / model.layer3[1].bn1.weight.nelement()) * 100

    conv12_sparsity = (torch.sum(model.layer3[1].conv2.weight == 0) / model.layer3[1].conv2.weight.nelement()) * 100
    conv12_sparsity = (torch.sum(model.layer3[1].bn2.weight == 0) / model.layer3[1].bn2.weight.nelement()) * 100

    conv13_sparsity = (torch.sum(model.layer4[0].conv1.weight == 0) / model.layer4[0].conv1.weight.nelement()) * 100
    conv13_sparsity = (torch.sum(model.layer4[0].bn1.weight == 0) / model.layer4[0].bn1.weight.nelement()) * 100

    conv14_sparsity = (torch.sum(model.layer4[0].conv2.weight == 0) / model.layer4[0].conv2.weight.nelement()) * 100
    conv14_sparsity = (torch.sum(model.layer4[0].bn2.weight == 0) / model.layer4[0].bn2.weight.nelement()) * 100

    conv15_sparsity = (torch.sum(model.layer4[1].conv1.weight == 0) / model.layer4[1].conv1.weight.nelement()) * 100
    conv15_sparsity = (torch.sum(model.layer4[1].bn1.weight == 0) / model.layer4[1].bn1.weight.nelement()) * 100

    conv16_sparsity = (torch.sum(model.layer4[1].conv2.weight == 0) / model.layer4[1].conv2.weight.nelement()) * 100
    conv16_sparsity = (torch.sum(model.layer4[1].bn2.weight == 0) / model.layer4[1].bn2.weight.nelement()) * 100
    
    fc_sparsity = (torch.sum(model.fc.weight == 0) / model.fc.weight.nelement()) * 100

    num =  torch.sum(model.conv1.weight == 0) + torch.sum(model.bn1.weight == 0) + torch.sum(model.layer1[0].conv1.weight == 0) + torch.sum(model.layer1[0].bn1.weight == 0) + torch.sum(model.layer1[0].conv2.weight == 0) +  torch.sum(model.layer1[0].bn2.weight == 0) + torch.sum(model.layer1[1].conv1.weight == 0) +  torch.sum(model.layer1[1].bn1.weight == 0) + torch.sum(model.layer1[1].conv2.weight == 0) + torch.sum(model.layer1[1].bn2.weight == 0) +torch.sum(model.layer2[0].conv1.weight == 0) + torch.sum(model.layer2[0].bn1.weight == 0) + torch.sum(model.layer2[0].conv2.weight == 0) +  torch.sum(model.layer2[0].bn2.weight == 0) + torch.sum(model.layer2[1].conv1.weight == 0) + torch.sum(model.layer2[1].bn1.weight == 0) + torch.sum(model.layer2[1].conv2.weight == 0) + torch.sum(model.layer2[1].bn2.weight == 0) + torch.sum(model.layer3[0].conv1.weight == 0) + torch.sum(model.layer3[0].bn1.weight == 0) + torch.sum(model.layer3[0].conv2.weight == 0) +  torch.sum(model.layer3[0].bn2.weight == 0) + torch.sum(model.layer3[1].conv1.weight == 0) +  torch.sum(model.layer3[1].bn1.weight == 0) + torch.sum(model.layer3[1].conv2.weight == 0) + torch.sum(model.layer3[1].bn2.weight == 0) + torch.sum(model.layer4[0].conv1.weight == 0) + torch.sum(model.layer4[0].bn1.weight == 0) + torch.sum(model.layer4[0].conv2.weight == 0) +  torch.sum(model.layer4[0].bn2.weight == 0) + torch.sum(model.layer4[1].conv1.weight == 0) +  torch.sum(model.layer4[1].bn1.weight == 0) + torch.sum(model.layer4[1].conv2.weight == 0) + torch.sum(model.layer4[1].bn2.weight == 0) + torch.sum(model.fc.weight == 0) 
                                                                                                                                                                                                                                                                                                                                    
    denom =  model.conv1.weight.nelement() +  model.bn1.weight.nelement() + model.layer1[0].conv1.weight.nelement() + model.layer1[0].bn1.weight.nelement() + model.layer1[0].conv2.weight.nelement() + model.layer1[0].bn2.weight.nelement() + model.layer1[1].conv1.weight.nelement() +  model.layer1[1].bn1.weight.nelement() + model.layer1[1].conv2.weight.nelement() + model.layer1[1].bn2.weight.nelement() +  model.layer2[0].conv1.weight.nelement() + model.layer2[0].bn1.weight.nelement() + model.layer2[0].conv2.weight.nelement() + model.layer2[0].bn2.weight.nelement() + model.layer2[1].conv1.weight.nelement() +  model.layer2[1].bn1.weight.nelement() + model.layer2[1].conv2.weight.nelement() + model.layer2[1].bn2.weight.nelement() +  model.layer3[0].conv1.weight.nelement() + model.layer3[0].bn1.weight.nelement() + model.layer3[0].conv2.weight.nelement() + model.layer3[0].bn2.weight.nelement() + model.layer3[1].conv1.weight.nelement() +  model.layer3[1].bn1.weight.nelement() + model.layer3[1].conv2.weight.nelement() + model.layer3[1].bn2.weight.nelement() +  model.layer4[0].conv1.weight.nelement() + model.layer4[0].bn1.weight.nelement() + model.layer4[0].conv2.weight.nelement() + model.layer4[0].bn2.weight.nelement() + model.layer4[1].conv1.weight.nelement() +  model.layer4[1].bn1.weight.nelement() + model.layer4[1].conv2.weight.nelement() + model.layer4[1].bn2.weight.nelement() + model.fc.weight.nelement()
    global_sparsity = num/denom * 100
    return global_sparsity

In [25]:
def image_explanation(index, x_batch, y_batch, a_batch_saliency, a_batch_integrad, a_batch_smoothgrad, file_name):
    """
    Visualize and save explanation maps for a single image using ResNet or VGG.

    Args:
        index : Index of the image in the batch.
        x_batch: Batch of input images.
        y_batch: labels for the input images.
        a_batch_saliency:Explanation maps generated using Integrated Gradients Vanilla Gradient.
        a_batch_integrad : Explanation maps generated using Integrated Gradients.
        a_batch_smoothgrad: Explanation maps generated using SmoothGrad.
        file_name: Name for the file to save the visualizations.
    """
    nr_images = 2
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(nr_images*4, int(nr_images)))
    
    #plot normal
    axes[0].imshow(np.moveaxis(x_batch[index], 0, -1), vmin=0.0, vmax=1.0)
    axes[0].title.set_text(f"Normal Image {y_batch[index].item()}")
    axes[0].axis("off")
    axes[1].imshow(a_batch_saliency[index], cmap="hot")
    axes[1].title.set_text(f"Vanilla Gradient")
    axes[1].axis("off")  
    axes[2].imshow(a_batch_integrad[index], cmap="hot")
    axes[2].title.set_text(f"Integrated Gradients")
    axes[2].axis("off")
    axes[3].imshow(a_batch_smoothgrad[index], cmap="hot")
    axes[3].title.set_text(f"SmoothGrad")
    axes[3].axis("off")
    plt.tight_layout()
    
    plt.savefig(f'Explanations/{file_name}.png')
    plt.show()