In [1]:
import tensorflow.compat.v1 as tf
tf.enable_eager_execution()

from tensorflow.python.keras.applications import ResNet50
from tensorflow.python.keras.layers import Layer, Conv2D, Wrapper
from tensorflow.python.keras.models import Model, Sequential

from tensorflow_model_optimization.python.core.sparsity import keras as pruning

import numpy as np
from scipy import stats

import os
import sys
import re

In [2]:
# Loading model, collecting convolutional layers
model = ResNet50()
conv_layers = [layer for layer in model.layers if layer.name.count("conv") == 2]
test_layer = conv_layers[0]

In [11]:
class PruneWrapper(Wrapper):
    """
    Augment layer to make prunable
    
    What I need:
    layer - The layer that is going to be made pruneable
    mask - Create a mask variable for the layer that is being pruned, that will store T or F values to flag weights to be pruned
    prune_specs - A function that will be passed the model, that will determine the method in which it will be pruned
    ***
    prune specs = {
        'prune_level': "",
        'pruning_method': ""
    }
    
    
    """
    def __init__(self, layer: Layer, prune_level: str = "weights", prune_method: str = "taylor_first_order", **kwargs):
        super(PruneWrapper, self).__init__(layer, **kwargs)
        
        self._layer = layer
        self._pruning_layer = self._copy_layer(layer)
        self.rankings = []
        
        if prune_level not in ("weights", "filter"):
            raise ValueError("Incompatible prune_level:\n\n" /
                            "Supported Levels:\n" / 
                            "weights\nfilter")
            
        if prune_method not in ("taylor_first_order", "taylor_second_order", "oracle"):
            raise ValueError("Incompatible prune_method:\n\n" /
                            "Supported Methods:\n" / 
                            "taylor_first_order\taylor_second_order\noracle")
        
        self.prune_level = prune_level
        self.prune_method = prune_method
        
        
        # Kwargs
        self.scope = kwargs.get("scope", "")
        
        
        # Mask to track weight pruning
        wandb = layer.get_weights()
        weights = wandb[0]
        
        
        with tf.variable_scope(scope):
            mask = tf.Variable(initial_value=tf.ones(weights.shape, dtype=tf.float32, name=None),
                               trainable=False,
                               name="mask",
                               dtype=tf.float32,
                               aggregation=tf.VariableAggregation.MEAN,
                               shape=weights.shape)
        

    def _copy_layer(self, layer: Layer):
        config = layer.get_config()
        copy_layer = None
        
        if isinstance(layer, Conv2D):
            copy_layer = Conv2D(**config)
            
        return copy_layer
            
        
    def _get_layer(self):
        return self._layer
    
    
    def _get_pruning_layer(self):
        return self._pruning_layer
    
    
    def get_rankings(self):
        return self.rankings
    
    
    def get_wandb(self):
        return self.layer.get_weights()
    
    
    def get_weights(self):
        return self.layer.get_weights()[0]
    
    
    def get_prune_level(self):
        return self.prune_level
    
    
    def set_prune_level(self, prune_level):
        if prune_level not in ("weights", "filter", "layer"):
            raise ValueError("Incompatible prune_level:\n\n" /
                            "Supported Levels:\n" / 
                            "weights\nfilter\nlayer")            
        self.prune_level = prune_level

        
    def get_prune_method(self):
        return self.prune_method
    
    
    def set_prune_method(self, prune_method):
        if prune_method not in ("taylor_first_order", "taylor_second_order", "oracle"):
            raise ValueError("Incompatible prune_method:\n\n" /
                            "Supported Methods:\n" / 
                            "taylor_first_order\taylor_second_order\noracle")
        self.prune_method = prune_method
        
        
    def get_mask(self):
        return self.mask
    
    
    def set_mask(self, new_mask: tf.Variable):
        self.mask.assign(new_mask)
        
    
    def prune(self):
        pruning_layer = self._get_pruning_layer()
        mask = self.get_mask()
        self.pruning_layer = mask * pruning_layer
    

In [12]:
prunable_layer = PruneWrapper(test_layer)

TypeError: _variable_v1_call() got an unexpected keyword argument 'initializer'