In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import time

# 确保使用 GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU not available, using CPU. Performance gains from pruning might not be evident.")

# --- 1. 定义测试卷积层 ---
class TestConvNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(TestConvNet, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=0).to(device)

    def forward(self, x):
        return self.conv(x)

Using GPU: NVIDIA GeForce GTX 1660 Ti


In [11]:
def benchmark_pruning(model, input_tensor, num_runs=1000):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    # CUDA warmup
    for _ in range(100):
        _ = model(input_tensor)

    start_event.record()
    for _ in range(num_runs):
        _ = model(input_tensor)
    end_event.record()
    torch.cuda.synchronize()
    return start_event.elapsed_time(end_event) / num_runs

In [21]:
in_channels = 64
out_channels = 128
kernel_size = 3
input_batch_size = 32
input_height = 32
input_width = 32
pruning_ratio = 0.5

torch.manual_seed(42)
input_data = torch.randn(input_batch_size, in_channels, input_height, input_width).to(device)

print(f"\n--- Benchmarking Convolutional Layer Pruning (Input: {input_data.shape}) ---")

print("\n--- Original Model ---")
model_original = TestConvNet(in_channels, out_channels, kernel_size)
original_time = benchmark_pruning(model_original, input_data)
print(f"Original Conv2d layer inference time: {original_time:.4f} ms")
print(f"Original parameters: {sum(p.numel() for p in model_original.parameters())}\n")
non_zero_params = torch.count_nonzero(model_original.conv.weight.data).item()
print(f"Non-zero parameters: {non_zero_params}\n")


--- Benchmarking Convolutional Layer Pruning (Input: torch.Size([32, 64, 32, 32])) ---

--- Original Model ---
Original Conv2d layer inference time: 1.1373 ms
Original parameters: 73856

Non-zero parameters: 73728



In [43]:
model_masked_weight_pruned = TestConvNet(in_channels, out_channels, kernel_size)
with torch.no_grad():
    for name, module in model_masked_weight_pruned.named_modules():
            if isinstance(module, nn.Conv2d):
                threshold = torch.quantile(torch.abs(module.weight), q=pruning_ratio)
                mask = torch.abs(module.weight) >= threshold
                module.weight.data *= mask.float()
    #print(model_masked_weight_pruned.conv.weight.data[1])
    masked_weight_pruned_time = benchmark_pruning(model_masked_weight_pruned, input_data)
    print(f"masked_weight_pruned Conv2d layer inference time: {masked_weight_pruned_time:.4f} ms")
    print(f"masked_weight_pruned parameters: {sum(p.numel() for p in model_masked_weight_pruned.parameters())}")
    non_zero_params = torch.count_nonzero(model_masked_weight_pruned.conv.weight.data).item()
    print(f"Non-zero parameters: {non_zero_params}\n")

masked_weight_pruned Conv2d layer inference time: 1.1366 ms
masked_weight_pruned parameters: 73856
Non-zero parameters: 36864



In [38]:
model_masked_kernel_pruned = TestConvNet(in_channels, out_channels, kernel_size)
with torch.no_grad():
    for name, module in model_masked_kernel_pruned.named_modules():
            if isinstance(module, nn.Conv2d):
                norms_per_filter_input_channel = torch.norm(module.weight, p=1, dim=(2, 3))
                num_elements_to_prune_per_filter = int(pruning_ratio * in_channels)
                
                for i in range(out_channels):
                    # Get the norms for the current output filter across all input channels
                    current_filter_input_channel_norms = norms_per_filter_input_channel[i, :]
                    
                    # Find the threshold for this specific filter's input channels
                    # Sort the norms to find the 'pruning_amount' smallest ones
                    sorted_norms, sorted_indices = torch.sort(current_filter_input_channel_norms, descending=False)
                    
                    # Get the indices of the input channels to prune for this output filter
                    indices_to_prune_for_this_filter = sorted_indices[:num_elements_to_prune_per_filter]
                    
                    # Set the corresponding kernel weights to zero
                    # model_per_filter_kernel_pruned.conv.weight.data[i, indices_to_prune_for_this_filter, :, :] = 0.0
                    # Use original_weight_data for modification within the loop
                    module.weight.data[i, indices_to_prune_for_this_filter, :, :] = 0.0
    # print(model_masked_kernel_pruned.conv.weight.data[1])
    masked_kernel_pruned_time = benchmark_pruning(model_masked_kernel_pruned, input_data)
    print(f"masked_kernel_pruned Conv2d layer inference time: {masked_kernel_pruned_time:.4f} ms")
    print(f"masked_kernel_pruned parameters: {sum(p.numel() for p in model_masked_kernel_pruned.parameters())}")
    non_zero_params = torch.count_nonzero(model_masked_kernel_pruned.conv.weight.data).item()
    print(f"Non-zero parameters: {non_zero_params}\n")

