In [1]:
from torchvision.models.googlenet import googlenet

In [2]:
import torch
import torch.nn as nn
import torch_pruning as tp
import random

In [3]:
def random_prune(model, example_inputs, output_transform):
    model.cpu().eval()
    prunable_module_type = ( nn.Conv2d, nn.BatchNorm2d )
    prunable_modules = [ m for m in model.modules() if isinstance(m, prunable_module_type) ]
    ori_size = tp.utils.count_params( model )
    DG = tp.DependencyGraph().build_dependency( model, example_inputs=example_inputs, output_transform=output_transform )
    for layer_to_prune in prunable_modules:
        # select a layer

        if isinstance( layer_to_prune, nn.Conv2d ):
            prune_fn = tp.prune_conv
        elif isinstance(layer_to_prune, nn.BatchNorm2d):
            prune_fn = tp.prune_batchnorm

        ch = tp.utils.count_prunable_channels( layer_to_prune )
        rand_idx = random.sample( list(range(ch)), min( ch//2, 10 ) )
        plan = DG.get_pruning_plan( layer_to_prune, prune_fn, rand_idx)
        plan.exec()

    print(model)
    with torch.no_grad():
        out = model( example_inputs )
        if output_transform:
            out = output_transform(out)
        print('googlenet')
        print( "  Params: %s => %s"%( ori_size, tp.utils.count_params(model) ) )
        print( "  Output: ", out.shape )
        print("------------------------------------------------------\n")


In [4]:
example_inputs = torch.randn(1,3,256,256)
output_transform = None

model = googlenet(pretrained=True)