In [1]:
import numpy as np

from qonnx.core.datatype import DataType
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes

from qonnx.util.cleanup import cleanup as qonnx_cleanup

from finn.util.visualization import showInNetron
from qonnx.core.modelwrapper import ModelWrapper

from qonnx.custom_op.registry import getCustomOp

import onnx.helper as oh
import qonnx.util.basic as util

In [2]:
# from onnx import __version__, IR_VERSION
# from onnx.defs import onnx_opset_version
# print(f"onnx.__version__={__version__!r}, opset={onnx_opset_version()}, IR_VERSION={IR_VERSION}")

In [3]:
prune_folder = './manual_pruning/pruning/sparse24_mul8/'

# Load Model and Clean

In [4]:
# model_file = './onnx_models/MY_MBLNET_V2_RESNET_classifier__best_mean_F1__BIPOLAR_Out__QONNX.onnx'

model_file = './onnx_models/Mobilenetv2_Mini_Resnet_Sparse24__best_F1__Bipolar.onnx'

In [5]:
qonnx_clean_filename = prune_folder + '00_prune_clean.onnx'
qonnx_cleanup(model_file, out_file=qonnx_clean_filename)

In [6]:
showInNetron(qonnx_clean_filename)

Serving './manual_pruning/pruning/sparse24_mul8/00_prune_clean.onnx' at http://0.0.0.0:8083


# Analyze layers to prune

Check scales of weights that are very close to zero, under epsilon:

$$
0 < abs(scale) < \epsilon
$$

Store all initializers in a list and look for the corresponding convolution afterwards.

In [7]:
model = ModelWrapper(qonnx_clean_filename)

In [8]:
all_inits_names = [init.name for init in model.graph.initializer]

print(f'Number of initializers = {len(all_inits_names)}')

Number of initializers = 376


In [9]:
eps = 1e-10

layers_to_prune = {}

for idx, init_name in enumerate(all_inits_names):
    if "Quant" in init_name and "param1" in init_name:
    # It is a scale value, check it
        np_init = model.get_initializer(init_name)
        np_abs_val = np.abs(np_init)
        zero_idx = (np_abs_val < eps) * (np_abs_val > 0)
        if np.all(zero_idx == False):
            #print(f'Index = {idx}. {init_name} was not appended, as there were no values under epsilon')
            continue
        else:
            zero_layer = np.where(zero_idx == True)[0]
            quant_layer_name = init_name.split("_param")[0]
            layers_to_prune[quant_layer_name] = {1: {*zero_layer}}
            #print(f'Index = {idx}. {init_name} appended, as there were values under epsilon')

In [10]:
print(f'Number of layers to prune: {len(layers_to_prune)}')
for k, v in layers_to_prune.items():
    print(k, v)

