<a href="https://colab.research.google.com/github/swalehaparvin/Model-Compression-Techniques/blob/main/Activation_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
# Activation pruning function
def prune_activations(activations, pruning_ratio=0.5):
    """
    Prune activations by setting the smallest values to zero

    Args:
        activations: Tensor of activations
        pruning_ratio: Percentage of activations to prune (0.0 to 1.0)

    Returns:
        Pruned activations tensor
    """
    # Calculate threshold based on pruning ratio
    k = int(pruning_ratio * activations.numel())

    if k > 0:
        # Get the k-th smallest value as threshold
        threshold = torch.kthvalue(activations.flatten(), k).values

        # Create mask: 1 for values above threshold, 0 for values below
        mask = (activations > threshold).float()
        # Apply mask to prune activations
        pruned_activations = activations * mask
    else:
        pruned_activations = activations

    return pruned_activations

In [5]:
# Custom layer with activation pruning
class PrunedReLU(nn.Module):
    def __init__(self, pruning_ratio=0.3):
        super(PrunedReLU, self).__init__()
        self.pruning_ratio = pruning_ratio

    def forward(self, x):
        x = F.relu(x)
        x = prune_activations(x, self.pruning_ratio)
        return x

In [6]:
# Network with activation pruning
class PrunedNet(nn.Module):
    def __init__(self, pruning_ratio=0.3):
        super(PrunedNet, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.pruned_relu = PrunedReLU(pruning_ratio)

    def forward(self, x):
        x = self.pruned_relu(self.fc1(x))
        x = self.pruned_relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [9]:
# Example usage
if __name__ == "__main__":
    # Create sample data
    batch_size = 32
    input_size = 784
    x = torch.randn(batch_size, input_size)

    # Test activation pruning function
    print("Testing activation pruning:")
    activations = torch.randn(4, 4)
    print("Original activations:")
    print(activations)

    pruned = prune_activations(activations, pruning_ratio=0.5)
    print("\nPruned activations (50% pruning):")
    print(pruned)

    # Test networks
    print("\n" + "="*50)
    print("Testing networks:")

    # Regular network
    regular_net = SimpleNet()
    regular_output = regular_net(x)
    print(f"Regular network output shape: {regular_output.shape}")

    # Pruned network
    pruned_net = PrunedNet(pruning_ratio=0.3)
    pruned_output = pruned_net(x)
    print(f"Pruned network output shape: {pruned_output.shape}")

    # Count non-zero activations
    with torch.no_grad():
        # Get activations from first layer
        activations_regular = F.relu(regular_net.fc1(x))
        activations_pruned = pruned_net.pruned_relu(pruned_net.fc1(x))

        non_zero_regular = torch.sum(activations_regular != 0).item()
        non_zero_pruned = torch.sum(activations_pruned != 0).item()
        total_activations = activations_regular.numel()
        print(f"\nActivation sparsity:")
        print(f"Regular network: {non_zero_regular}/{total_activations} non-zero ({non_zero_regular/total_activations*100:.1f}%)")
        print(f"Pruned network: {non_zero_pruned}/{total_activations} non-zero ({non_zero_pruned/total_activations*100:.1f}%)")

Testing activation pruning:
Original activations:
tensor([[-0.5588, -0.4993, -0.6085,  1.2434],
        [-0.5221,  0.5522, -0.3042,  0.0119],
        [-0.2190,  0.3823,  0.2324, -0.7004],
        [-0.6220, -2.4627,  0.7492,  0.8651]])

Pruned activations (50% pruning):
tensor([[-0.0000, -0.0000, -0.0000,  1.2434],
        [-0.0000,  0.5522, -0.0000,  0.0119],
        [-0.2190,  0.3823,  0.2324, -0.0000],
        [-0.0000, -0.0000,  0.7492,  0.8651]])

Testing networks:
Regular network output shape: torch.Size([32, 10])
Pruned network output shape: torch.Size([32, 10])

Activation sparsity:
Regular network: 4125/8192 non-zero (50.4%)
Pruned network: 4050/8192 non-zero (49.4%)
