In [38]:
from ptflops import get_model_complexity_info
import torch.nn as nn
import torch
from pruneshift.networks import network
from pruneshift.prune import prune
import numpy as np

In [39]:
type(torch.tensor([1, 2.]).mean().item())

float

In [40]:
floats = []
for 
net = network("cifar100_resnet50")

In [None]:
prune(net2, "global_weight", 64)


In [48]:
# We can pass custom module hooks to allow specific calculation for pruned modules.
# Write a custom function wrapper to scale all the conv_counter.

net1 = network("cifar100_resnet50")
net2 = network("cifar100_resnet50")


def consider_pruning(flops: int, module: nn.Module, param_name: str) -> int:
    if not hasattr(module, param_name + "_mask"):
        factor = 1.
    else:
        factor = getattr(module, param_name + "_mask").mean().item()

    return int(factor * flops)


def conv_flops_counter_hook(conv_module, input, output):
    # Can have multiple inputs, getting the first one
    input = input[0]

    batch_size = input.shape[0]
    output_dims = list(output.shape[2:])

    kernel_dims = list(conv_module.kernel_size)
    in_channels = conv_module.in_channels
    out_channels = conv_module.out_channels
    groups = conv_module.groups

    filters_per_channel = out_channels // groups
    conv_per_position_flops = int(np.prod(kernel_dims)) * \
        in_channels * filters_per_channel

    active_elements_count = batch_size * int(np.prod(output_dims))

    overall_conv_flops = conv_per_position_flops * active_elements_count
    overall_conv_flops = consider_pruning(overall_conv_flops, conv_module, "weight")

    bias_flops = 0

    if conv_module.bias is not None:
        bias_flops = out_channels * active_elements_count
        # Modification!
        bias_flops = consider_pruning(bias_flops, conv_module, "bias")

    overall_flops = overall_conv_flops + bias_flops

    conv_module.__flops__ += int(overall_flops)


def linear_flops_counter_hook(module, input, output):
    input = input[0]
    # pytorch checks dimensions, so here we don't care much
    output_last_dim = output.shape[-1]
    if module.bias is not None:
        bias_flops = consider_pruning(output_last_dim, module, "bias")
    else:
        bias_flops = 0
        
    weight_flops = int(np.prod(input.shape)) * output_last_dim
    weight_flops = consider_pruning(weight_flops, module, "weight")
    
    module.__flops__ += weight_flops + bias_flops


def get_model_complexity_prune(model, input_res, print_per_layer_stat=False, as_strings=False):
    """ Calculates the model complexity taking into account pruning in conv2d and linear layers."""
    custom_modules_hooks = {nn.Conv2d: conv_flops_counter_hook,
                           nn.Linear: linear_flops_counter_hook}
    return get_model_complexity_info(model, input_res, print_per_layer_stat, as_strings=as_strings, custom_modules_hooks=custom_modules_hooks)

# print(get_model_complexity_info(net1, (3, 32, 32), False))
print(get_model_complexity_prune(net2, input_res=(3, 32, 32)))

(10722134.0, 23705252)


In [52]:
layer = nn.Conv2d(3, 10, 2)

get_model_complexity_prune(layer, (3, 10, 10))

(5264.0, 130)

In [54]:
layer._forward_hooks

OrderedDict()