In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import pdb
import caffe

In [2]:
last_deploy = 'deploy_SE.prototxt'    # old network deploy file
last_model = 'googlenet_SE.caffemodel'    # model parameters to be pruned

In [None]:
caffe.set_mode_gpu()
caffe.set_device(0)

model_def = last_deploy
model_weights = last_model
mu = np.load('train_mean_global.npy')
mu = mu.mean(1).mean(1)
net = caffe.Net(model_def,      # defines the structure of the model
                model_weights,  # contains the trained weights
                caffe.TEST)     # use test mode (e.g., don't perform dropout)

picture_num = 1
positive_count = 0
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1))
transformer.set_mean('data',mu)
transformer.set_raw_scale('data', 255)
transformer.set_channel_swap('data', (2,1,0))
net.blobs['data'].reshape(picture_num,
                                  3,
                                  224, 224)

layers_dict ={}    # save the channel weights of all convolutional layers.
layers_list = []
for layer_name, blob in net.blobs.iteritems():
    if 'up' in layer_name:
        print(layer_name)
        layers_list.append(layer_name)
        layers_dict[layer_name] = np.zeros(blob.data.size/10)
        print(blob.data.size/10)

image_file = 'train.txt'
train_num = 0
with open(image_file, 'r') as image_list:
    for image in image_list:
        image_source = image.strip().split(' ')[0]
        label = int(image.strip().split(' ')[1])
        image = caffe.io.load_image(image_source)
        transformed_image = transformer.preprocess('data',image)
        net.blobs['data'].data[0] = transformed_image
        output = net.forward()
        for layer_name, blob in net.blobs.iteritems():
            if 'up' in layer_name:
                layers_dict[layer_name] = layers_dict[layer_name] + blob.data[0,:,0,0]    # the scale factors saved in layers named "xxx_up".
        print(image_source + 'processed.')
        train_num = train_num + 1
print('Done.')
for layer_name in layers_list:
    layers_dict[layer_name] = layers_dict[layer_name]/train_num    # compute the channel weights, use mean of all the scale factors we collected when training.

In [None]:
for layer_name in layers_list:
    plt.figure(layer_name)
    plt.ylabel(layer_name)
    plt.ylim(0.0,1.0,10)
    yticks = np.linspace(0.00,1.00,11)
    plt.yticks(yticks)
    plt.plot(layers_dict[layer_name])
    
c_rate = 0.5    # prune rate. If you want to set how many channels to prune directly, you can set this value. For example, 0.5 means prune 50% channels in a convolutionnal layer.
#k = 0.1    # k in our paper. If you want to dynamically control how many channels to prune automatically, use this.
sorted_layers_dict = {}
for layer_name in layers_list:
    sorted_layers_dict[layer_name] = sorted(layers_dict[layer_name])
    
threshold_dict = {}    # save the prune threshold of each convolutional layer.
for layer in layers_list:
    #threshold_dict[layer] = np.mean(layers_dict[layer])-(k-np.std(layers_dict[layer]))    # dynamically compute the prune threshold.
    threshold_dict[layer] = sorted_layers_dict[layer][int(len(sorted_layers_dict[layer])*c_rate)]    # directly use prune rate to compute the prune threshold.
threshold_dict

channels = {}    # This contains the channels whose channel weight is larger than the threshold, so we can compute the number of channels after pruning and use the remaining channel parameters for next train.
for layer in layers_list:
    channels[layer] = np.where(layers_dict[layer]>threshold_dict[layer])[0]
    print layers_dict[layer].size,channels[layer].size

In [None]:
# Now you can get your new network structure after pruning because you have known how many channels in each convolutional layer.
# And you can get remaining model parameters after pruning because you have known which channels are pruned and which remained.
# Finally, you should train the network pruned using the remaining model parameters to make up for the loss of precision.