masked_kernel_pruned Conv2d layer inference time: 1.1350 ms
masked_kernel_pruned parameters: 73856
Non-zero parameters: 36864



In [39]:
model_masked_filter_pruned = TestConvNet(in_channels, out_channels, kernel_size)
with torch.no_grad():
    for name, module in model_masked_filter_pruned.named_modules():
            if isinstance(module, nn.Conv2d):
                norms_per_filter = torch.norm(module.weight, p=1, dim=(1, 2, 3))
                num_filters_to_prune = int(pruning_ratio * out_channels)
                
                _, indices_to_prune = torch.topk(norms_per_filter, num_filters_to_prune, largest=False)

                module.weight.data[indices_to_prune, :, :, :] = 0.0
    #print(model_masked_filter_pruned.conv.weight.data[1])
    masked_filter_pruned_time = benchmark_pruning(model_masked_filter_pruned, input_data)
    print(f"masked_filter_pruned Conv2d layer inference time: {masked_filter_pruned_time:.4f} ms")
    print(f"masked_filter_pruned parameters: {sum(p.numel() for p in model_masked_filter_pruned.parameters())}")
    non_zero_params = torch.count_nonzero(model_masked_filter_pruned.conv.weight.data).item()
    print(f"Non-zero parameters: {non_zero_params}\n")

masked_filter_pruned Conv2d layer inference time: 1.1321 ms
masked_filter_pruned parameters: 73856
Non-zero parameters: 36864



In [48]:
temp_model_for_indices = TestConvNet(in_channels, out_channels, kernel_size)
with torch.no_grad():
    output_filter_l1_norms_temp = temp_model_for_indices.conv.weight.data.norm(p=1, dim=(1, 2, 3))
    num_filters_to_prune = int(pruning_ratio * out_channels)
    # 获取要保留的过滤器索引 (值最大的那些)
    # 注意这里 largest=True 是为了获取要保留的，而不是要剪枝的
    num_filters_to_retain = out_channels - num_filters_to_prune
    _, retained_filters_indices_sorted = torch.topk(output_filter_l1_norms_temp, num_filters_to_retain, largest=True)
    retained_filters_indices = retained_filters_indices_sorted.sort().values # 确保索引是有序的
    
model_no_mask_filter_pruned = TestConvNet(in_channels, len(retained_filters_indices), kernel_size)
model_no_mask_filter_pruned.conv.weight.data = temp_model_for_indices.conv.weight.data[retained_filters_indices, :, :, :]
if temp_model_for_indices.conv.bias is not None:
    model_no_mask_filter_pruned.conv.bias.data = temp_model_for_indices.conv.bias.data[retained_filters_indices]
    
no_mask_filter_pruned_time = benchmark_pruning(model_no_mask_filter_pruned, input_data)
print(f"no_mask_filter_pruned Conv2d layer inference time: {no_mask_filter_pruned_time:.4f} ms")
print(f"no_mask_filter_pruned parameters: {sum(p.numel() for p in model_no_mask_filter_pruned.parameters())}")
non_zero_params = torch.count_nonzero(model_no_mask_filter_pruned.conv.weight.data).item()
print(f"Non-zero parameters: {non_zero_params}\n")

no_mask_filter_pruned Conv2d layer inference time: 0.5739 ms
no_mask_filter_pruned parameters: 36928
Non-zero parameters: 36864