Number of layers to prune: 15
Quant_12 {1: {28}}
Quant_13 {1: {28}}
Quant_15 {1: {0, 83, 45, 54}}
Quant_16 {1: {0, 83, 45, 54}}
Quant_18 {1: {64, 33, 2, 99, 102, 7, 40, 41, 12, 13, 112, 51, 20, 89, 59, 94}}
Quant_19 {1: {64, 33, 2, 99, 102, 7, 40, 41, 12, 13, 112, 51, 20, 89, 59, 94}}
Quant_21 {1: {3, 7, 9, 12, 21, 24, 27, 29, 31, 32, 39, 41, 43, 44, 50, 57, 63, 68, 71, 73, 75, 80, 83, 90, 94, 96, 97, 99, 103, 113, 117, 124, 125}}
Quant_22 {1: {3, 7, 9, 12, 21, 24, 27, 29, 31, 32, 39, 41, 43, 44, 50, 57, 63, 68, 71, 73, 75, 80, 83, 90, 94, 96, 97, 99, 103, 113, 117, 124, 125}}
Quant_24 {1: {30}}
Quant_25 {1: {30}}
Quant_26 {1: {10, 51, 14, 7}}
Quant_27 {1: {0, 1, 5, 7, 10, 11, 12, 13, 14, 15, 16, 18, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 34, 35, 37, 38, 45, 47, 49, 51, 54, 56, 57, 58, 59, 66, 67, 68, 70, 71, 72, 73, 74, 75, 76, 78, 79, 82, 85, 89, 92, 93, 97, 98, 100, 103, 104, 109, 110, 111, 112, 115, 117, 118, 121, 126, 127}}
Quant_28 {1: {0, 1, 5, 7, 10, 11, 12, 13, 14, 15, 16, 18

In [11]:
# print(len(layers_to_prune["Quant_30"][1]))

### Get all Convolutions or Linears to be pruned

In [12]:
all_nodes = model.graph.node

convs_to_prune = []

for node in all_nodes:
    for key in layers_to_prune.keys():
        if key == node.name:
            successor_node = model.find_direct_successors(node)[0]
            convs_to_prune.append(successor_node.name)

print("All convolutions to prune")
for conv in convs_to_prune:
    print(conv)

# # Remove last 5 convs, as they are harder to prune
# for i in range(5):
#     convs_to_prune.pop()

# # Print again
# print("\nEasy convolutions to prune")
# for conv in convs_to_prune:
#     print(conv)

All convolutions to prune
Conv_12
Conv_13
Conv_15
Conv_16
Conv_18
Conv_19
Conv_21
Conv_22
Conv_24
Conv_25
Conv_26
Conv_27
Conv_28
Conv_29
Conv_30


### Get Sparsity to compare after pruning

Retrieve all weights from Convolutions and perform:
$$
Sparsity = \frac{N_{Zeros}}{N_{Tensors}}
$$

In [13]:
def get_sparsity(model_wrapper, layers_to_prune):
    
    sparse_dict = {}
    
    for key in layers_to_prune.keys():
        init_name = key + "_param0"
        np_init = model_wrapper.get_initializer(init_name)
        n_zeros = np.count_nonzero(np_init == 0)
        total_values = np_init.size
        sparsity = round(n_zeros/total_values, 2)
        #print(init_name, n_zeros, total_values, sparsity*100)
        sparse_dict[init_name] = {"zeros": n_zeros, "total": total_values, "sparsity": sparsity}

    return sparse_dict

In [14]:
sparsity_before_pruning = get_sparsity(model, layers_to_prune)
for k, v in sparsity_before_pruning.items():
    print(k, v)

Quant_12_param0 {'zeros': 122, 'total': 1152, 'sparsity': 0.11}
Quant_13_param0 {'zeros': 37, 'total': 432, 'sparsity': 0.09}
Quant_15_param0 {'zeros': 315, 'total': 2304, 'sparsity': 0.14}
Quant_16_param0 {'zeros': 86, 'total': 864, 'sparsity': 0.1}
Quant_18_param0 {'zeros': 902, 'total': 4096, 'sparsity': 0.22}
Quant_19_param0 {'zeros': 211, 'total': 1152, 'sparsity': 0.18}
Quant_21_param0 {'zeros': 1358, 'total': 4096, 'sparsity': 0.33}
Quant_22_param0 {'zeros': 352, 'total': 1152, 'sparsity': 0.31}
Quant_24_param0 {'zeros': 226, 'total': 2048, 'sparsity': 0.11}
Quant_25_param0 {'zeros': 56, 'total': 576, 'sparsity': 0.1}
Quant_26_param0 {'zeros': 720, 'total': 4096, 'sparsity': 0.18}
Quant_27_param0 {'zeros': 5019, 'total': 8192, 'sparsity': 0.61}
Quant_28_param0 {'zeros': 642, 'total': 1152, 'sparsity': 0.56}
Quant_29_param0 {'zeros': 5004, 'total': 8192, 'sparsity': 0.61}
Quant_30_param0 {'zeros': 3045, 'total': 8192, 'sparsity': 0.37}


# Prune the Layers Manually

- Convolution + BN: all convolutions are followed by Batch Norm, so they can be pruned together. Out channels will be pruned, so it impacts the next convolution, which input weights must be adapted.

- Convolution + BN + ReLU: if convolution+bn is followed by ReLU, it must be pruned too.

- DW Layer: this is a particular case. If a Conv Layer is pruned and next layer is Depth Wise, it must be pruned to, to fit the groups parameter in the output.

## Process:
```python
def prune_conv(model, conv: str)
```
Args: 
- model: ModelWrapper of the model to prune
- conv: string of the convolution layer to prune
>**Steps**: <br>
>1. Get Conv Node from model.
>2. Find direct predecessors: [1] will be the convolution weights, so store it.
>3. Modify convolution weights.
>4. Modify the output shape of convolution weights.
>5. Modify convolution output shape.
>6. Find direct successor, which is batch norm layer.
>7. Modify batch norm layer: weights and shape.
>8. Find direct successor of batch norm: modify the output shape of ReLU+Quant or only Quant.
>9. Find direct successor of last Quant node: it will be next convolution.
>10. Prune the weights according to output channel pruned in previous convolution.

#### Prune weights of convolution

In [15]:
def prune_conv_weights(model, quant_node, do_mul_val = True, mul_val = 8):

    print(f'\n############ Pruning Weights of {quant_node.name} node ############')
    quant_0 = quant_node.input[0]
    quant_1 = quant_node.input[1]
    np_q0 = model.get_initializer(quant_0)
    np_q1 = model.get_initializer(quant_1)
    print(f'Quant 0 shape: {np_q0.shape}')
    print(f'Quant 1 shape: {np_q1.shape}') 

    np_q1_abs = np.abs(np_q1)
    zero_idx = np.where((np_q1_abs < 1e-10) * (np_q1_abs > 0))[0]
    non_zero_idx = np.where(np_q1_abs > 1e-10)[0]
    print("-------------------------------------")
    print(f'*** Zero IDX, channels to be pruned ({zero_idx.size} elements):\n{zero_idx}')
    print(f'### Non Zero IDX, channels to keep ({non_zero_idx.size} elements):\n{non_zero_idx}')
    print("-------------------------------------")

    # Code added to keep a number of channels multiple of 4 or 8
    # It may cause that some zero channels must be kept
    print(f'\n------ Prune to muliple of {mul_val}: {do_mul_val}')
    if do_mul_val == True:
        non_zero_idx_size = non_zero_idx.size
        non_zero_idx_remainder = non_zero_idx_size % mul_val
        if non_zero_idx_remainder == 0:
            print(f'Elements to keep is multiple of {mul_val}')
        else:
            # Calculate the mean of the values which are not zero, to use it as a default for other values
            # np_q1_mean = np_q1[non_zero_idx].mean()
            np_q1_mean = np_q1[non_zero_idx][0] # uses directly the scale of element [0] of non zero idx
            np_q1_copy = np_q1.copy() # Original array is read-only, so it must be copied first
            np_q1_copy[zero_idx] = np_q1_mean
            np_q1 = np_q1_copy
            print(f'Mean of channels to keep, to use it as default for zero elements: {np_q1_mean}')
            print(f'Elements to keep is not multiple of {mul_val} -> calculate new Non Zero IDX')
            n_to_keep = mul_val - non_zero_idx_remainder
            zero_idx_to_keep = zero_idx[:n_to_keep]
            print(f'{n_to_keep} zero elements must be kept\n{zero_idx_to_keep}')
            non_zero_idx = np.concatenate((zero_idx_to_keep, non_zero_idx))
            non_zero_idx.sort()
            assert non_zero_idx.size % mul_val == 0, f'Non Zero IDX calculated is not multiple of {mul_val}'
            print(f'### Multiple of {mul_val} Non Zero IDX, channels to keep:\n{non_zero_idx}')
            print("-------------------------------------")
    
    new_np_q0 = np_q0[non_zero_idx]
    print(f'New Quant 0 shape: {new_np_q0.shape}')
    new_np_q1 = np_q1[non_zero_idx]
    print(f'New Quant 1 shape: {new_np_q1.shape}')

    model.set_initializer(
        tensor_name = quant_0, 
        tensor_value = new_np_q0)
    model.set_initializer(
        tensor_name = quant_1, 
        tensor_value = new_np_q1)

    ch, k, w, h = model.get_tensor_shape(quant_node.output[0])
    print(f'{quant_node.name} output original shape: {ch, k, w, h}')
    new_ch = new_np_q0.shape[0]
    new_shape = (new_ch, k, w, h)
    print(f'{quant_node.name} output new shape: {new_shape}')

    model.set_tensor_shape(quant_node.output[0], new_shape) 

    return non_zero_idx, zero_idx, new_ch

##### Test weight pruning

In [16]:
def test_prune_conv_weights(model, conv: str):
    
    conv_0_node = model.get_node_from_name(conv)
    conv_0_node_predec = model.find_direct_predecessors(conv_0_node)
    conv_0_weights_node = conv_0_node_predec[1]
    
    non_zero_idx, zero_idx, new_ch = prune_conv_weights(model=model, quant_node=conv_0_weights_node)

    return non_zero_idx, zero_idx, new_ch

In [17]:
model = ModelWrapper(qonnx_clean_filename)
non_zero_idx, zero_idx, new_ch = test_prune_conv_weights(model, "Conv_0")


############ Pruning Weights of Quant_0 node ############
Quant 0 shape: (24, 3, 3, 3)
Quant 1 shape: (24, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (0 elements):
[]
### Non Zero IDX, channels to keep (24 elements):
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
-------------------------------------

------ Prune to muliple of 8: True
Elements to keep is multiple of 8
New Quant 0 shape: (24, 3, 3, 3)
New Quant 1 shape: (24, 1, 1, 1)
Quant_0 output original shape: (24, 3, 3, 3)
Quant_0 output new shape: (24, 3, 3, 3)


In [18]:
test_prune_weights = prune_folder + "01_test_prune_weights.onnx"
model.save(test_prune_weights)

In [19]:
showInNetron(test_prune_weights)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/01_test_prune_weights.onnx' at http://0.0.0.0:8083


#### Prune convolution output

In [20]:
def prune_conv_out(model, conv_node, new_ch):

    print(f'\n############ Convert Convolution Output Shape {conv_node.name} node ############')
    batch, ch, w, h = model.get_tensor_shape(conv_node.output[0])
    print(f'Old Conv output shape: {(batch, ch, w, h)}')
    new_shape = (batch, new_ch, w, h)
    print(f'New Conv output shape: {new_shape}')

    model.set_tensor_shape(conv_node.output[0], new_shape)

    # Change groups
    conv_group = conv_node.attribute[1].i
    if conv_group != 1:
        print(f'---> DW Conv found. Group {conv_group}, changed to {new_ch}')
        conv_node.attribute[1].i = new_ch
    else:
        print(f'DW Conv not found. Group = {conv_group}')

    return new_shape

##### Test conv output pruning

In [21]:
def test_prune_conv_output(model, conv: str):
    
    conv_node = model.get_node_from_name(conv)
    conv_node_predec = model.find_direct_predecessors(conv_node)
    conv_node_weights = conv_node_predec[1]

    # Prune weights
    non_zero_idx, zero_idx, new_ch = prune_conv_weights(model=model, quant_node=conv_node_weights)
    # Update conv out shape
    new_shape = prune_conv_out(model, conv_node, new_ch)

    return non_zero_idx, zero_idx, new_ch, new_shape

In [22]:
model = ModelWrapper(qonnx_clean_filename)
non_zero_idx, zero_idx, new_ch, new_shape = test_prune_conv_output(model, "Conv_0")


############ Pruning Weights of Quant_0 node ############
Quant 0 shape: (24, 3, 3, 3)
Quant 1 shape: (24, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (0 elements):
[]
### Non Zero IDX, channels to keep (24 elements):
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
-------------------------------------

------ Prune to muliple of 8: True
Elements to keep is multiple of 8
New Quant 0 shape: (24, 3, 3, 3)
New Quant 1 shape: (24, 1, 1, 1)
Quant_0 output original shape: (24, 3, 3, 3)
Quant_0 output new shape: (24, 3, 3, 3)

############ Convert Convolution Output Shape Conv_0 node ############
Old Conv output shape: (1, 24, 112, 112)
New Conv output shape: (1, 24, 112, 112)
DW Conv not found. Group = 1


In [23]:
print(f'Non Zero IDX: {non_zero_idx}\nZero IDX: {zero_idx}\nNew Channels: {new_ch}\nNew Shape: {new_shape}')

Non Zero IDX: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
Zero IDX: []
New Channels: 24
New Shape: (1, 24, 112, 112)


In [24]:
test_prune_conv_out = prune_folder + "02_test_prune_conv_out.onnx"
model.save(test_prune_conv_out)

In [25]:
showInNetron(test_prune_conv_out)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/02_test_prune_conv_out.onnx' at http://0.0.0.0:8083


#### Prune batch norm

In [26]:
def prune_bn(model, bn_node, new_shape, non_zero_idx, zero_idx):

    print(f'\n############ Prune Batch Norm {bn_node.name} node ############')
    bn_0 = bn_node.input[1]
    bn_1 = bn_node.input[2]
    bn_2 = bn_node.input[3]
    bn_3 = bn_node.input[4]
    np_bn0 = model.get_initializer(bn_0)
    np_bn1 = model.get_initializer(bn_1)
    np_bn2 = model.get_initializer(bn_2)
    np_bn3 = model.get_initializer(bn_3)
    
    print(f'BN0 shape: {np_bn0.shape}')
    print(f'BN1 shape: {np_bn1.shape}')
    print(f'BN2 shape: {np_bn2.shape}')
    print(f'BN3 shape: {np_bn3.shape}')

    print("-------------------------------------")
    print(f'*** Zero IDX, scale value of channels to be pruned:\n{np_bn0[zero_idx]}')
    print("-------------------------------------")

    # Code to calculate real indices with good values, 
    # as non_zero_idx includes some zero_idx values, to stay in multiple of 4
    real_non_zero_idx = np.setdiff1d(non_zero_idx, np.intersect1d(non_zero_idx, zero_idx))
    # np_bn0_mean = np_bn0[real_non_zero_idx].mean()
    # np_bn1_mean = np_bn1[real_non_zero_idx].mean()
    # np_bn2_mean = np_bn2[real_non_zero_idx].mean()
    # np_bn3_mean = np_bn3[real_non_zero_idx].mean()
    np_bn0_mean = np_bn0[real_non_zero_idx][0]
    np_bn1_mean = np_bn1[real_non_zero_idx][0]
    np_bn2_mean = np_bn2[real_non_zero_idx][0]
    np_bn3_mean = np_bn3[real_non_zero_idx][0]
    print(f'Real Non Zero IDX, after removing Zero IDX to keep multiple of 4:\n{real_non_zero_idx}')
    real_zero_idx = np.intersect1d(zero_idx, non_zero_idx)
    print(f'Real Zero IDX kept:\n{real_zero_idx}')
    print("-------------------------------------")
    
    new_np_bn0 = np_bn0[non_zero_idx]
    new_np_bn0[real_zero_idx] = np_bn0_mean #0. # Replace ultra small scale with element [0] scale
    print(f'New {bn_0} shape: {new_np_bn0.shape}')
    new_np_bn1 = np_bn1[non_zero_idx]
    new_np_bn1[real_zero_idx] = np_bn1_mean #0. # Replace ultra small bias with zero
    print(f'New {bn_1} shape: {new_np_bn1.shape}')
    new_np_bn2 = np_bn2[non_zero_idx] 
    new_np_bn2[real_zero_idx] = np_bn2_mean #0. # Replace ultra small mean with zero
    print(f'New {bn_2} shape: {new_np_bn2.shape}')
    new_np_bn3 = np_bn3[non_zero_idx]
    new_np_bn3[real_zero_idx] = np_bn3_mean # Replace ultra small variance with mean variance, to be in the denominator
    print(f'New {bn_3} shape: {new_np_bn3.shape}')

    model.set_initializer(
        tensor_name = bn_0, 
        tensor_value = new_np_bn0)
    model.set_initializer(
        tensor_name = bn_1, 
        tensor_value = new_np_bn1)
    model.set_initializer(
        tensor_name = bn_2, 
        tensor_value = new_np_bn2)
    model.set_initializer(
        tensor_name = bn_3, 
        tensor_value = new_np_bn3) 

    model.set_tensor_shape(bn_node.output[0], new_shape)     

##### Test prune batch norm

In [27]:
def test_prune_conv_bn(model, conv: str):
    
    conv_node = model.get_node_from_name(conv)
    conv_node_predec = model.find_direct_predecessors(conv_node)
    conv_node_weights = conv_node_predec[1]

    # Prune weights
    non_zero_idx, zero_idx, new_ch = prune_conv_weights(model=model, quant_node=conv_node_weights)
    # Update conv out shape
    new_shape = prune_conv_out(model=model, conv_node=conv_node, new_ch=new_ch)
    # Prune batch norm
    bn_node = model.find_direct_successors(conv_node)[0]
    prune_bn(model, bn_node, new_shape, non_zero_idx, zero_idx)

    return non_zero_idx, zero_idx, new_ch, new_shape

In [28]:
model = ModelWrapper(qonnx_clean_filename)
non_zero_idx, zero_idx, new_ch, new_shape = test_prune_conv_bn(model, "Conv_0")


############ Pruning Weights of Quant_0 node ############
Quant 0 shape: (24, 3, 3, 3)
Quant 1 shape: (24, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (0 elements):
[]
### Non Zero IDX, channels to keep (24 elements):
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
-------------------------------------

------ Prune to muliple of 8: True
Elements to keep is multiple of 8
New Quant 0 shape: (24, 3, 3, 3)
New Quant 1 shape: (24, 1, 1, 1)
Quant_0 output original shape: (24, 3, 3, 3)
Quant_0 output new shape: (24, 3, 3, 3)

############ Convert Convolution Output Shape Conv_0 node ############
Old Conv output shape: (1, 24, 112, 112)
New Conv output shape: (1, 24, 112, 112)
DW Conv not found. Group = 1

############ Prune Batch Norm BatchNormalization_0 node ############
BN0 shape: (24,)
BN1 shape: (24,)
BN2 shape: (24,)
BN3 shape: (24,)
-------------------------------------
*** Zero IDX, scale value of channels to be pruned

In [29]:
test_prune_conv_bn = prune_folder + "03_test_prune_conv_bn.onnx"
model.save(test_prune_conv_bn)

In [30]:
showInNetron(test_prune_conv_bn)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/03_test_prune_conv_bn.onnx' at http://0.0.0.0:8083


#### Prune ReLU

In [31]:
def prune_relu(model, relu_node, new_shape):

    print(f'\n############ Update output shape of {relu_node.name} node ############')
    print(f'New shape: {new_shape}')
    model.set_tensor_shape(relu_node.output[0], new_shape)               

##### Test update ReLU output shape

In [32]:
def test_prune_relu(model, conv: str):
    
    conv_node = model.get_node_from_name(conv)
    conv_node_predec = model.find_direct_predecessors(conv_node)
    conv_node_weights = conv_node_predec[1]

    # Prune weights
    non_zero_idx, zero_idx, new_ch = prune_conv_weights(model=model, quant_node=conv_node_weights)
    # Update conv out shape
    new_shape = prune_conv_out(model=model, conv_node=conv_node, new_ch=new_ch)
    # Prune batch norm
    bn_node = model.find_direct_successors(conv_node)[0]
    prune_bn(model, bn_node, new_shape, non_zero_idx, zero_idx)
    
    # Find Batch Norm successor
    bn_successor_node = model.find_direct_successors(bn_node)[0]
    if "Relu" in bn_successor_node.name:
        # Prune ReLU output
        relu_node = bn_successor_node
        prune_relu(model=model, relu_node=relu_node, new_shape=new_shape)

    return non_zero_idx, zero_idx, new_ch, new_shape

In [33]:
model = ModelWrapper(qonnx_clean_filename)
non_zero_idx, zero_idx, new_ch, new_shape = test_prune_relu(model, "Conv_0")


############ Pruning Weights of Quant_0 node ############
Quant 0 shape: (24, 3, 3, 3)
Quant 1 shape: (24, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (0 elements):
[]
### Non Zero IDX, channels to keep (24 elements):
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
-------------------------------------

------ Prune to muliple of 8: True
Elements to keep is multiple of 8
New Quant 0 shape: (24, 3, 3, 3)
New Quant 1 shape: (24, 1, 1, 1)
Quant_0 output original shape: (24, 3, 3, 3)
Quant_0 output new shape: (24, 3, 3, 3)

############ Convert Convolution Output Shape Conv_0 node ############
Old Conv output shape: (1, 24, 112, 112)
New Conv output shape: (1, 24, 112, 112)
DW Conv not found. Group = 1

############ Prune Batch Norm BatchNormalization_0 node ############
BN0 shape: (24,)
BN1 shape: (24,)
BN2 shape: (24,)
BN3 shape: (24,)
-------------------------------------
*** Zero IDX, scale value of channels to be pruned

In [34]:
test_prune_relu = prune_folder + "04_test_prune_relu.onnx"
model.save(test_prune_relu)

In [35]:
showInNetron(test_prune_relu)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/04_test_prune_relu.onnx' at http://0.0.0.0:8083


#### Prune quant Output

In [36]:
def prune_quant_out(model, quant_node, new_shape):

    print(f'\n############ Update output shape of {quant_node.name} node ############')
    print(f'New shape: {new_shape}')
    model.set_tensor_shape(quant_node.output[0], new_shape)              

##### Test update of quant output

In [37]:
def test_prune_quant_out(model, conv: str):
    
    conv_node = model.get_node_from_name(conv)
    conv_node_predec = model.find_direct_predecessors(conv_node)
    conv_node_weights = conv_node_predec[1]

    # Prune weights
    non_zero_idx, zero_idx, new_ch = prune_conv_weights(model=model, quant_node=conv_node_weights)
    # Update conv out shape
    new_shape = prune_conv_out(model=model, conv_node=conv_node, new_ch=new_ch)
    # Prune batch norm
    bn_node = model.find_direct_successors(conv_node)[0]
    prune_bn(model, bn_node, new_shape, non_zero_idx, zero_idx)
    
    # Find Batch Norm successor
    bn_successor_node = model.find_direct_successors(bn_node)[0]
    if "Relu" in bn_successor_node.name:
        # Prune ReLU output
        prune_relu(model=model, relu_node=bn_successor_node, new_shape=new_shape)
        # Update successor to ReLU quant node
        bn_successor_node = model.find_direct_successors(bn_successor_node)[0]
    if "Quant" in bn_successor_node.name:
        prune_quant_out(model, bn_successor_node, new_shape)   
    else:
        raise Exception(f'Node following BN is not ReLU or Quant: {bn_successor_node.name}')

    return non_zero_idx, zero_idx, new_ch, new_shape

In [38]:
model = ModelWrapper(qonnx_clean_filename)
non_zero_idx, zero_idx, new_ch, new_shape = test_prune_quant_out(model, "Conv_0")


############ Pruning Weights of Quant_0 node ############
Quant 0 shape: (24, 3, 3, 3)
Quant 1 shape: (24, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (0 elements):
[]
### Non Zero IDX, channels to keep (24 elements):
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
-------------------------------------

------ Prune to muliple of 8: True
Elements to keep is multiple of 8
New Quant 0 shape: (24, 3, 3, 3)
New Quant 1 shape: (24, 1, 1, 1)
Quant_0 output original shape: (24, 3, 3, 3)
Quant_0 output new shape: (24, 3, 3, 3)

############ Convert Convolution Output Shape Conv_0 node ############
Old Conv output shape: (1, 24, 112, 112)
New Conv output shape: (1, 24, 112, 112)
DW Conv not found. Group = 1

############ Prune Batch Norm BatchNormalization_0 node ############
BN0 shape: (24,)
BN1 shape: (24,)
BN2 shape: (24,)
BN3 shape: (24,)
-------------------------------------
*** Zero IDX, scale value of channels to be pruned

In [39]:
test_prune_quant_relu = prune_folder + "05_test_prune_quant_relu.onnx"
model.save(test_prune_quant_relu)

In [40]:
showInNetron(test_prune_quant_relu)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/05_test_prune_quant_relu.onnx' at http://0.0.0.0:8083


#### Prune Next Conv In Channels -> Weights

In [41]:
def prune_next_conv_inp(model, conv_node, non_zero_idx, zero_idx):

    print(f'\n............ Pruning Weights of {conv_node.name} node ............')   
    conv_node_predec = model.find_direct_predecessors(conv_node)
    quant_node = conv_node_predec[1] # Quant node is [1]     
    
    quant_0 = quant_node.input[0]
    np_q0 = model.get_initializer(quant_0)
    print(f'{quant_0} shape: {np_q0.shape}')

    new_np_q0 = np_q0[:, non_zero_idx]
    new_shape = new_np_q0.shape
    print(f'New {quant_0} shape: {new_shape}')
    
    print("-------------------------------------")
    all_weights_zero = np.all(np_q0[:, zero_idx] == 0.)
    print(f'*** All weights removed are zero? {all_weights_zero}')
    print(f'### Non Zero IDX, channels to keep:\n{non_zero_idx}')
    print("-------------------------------------")

    model.set_initializer(
        tensor_name = quant_0, 
        tensor_value = new_np_q0) 

    print(f'Modify output shape of {quant_node.name} node: {new_shape}')
    model.set_tensor_shape(quant_node.output[0], new_shape) 

##### Test Prune Next Convolution

In [42]:
def test_prune_next_conv_inp(model, conv: str):
    
    conv_node = model.get_node_from_name(conv)
    conv_node_predec = model.find_direct_predecessors(conv_node)
    conv_node_weights = conv_node_predec[1]

    # Prune weights
    non_zero_idx, zero_idx, new_ch = prune_conv_weights(model=model, quant_node=conv_node_weights)
    # Update conv out shape
    new_shape = prune_conv_out(model=model, conv_node=conv_node, new_ch=new_ch)
    # Prune batch norm
    bn_node = model.find_direct_successors(conv_node)[0]
    prune_bn(model, bn_node, new_shape, non_zero_idx, zero_idx)
    
    # Find Batch Norm successor
    bn_successor_node = model.find_direct_successors(bn_node)[0]
    if "Relu" in bn_successor_node.name:
        # Prune relu output
        prune_relu(model=model, relu_node=bn_successor_node, new_shape=new_shape)
        # Update successor to relu quant node
        bn_successor_node = model.find_direct_successors(bn_successor_node)[0]
    if "Quant" in bn_successor_node.name:
        # Always update the shape of Quant Node: it will be preceded by relu or batch norm
        prune_quant_out(model, bn_successor_node, new_shape)   
    else:
        raise Exception(f'Node following BN is not ReLU or Quant: {bn_successor_node.name}')

    # Prune next conv weights, so everything fits
    next_successor_node = model.find_direct_successors(bn_successor_node)[0]
    if "Conv" in next_successor_node.name:
        conv_succesor_node = next_successor_node
        print(f'\nNext successor node is a convolution: {conv_succesor_node.name}')
        # Check if next conv is DW. If so, skip, as whole pruning process must be done
        conv_group = conv_succesor_node.attribute[1].i
        if conv_group != 1:
            print(f'---> DW Conv found. Group = {conv_group}. Skip')
        else:
            print(f'DW Conv not found. Group = {conv_group}. Prune Conv Weights')
            prune_next_conv_inp(model, conv_succesor_node, non_zero_idx, zero_idx)
    # Successor could be Average Pool too, keep in mind
    elif "AveragePool" in next_successor_node.name:
        avgpool_succesor_node = next_successor_node
        print(f'\nNext successor node is average pooling: {avgpool_succesor_node.name}')
            
    return non_zero_idx, zero_idx, new_ch, new_shape

In [43]:
model = ModelWrapper(qonnx_clean_filename)
non_zero_idx, zero_idx, new_ch, new_shape = test_prune_next_conv_inp(model, "Conv_0")


############ Pruning Weights of Quant_0 node ############
Quant 0 shape: (24, 3, 3, 3)
Quant 1 shape: (24, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (0 elements):
[]
### Non Zero IDX, channels to keep (24 elements):
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
-------------------------------------

------ Prune to muliple of 8: True
Elements to keep is multiple of 8
New Quant 0 shape: (24, 3, 3, 3)
New Quant 1 shape: (24, 1, 1, 1)
Quant_0 output original shape: (24, 3, 3, 3)
Quant_0 output new shape: (24, 3, 3, 3)

############ Convert Convolution Output Shape Conv_0 node ############
Old Conv output shape: (1, 24, 112, 112)
New Conv output shape: (1, 24, 112, 112)
DW Conv not found. Group = 1

############ Prune Batch Norm BatchNormalization_0 node ############
BN0 shape: (24,)
BN1 shape: (24,)
BN2 shape: (24,)
BN3 shape: (24,)
-------------------------------------
*** Zero IDX, scale value of channels to be pruned

In [44]:
test_prune_next_conv_inp_file = prune_folder + "06_test_prune_next_conv_inp.onnx"
model.save(test_prune_next_conv_inp_file)

In [45]:
showInNetron(test_prune_next_conv_inp_file)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/06_test_prune_next_conv_inp.onnx' at http://0.0.0.0:8083


#### Prune add node

In [46]:
def prune_add_node(model, add_node, new_shape):

    print(f'\n############ Update output shape of {add_node.name} node ############')
    print(f'New shape: {new_shape}')
    model.set_tensor_shape(add_node.output[0], new_shape)     

##### Test prune add node

In [47]:
def test_prune_conv_b4_avgpool(model, conv: str):
    
    conv_node = model.get_node_from_name(conv)
    conv_node_predec = model.find_direct_predecessors(conv_node)
    conv_node_weights = conv_node_predec[1]

    # Prune weights
    non_zero_idx, zero_idx, new_ch = prune_conv_weights(model=model, quant_node=conv_node_weights)
    # Update conv out shape
    new_shape = prune_conv_out(model=model, conv_node=conv_node, new_ch=new_ch)
    # Prune batch norm
    bn_node = model.find_direct_successors(conv_node)[0]
    prune_bn(model, bn_node, new_shape, non_zero_idx, zero_idx)
    
    # Find Batch Norm successor
    bn_successor_node = model.find_direct_successors(bn_node)[0]
    if "Relu" in bn_successor_node.name:
        # Prune relu output
        prune_relu(model=model, relu_node=bn_successor_node, new_shape=new_shape)
        # Update successor to relu quant node
        bn_successor_node = model.find_direct_successors(bn_successor_node)[0]
    if "Quant" in bn_successor_node.name:
        # Always update the shape of Quant Node: it will be preceded by relu or batch norm
        prune_quant_out(model, bn_successor_node, new_shape)   
    else:
        raise Exception(f'Node following BN is not ReLU or Quant: {bn_successor_node.name}')

    # Prune next conv weights, so everything fits
    # Check first if it is fork node
        # Fork node: 
            # [0] -> Conv
            # [1] -> Add 
    next_successor_nodes = model.find_direct_successors(bn_successor_node)
    next_successor_node = next_successor_nodes[0]
    fork_node = False
    if len(next_successor_nodes) >= 2:
        print(f'This successor node is a fork: {bn_successor_node.name}') 
        fork_node = True
        next_successor_node_fork = next_successor_nodes[1] 
        print("\n\t%%%%%%%%%%%%% This is the right branch of the fork") 
        if "Add" in next_successor_node_fork.name:
            add_node = next_successor_node_fork
            prune_add_node(model, add_node, new_shape)
            quant_add_node = model.find_direct_successors(add_node)[0]
            if "Quant" in quant_add_node.name:
                prune_quant_out(model, quant_add_node, new_shape)  
            else:
                raise Exception(f'Node following Add is not Quant: {quant_add_node.name}')
            conv_after_quant = model.find_direct_successors(quant_add_node)[0]
            if "Conv" in conv_after_quant.name:
                print(f'\nNext successor node is a convolution: {conv_after_quant.name}')
                # Check if next conv is DW. If so, skip, as whole pruning process must be done
                conv_group = conv_after_quant.attribute[1].i
                if conv_group != 1:
                    print(f'---> DW Conv found. Group = {conv_group}. Skip')
                else:
                    print(f'DW Conv not found. Group = {conv_group}. Prune Conv Weights')
                    prune_next_conv_inp(model, conv_after_quant, non_zero_idx, zero_idx)
            else:
                raise Exception(f'Node following Quant Add is not Conv: {conv_after_quant.name}')
        else:
            print(f'\nNext successor node is a fork, but not followed by Add node: {bn_successor_node.name}')     

    # Always adjust the input weights of next conv, left side of the Fork if it is the case
    if "Conv" in next_successor_node.name:
        conv_succesor_node = next_successor_node
        if fork_node:
            print("\n\t%%%%%%%%%%%%% This is the left branch of the fork")     
        print(f'\nNext successor node is a convolution: {conv_succesor_node.name}')
        # Check if next conv is DW. If so, skip, as whole pruning process must be done
        conv_group = conv_succesor_node.attribute[1].i
        if conv_group != 1:
            print(f'---> DW Conv found. Group = {conv_group}. Skip')
        else:
            print(f'DW Conv not found. Group = {conv_group}. Prune Conv Weights')
            prune_next_conv_inp(model, conv_succesor_node, non_zero_idx, zero_idx)
    elif "Add" in next_successor_node.name:
        add_successor_node = next_successor_node
        print(f'\nNext successor node is add: {add_successor_node.name}. Skip')
    # Successor could be Average Pool too, keep in mind
    elif "AveragePool" in next_successor_node.name:
        avgpool_succesor_node = next_successor_node
        print(f'\nNext successor node is average pooling: {avgpool_succesor_node.name}')
            
    return non_zero_idx, zero_idx, new_ch, new_shape

In [48]:
model = ModelWrapper(qonnx_clean_filename)
non_zero_idx, zero_idx, new_ch, new_shape = test_prune_conv_b4_avgpool(model, "Conv_26")


############ Pruning Weights of Quant_26 node ############
Quant 0 shape: (64, 64, 1, 1)
Quant 1 shape: (64, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (4 elements):
[ 7 10 14 51]
### Non Zero IDX, channels to keep (60 elements):
[ 0  1  2  3  4  5  6  8  9 11 12 13 15 16 17 18 19 20 21 22 23 24 25 26
 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
 52 53 54 55 56 57 58 59 60 61 62 63]
-------------------------------------

------ Prune to muliple of 8: True
Mean of channels to keep, to use it as default for zero elements: [[[0.01414997]]]
Elements to keep is not multiple of 8 -> calculate new Non Zero IDX
4 zero elements must be kept
[ 7 10 14 51]
### Multiple of 8 Non Zero IDX, channels to keep:
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63]
-------------------------------

In [49]:
test_prune_conv_b4_avgpool_file = prune_folder + "07_test_prune_conv_b4_avgpool.onnx"
model.save(test_prune_conv_b4_avgpool_file)

In [50]:
showInNetron(test_prune_conv_b4_avgpool_file)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/07_test_prune_conv_b4_avgpool.onnx' at http://0.0.0.0:8083


##### Test prune convs 26, 27, 28, 29 and 30
Check that resnet is pruned properly

In [51]:
model = ModelWrapper(qonnx_clean_filename)
_, _, _, _ = test_prune_conv_b4_avgpool(model, "Conv_26")
_, _, _, _ = test_prune_conv_b4_avgpool(model, "Conv_27")
_, _, _, _ = test_prune_conv_b4_avgpool(model, "Conv_28")
_, _, _, _ = test_prune_conv_b4_avgpool(model, "Conv_29")
_, _, _, _ = test_prune_conv_b4_avgpool(model, "Conv_30")


############ Pruning Weights of Quant_26 node ############
Quant 0 shape: (64, 64, 1, 1)
Quant 1 shape: (64, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (4 elements):
[ 7 10 14 51]
### Non Zero IDX, channels to keep (60 elements):
[ 0  1  2  3  4  5  6  8  9 11 12 13 15 16 17 18 19 20 21 22 23 24 25 26
 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
 52 53 54 55 56 57 58 59 60 61 62 63]
-------------------------------------

------ Prune to muliple of 8: True
Mean of channels to keep, to use it as default for zero elements: [[[0.01414997]]]
Elements to keep is not multiple of 8 -> calculate new Non Zero IDX
4 zero elements must be kept
[ 7 10 14 51]
### Multiple of 8 Non Zero IDX, channels to keep:
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63]
-------------------------------

In [52]:
test_prune_resnet = prune_folder + "08_test_prune_resnet.onnx"
model.save(test_prune_resnet)

In [53]:
showInNetron(test_prune_resnet)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/08_test_prune_resnet.onnx' at http://0.0.0.0:8083


#### Prune Avg Pool and Reshape

In [54]:
def prune_avgpool_reshape(model, avgpool_node, new_shape):

    print(f'\n############ Update output shape of {avgpool_node.name} node ############')
    avgpool_old_shape = model.get_tensor_shape(avgpool_node.output[0])
    print(f'Old Avg Pool output shape: {avgpool_old_shape}')
    avgpool_new_shape = (new_shape[0], new_shape[1], 1, 1)
    print(f'New shape: {avgpool_new_shape}')
    model.set_tensor_shape(avgpool_node.output[0], avgpool_new_shape)

    mul_node = model.find_direct_successors(avgpool_node)[0]
    if "Mul" in mul_node.name:
        print(f'\n############ Update output shape of {mul_node.name} node ############')
        print(f'New shape: {avgpool_new_shape}')
        model.set_tensor_shape(mul_node.output[0], avgpool_new_shape)
    else:
        raise Exception(f'Node following AvgPool is not Mul: {mul_node.name}')  

    trunc_node = model.find_direct_successors(mul_node)[0]
    if "Trunc" in trunc_node.name:
        print(f'\n############ Update output shape of {trunc_node.name} node ############')
        print(f'New shape: {avgpool_new_shape}')
        model.set_tensor_shape(trunc_node.output[0], avgpool_new_shape)
    else:
        raise Exception(f'Node following Mul is not Trunc: {trunc_node.name}')  

    reshape_node = model.find_direct_successors(trunc_node)[0]
    if "Reshape" in reshape_node.name:
        print(f'\n############ Update output shape of {reshape_node.name} node ############')
        reshape_shape = (avgpool_new_shape[0], avgpool_new_shape[1])
        print(f'New shape: {reshape_shape}')
        model.set_tensor_shape(reshape_node.output[0], reshape_shape)
    else:
        raise Exception(f'Node following Trunc is not Reshape: {reshape_node.name}') 

    gemm_node = model.find_direct_successors(reshape_node)[0]
    if "Gemm" in gemm_node.name:
        print(f'\n############ Gemm node found: {gemm_node.name} node ############')
    else:
        raise Exception(f'Node following Trunc is not Reshape: {gemm_node.name}')     
   
    return gemm_node

##### Test prune avgpool and reshape

In [55]:
def test_prune_conv_and_avgpool(model, conv: str):
    
    conv_node = model.get_node_from_name(conv)
    conv_node_predec = model.find_direct_predecessors(conv_node)
    conv_node_weights = conv_node_predec[1]

    # Prune weights
    non_zero_idx, zero_idx, new_ch = prune_conv_weights(model=model, quant_node=conv_node_weights)
    # Update conv out shape
    new_shape = prune_conv_out(model=model, conv_node=conv_node, new_ch=new_ch)
    # Prune batch norm
    bn_node = model.find_direct_successors(conv_node)[0]
    prune_bn(model, bn_node, new_shape, non_zero_idx, zero_idx)
    
    # Find Batch Norm successor
    bn_successor_node = model.find_direct_successors(bn_node)[0]
    if "Relu" in bn_successor_node.name:
        # Prune relu output
        prune_relu(model=model, relu_node=bn_successor_node, new_shape=new_shape)
        # Update successor to relu quant node
        bn_successor_node = model.find_direct_successors(bn_successor_node)[0]
    if "Quant" in bn_successor_node.name:
        # Always update the shape of Quant Node: it will be preceded by relu or batch norm
        prune_quant_out(model, bn_successor_node, new_shape)   
    else:
        raise Exception(f'Node following BN is not ReLU or Quant: {bn_successor_node.name}')

    # Prune next conv weights, so everything fits
    # Check first if it is fork node
        # Fork node: 
            # [0] -> Conv
            # [1] -> Add 
    next_successor_nodes = model.find_direct_successors(bn_successor_node)
    next_successor_node = next_successor_nodes[0]
    fork_node = False
    if len(next_successor_nodes) >= 2:
        print(f'This successor node is a fork: {bn_successor_node.name}') 
        fork_node = True
        next_successor_node_fork = next_successor_nodes[1] 
        print("\n\t%%%%%%%%%%%%% This is the right branch of the fork") 
        if "Add" in next_successor_node_fork.name:
            add_node = next_successor_node_fork
            prune_add_node(model, add_node, new_shape)
            quant_add_node = model.find_direct_successors(add_node)[0]
            if "Quant" in quant_add_node.name:
                prune_quant_out(model, quant_add_node, new_shape)  
            else:
                raise Exception(f'Node following Add is not Quant: {quant_add_node.name}')
            conv_after_quant = model.find_direct_successors(quant_add_node)[0]
            if "Conv" in conv_after_quant.name:
                print(f'\nNext successor node is a convolution: {conv_after_quant.name}')
                # Check if next conv is DW. If so, skip, as whole pruning process must be done
                conv_group = conv_after_quant.attribute[1].i
                if conv_group != 1:
                    print(f'---> DW Conv found. Group = {conv_group}. Skip')
                else:
                    print(f'DW Conv not found. Group = {conv_group}. Prune Conv Weights')
                    prune_next_conv_inp(model, conv_after_quant, non_zero_idx, zero_idx)
            else:
                raise Exception(f'Node following Quant Add is not Conv: {conv_after_quant.name}')
        else:
            print(f'\nNext successor node is a fork, but not followed by Add node: {bn_successor_node.name}')     

    # Always adjust the input weights of next conv, left side of the Fork if it is the case
    if "Conv" in next_successor_node.name:
        conv_succesor_node = next_successor_node
        if fork_node:
            print("\n\t%%%%%%%%%%%%% This is the left branch of the fork")     
        print(f'\nNext successor node is a convolution: {conv_succesor_node.name}')
        # Check if next conv is DW. If so, skip, as whole pruning process must be done
        conv_group = conv_succesor_node.attribute[1].i
        if conv_group != 1:
            print(f'---> DW Conv found. Group = {conv_group}. Skip')
        else:
            print(f'DW Conv not found. Group = {conv_group}. Prune Conv Weights')
            prune_next_conv_inp(model, conv_succesor_node, non_zero_idx, zero_idx)
    elif "Add" in next_successor_node.name:
        add_successor_node = next_successor_node
        print(f'\nNext successor node is add: {add_successor_node.name}. Skip')
    # Successor could be Average Pool too, keep in mind
    elif "AveragePool" in next_successor_node.name:
        avgpool_succesor_node = next_successor_node
        print(f'\nNext successor node is average pooling: {avgpool_succesor_node.name}')
        gemm_node = prune_avgpool_reshape(model, avgpool_succesor_node, new_shape)
            
    return non_zero_idx, zero_idx, new_ch, new_shape

In [56]:
model = ModelWrapper(qonnx_clean_filename)
_, _, _, _ = test_prune_conv_and_avgpool(model, "Conv_30")


############ Pruning Weights of Quant_30 node ############
Quant 0 shape: (128, 64, 1, 1)
Quant 1 shape: (128, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (26 elements):
[  4   8  13  14  24  36  41  48  52  59  61  62  71  75  77  78  82  86
  91  96 101 105 110 118 120 126]
### Non Zero IDX, channels to keep (102 elements):
[  0   1   2   3   5   6   7   9  10  11  12  15  16  17  18  19  20  21
  22  23  25  26  27  28  29  30  31  32  33  34  35  37  38  39  40  42
  43  44  45  46  47  49  50  51  53  54  55  56  57  58  60  63  64  65
  66  67  68  69  70  72  73  74  76  79  80  81  83  84  85  87  88  89
  90  92  93  94  95  97  98  99 100 102 103 104 106 107 108 109 111 112
 113 114 115 116 117 119 121 122 123 124 125 127]
-------------------------------------

------ Prune to muliple of 8: True
Mean of channels to keep, to use it as default for zero elements: [[[0.00795446]]]
Elements to keep is not multiple of 8 -> calculate new Non Z

In [57]:
test_prune_avgpool = prune_folder + "09_test_prune_avgpool.onnx"
model.save(test_prune_avgpool)

In [58]:
showInNetron(test_prune_avgpool)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/09_test_prune_avgpool.onnx' at http://0.0.0.0:8083


#### Prune GEMM

In [59]:
def prune_next_gemm_inp(model, gemm_node, non_zero_idx, zero_idx):

    print(f'\n............ Pruning Weights of {gemm_node.name} node ............')   
    gemm_node_predec = model.find_direct_predecessors(gemm_node)
    quant_node = gemm_node_predec[1]  # Quant node is [1]  
    
    quant_0 = quant_node.input[0]
    np_q0 = model.get_initializer(quant_0)
    print(f'{quant_0} shape: {np_q0.shape}')

    new_np_q0 = np_q0[:, non_zero_idx]
    new_shape = new_np_q0.shape
    print(f'New {quant_0} shape: {new_shape}')
    
    print("-------------------------------------")
    all_weights_zero = np.all(np_q0[:, zero_idx] == 0.)
    print(f'*** All weights removed are zero? {all_weights_zero}')
    print(f'### Non Zero IDX, channels to keep:\n{non_zero_idx}')
    print("-------------------------------------")

    model.set_initializer(
        tensor_name = quant_0, 
        tensor_value = new_np_q0) 

    print(f'Modify output shape of {quant_node.name} node: {new_shape}')
    model.set_tensor_shape(quant_node.output[0], new_shape) 

##### Test prune GEMM

In [60]:
def test_prune_any_conv(model, conv: str, do_mul_val = True):
    
    conv_node = model.get_node_from_name(conv)
    conv_node_predec = model.find_direct_predecessors(conv_node)
    conv_node_weights = conv_node_predec[1]

    # Prune weights
    non_zero_idx, zero_idx, new_ch = prune_conv_weights(model=model, quant_node=conv_node_weights, do_mul_val=do_mul_val)
    # Update conv out shape
    new_shape = prune_conv_out(model=model, conv_node=conv_node, new_ch=new_ch)
    # Prune batch norm
    bn_node = model.find_direct_successors(conv_node)[0]
    prune_bn(model, bn_node, new_shape, non_zero_idx, zero_idx)
    
    # Find Batch Norm successor
    bn_successor_node = model.find_direct_successors(bn_node)[0]
    if "Relu" in bn_successor_node.name:
        # Prune relu output
        prune_relu(model=model, relu_node=bn_successor_node, new_shape=new_shape)
        # Update successor to relu quant node
        bn_successor_node = model.find_direct_successors(bn_successor_node)[0]
    if "Quant" in bn_successor_node.name:
        # Always update the shape of Quant Node: it will be preceded by relu or batch norm
        prune_quant_out(model, bn_successor_node, new_shape)   
    else:
        raise Exception(f'Node following BN is not ReLU or Quant: {bn_successor_node.name}')

    # Prune next conv weights, so everything fits
    # Check first if it is fork node
        # Fork node: 
            # [0] -> Conv
            # [1] -> Add 
    next_successor_nodes = model.find_direct_successors(bn_successor_node)
    next_successor_node = next_successor_nodes[0]
    fork_node = False
    if len(next_successor_nodes) >= 2:
        print(f'This successor node is a fork: {bn_successor_node.name}') 
        fork_node = True
        next_successor_node_fork = next_successor_nodes[1] 
        print("\n\t%%%%%%%%%%%%% This is the right branch of the fork") 
        if "Add" in next_successor_node_fork.name:
            add_node = next_successor_node_fork
            prune_add_node(model, add_node, new_shape)
            quant_add_node = model.find_direct_successors(add_node)[0]
            if "Quant" in quant_add_node.name:
                prune_quant_out(model, quant_add_node, new_shape)  
            else:
                raise Exception(f'Node following Add is not Quant: {quant_add_node.name}')
            conv_after_quant = model.find_direct_successors(quant_add_node)[0]
            if "Conv" in conv_after_quant.name:
                print(f'\nNext successor node is a convolution: {conv_after_quant.name}')
                # Check if next conv is DW. If so, skip, as whole pruning process must be done
                conv_group = conv_after_quant.attribute[1].i
                if conv_group != 1:
                    print(f'---> DW Conv found. Group = {conv_group}. Skip')
                else:
                    print(f'DW Conv not found. Group = {conv_group}. Prune Conv Weights')
                    prune_next_conv_inp(model, conv_after_quant, non_zero_idx, zero_idx)
            else:
                raise Exception(f'Node following Quant Add is not Conv: {conv_after_quant.name}')
        else:
            print(f'\nNext successor node is a fork, but not followed by Add node: {bn_successor_node.name}')     

    # Always adjust the input weights of next conv, left side of the Fork if it is the case
    if "Conv" in next_successor_node.name:
        conv_succesor_node = next_successor_node
        if fork_node:
            print("\n\t%%%%%%%%%%%%% This is the left branch of the fork")     
        print(f'\nNext successor node is a convolution: {conv_succesor_node.name}')
        # Check if next conv is DW. If so, skip, as whole pruning process must be done
        conv_group = conv_succesor_node.attribute[1].i
        if conv_group != 1:
            print(f'---> DW Conv found. Group = {conv_group}. Skip')
        else:
            print(f'DW Conv not found. Group = {conv_group}. Prune Conv Weights')
            prune_next_conv_inp(model, conv_succesor_node, non_zero_idx, zero_idx)
    elif "Add" in next_successor_node.name:
        add_successor_node = next_successor_node
        print(f'\nNext successor node is add: {add_successor_node.name}. Skip')
    # Successor could be Average Pool too, keep in mind
    elif "AveragePool" in next_successor_node.name:
        avgpool_succesor_node = next_successor_node
        print(f'\nNext successor node is average pooling: {avgpool_succesor_node.name}')
        gemm_node = prune_avgpool_reshape(model, avgpool_succesor_node, new_shape)
        prune_next_gemm_inp(model, gemm_node, non_zero_idx, zero_idx)
            
    return non_zero_idx, zero_idx, new_ch, new_shape

In [61]:
model = ModelWrapper(qonnx_clean_filename)
_, _, _, _ = test_prune_any_conv(model, "Conv_30")


############ Pruning Weights of Quant_30 node ############
Quant 0 shape: (128, 64, 1, 1)
Quant 1 shape: (128, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (26 elements):
[  4   8  13  14  24  36  41  48  52  59  61  62  71  75  77  78  82  86
  91  96 101 105 110 118 120 126]
### Non Zero IDX, channels to keep (102 elements):
[  0   1   2   3   5   6   7   9  10  11  12  15  16  17  18  19  20  21
  22  23  25  26  27  28  29  30  31  32  33  34  35  37  38  39  40  42
  43  44  45  46  47  49  50  51  53  54  55  56  57  58  60  63  64  65
  66  67  68  69  70  72  73  74  76  79  80  81  83  84  85  87  88  89
  90  92  93  94  95  97  98  99 100 102 103 104 106 107 108 109 111 112
 113 114 115 116 117 119 121 122 123 124 125 127]
-------------------------------------

------ Prune to muliple of 8: True
Mean of channels to keep, to use it as default for zero elements: [[[0.00795446]]]
Elements to keep is not multiple of 8 -> calculate new Non Z

In [62]:
test_prune_any_conv_file = prune_folder + "10_test_prune_any_conv_file.onnx"
model.save(test_prune_any_conv_file)

In [63]:
showInNetron(test_prune_any_conv_file)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/10_test_prune_any_conv_file.onnx' at http://0.0.0.0:8083


# Test Pruning of first 2 Convs

In [64]:
# model = ModelWrapper(qonnx_clean_filename)
# # non_zero_idx, zero_idx, new_ch, new_shape = test_prune_conv(model, "Conv_0")
# # non_zero_idx, zero_idx, new_ch, new_shape = test_prune_conv(model, "Conv_1")

# _, _, _, _ = test_prune_conv(model, "Conv_0")
# _, _, _, _ = test_prune_conv(model, "Conv_1")

In [65]:
# test_prune_2_conv = prune_folder + "10_test_prune_2_conv.onnx"

# model = model.transform(InferShapes())
# model.save(test_prune_2_conv)

In [66]:
# showInNetron(test_prune_2_conv)

# Test Pruning the Whole Model - NO MUL 4 or 8

In [67]:
set_mul4 = False

In [68]:
model = ModelWrapper(qonnx_clean_filename)

for conv in convs_to_prune:
    print(f'\n______________________________________________________________________________________________________')
    print(f'                                                {conv} ')
    print(f'______________________________________________________________________________________________________')

    if set_mul4:
        _, _, _, _ = test_prune_any_conv(model, conv)
    else:
        _, _, _, _ = test_prune_any_conv(model, conv, do_mul_val=False)    


______________________________________________________________________________________________________
                                                Conv_12 
______________________________________________________________________________________________________

############ Pruning Weights of Quant_12 node ############
Quant 0 shape: (48, 24, 1, 1)
Quant 1 shape: (48, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (1 elements):
[28]
### Non Zero IDX, channels to keep (47 elements):
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47]
-------------------------------------

------ Prune to muliple of 8: False
New Quant 0 shape: (47, 24, 1, 1)
New Quant 1 shape: (47, 1, 1, 1)
Quant_12 output original shape: (48, 24, 1, 1)
Quant_12 output new shape: (47, 24, 1, 1)

############ Convert Convolution Output Shape Conv_12 node ############
Old Conv output shape: (1,

In [69]:
model = model.transform(InferShapes())

In [70]:
if set_mul4:
    prune_all_convs_onnx = prune_folder + "20_prune_all_convs_mul8.onnx"
else:
    prune_all_convs_onnx = prune_folder + "30_prune_all_convs_no_mul8.onnx"
model.save(prune_all_convs_onnx)

In [71]:
showInNetron(prune_all_convs_onnx)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/30_prune_all_convs_no_mul8.onnx' at http://0.0.0.0:8083


### Compare Sparsity NO MUL 4 or 8

In [72]:
sparsity_after_pruning = get_sparsity(model, layers_to_prune)

for k1, k2 in zip(sparsity_before_pruning.keys(), sparsity_after_pruning.keys()):
    before = sparsity_before_pruning[k1]["sparsity"]
    after = sparsity_after_pruning[k2]["sparsity"]
    assert k1 == k2, f'{k1} is not the same as {k2}'
    print(f'{k1}: \tbefore: {before:<4} - after: {after:<4}')

Quant_12_param0: 	before: 0.11 - after: 0.09
Quant_13_param0: 	before: 0.09 - after: 0.07
Quant_15_param0: 	before: 0.14 - after: 0.1 
Quant_16_param0: 	before: 0.1  - after: 0.06
Quant_18_param0: 	before: 0.22 - after: 0.11
Quant_19_param0: 	before: 0.18 - after: 0.07
Quant_21_param0: 	before: 0.33 - after: 0.1 
Quant_22_param0: 	before: 0.31 - after: 0.06
Quant_24_param0: 	before: 0.11 - after: 0.1 
Quant_25_param0: 	before: 0.1  - after: 0.08
Quant_26_param0: 	before: 0.18 - after: 0.11
Quant_27_param0: 	before: 0.61 - after: 0.13
Quant_28_param0: 	before: 0.56 - after: 0.07
Quant_29_param0: 	before: 0.61 - after: 0.13
Quant_30_param0: 	before: 0.37 - after: 0.16


# Test Pruning the Whole Model - YES MUL 4 or 8

In [73]:
set_mul4 = True

In [74]:
model = ModelWrapper(qonnx_clean_filename)

for conv in convs_to_prune:
    print(f'\n______________________________________________________________________________________________________')
    print(f'                                                {conv} ')
    print(f'______________________________________________________________________________________________________')

    if set_mul4:
        _, _, _, _ = test_prune_any_conv(model, conv)
    else:
        _, _, _, _ = test_prune_any_conv(model, conv, do_mul_val=False)    


______________________________________________________________________________________________________
                                                Conv_12 
______________________________________________________________________________________________________

############ Pruning Weights of Quant_12 node ############
Quant 0 shape: (48, 24, 1, 1)
Quant 1 shape: (48, 1, 1, 1)
-------------------------------------
*** Zero IDX, channels to be pruned (1 elements):
[28]
### Non Zero IDX, channels to keep (47 elements):
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47]
-------------------------------------

------ Prune to muliple of 8: True
Mean of channels to keep, to use it as default for zero elements: [[[0.02151186]]]
Elements to keep is not multiple of 8 -> calculate new Non Zero IDX
1 zero elements must be kept
[28]
### Multiple of 8 Non Zero IDX, channels to keep:
[ 0  1  2  3  4  5  6

In [75]:
model = model.transform(InferShapes())

In [76]:
if set_mul4:
    prune_all_convs_onnx = prune_folder + "20_prune_all_convs_mul8.onnx"
else:
    prune_all_convs_onnx = prune_folder + "30_prune_all_convs_no_mul8.onnx"
model.save(prune_all_convs_onnx)

In [77]:
showInNetron(prune_all_convs_onnx)

Stopping http://0.0.0.0:8083
Serving './manual_pruning/pruning/sparse24_mul8/20_prune_all_convs_mul8.onnx' at http://0.0.0.0:8083


### Compare Sparsity YES MUL 4 or 8

In [78]:
sparsity_after_pruning = get_sparsity(model, layers_to_prune)

for k1, k2 in zip(sparsity_before_pruning.keys(), sparsity_after_pruning.keys()):
    before = sparsity_before_pruning[k1]["sparsity"]
    after = sparsity_after_pruning[k2]["sparsity"]
    assert k1 == k2, f'{k1} is not the same as {k2}'
    print(f'{k1}: \tbefore: {before:<4} - after: {after:<4}')

Quant_12_param0: 	before: 0.11 - after: 0.11
Quant_13_param0: 	before: 0.09 - after: 0.09
Quant_15_param0: 	before: 0.14 - after: 0.14
Quant_16_param0: 	before: 0.1  - after: 0.1 
Quant_18_param0: 	before: 0.22 - after: 0.11
Quant_19_param0: 	before: 0.18 - after: 0.07
Quant_21_param0: 	before: 0.33 - after: 0.11
Quant_22_param0: 	before: 0.31 - after: 0.07
Quant_24_param0: 	before: 0.11 - after: 0.11
Quant_25_param0: 	before: 0.1  - after: 0.1 
Quant_26_param0: 	before: 0.18 - after: 0.18
Quant_27_param0: 	before: 0.61 - after: 0.23
Quant_28_param0: 	before: 0.56 - after: 0.11
Quant_29_param0: 	before: 0.61 - after: 0.22
Quant_30_param0: 	before: 0.37 - after: 0.23


# Prune Conv 22, as it is multiple of 4 or 8

In [79]:
# model = ModelWrapper(qonnx_clean_filename)
# _, _, _, _ = test_prune_any_conv(model, "Conv_21")
# _, _, _, _ = test_prune_any_conv(model, "Conv_22")

In [80]:
# prune_only_conv_22 = prune_folder + "50_prune_only_conv_22.onnx"
# model.save(prune_only_conv_22)

In [81]:
# showInNetron(prune_only_conv_22)