Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Did not support <class 'tensorflow.python.keras.layers.normalization_v2.BatchNormalization'>? #165

Closed
murdockhou opened this issue Dec 3, 2019 · 12 comments
Assignees
Labels
technique:pruning Regarding tfmot.sparsity.keras APIs and docs

Comments

@murdockhou
Copy link

I follow the guide here and want to prune mine own model, and raise error like this:

ValueError: Please initialize Prunewith a supported layer. Layers should either be aPrunableLayer instance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.layers.normalization_v2.BatchNormalization'>

@alanchiao
Copy link

alanchiao commented Dec 5, 2019

Addressing this in PR.

Are you using TF 2.X? The library has not been fully tested with TF 2.X yet. See this issue to follow.

@alanchiao
Copy link

alanchiao commented Dec 5, 2019

Added support in this commit. Note that it hasn't been release yet in the pip package and you will currently need to build the package from source until the next release (0.2.0).

@murdockhou
Copy link
Author

@alanchiao thanks, I'll try it.

@murdockhou
Copy link
Author

@alanchiao Hi, after I build the package from source latest, it raises another error:
File "/usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 155, in __init__ 'PruneRegistry. You passed: {input}'.format(input=layer.__class__)) ValueError: Please initialize Prunewith a supported layer. Layers should either be aPrunableLayerinstance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.engine.base_layer.TensorFlowOpLayer'>
I'm using TF 2.0.0

@alanchiao
Copy link

Hi Shiwei, you'll need to share the model building code or a minimal reproducible example for me to figure out the issue. See New Issue > Bug for the kind of information that is needed.

There may be a compatibility issue with 2.X. That is something I'm actively working on and if you share a minimal reproducible example, I may be able to ensure that it works for your case.

If you are using custom Keras layer, you should look at these docs: https://www.tensorflow.org/model_optimization/guide/pruning/train_sparse_models#prune_a_custom_layer

@alanchiao
Copy link

Let me know if it ends up being a custom Keras layer issue. In that case, we may able to improve the error message to make easier for users like yourself in the future.

@alanchiao
Copy link

@Xhark : FYI Jaehong on potential 2.X compatibility issue we should try to resolve.

@murdockhou
Copy link
Author

Hi, @alanchiao , here is my code:

import tensorflow as tf
from tensorflow_model_optimization.sparsity import  keras as sparsity
import math

NUM_CLASSES = 10

kernel_initializer = tf.keras.initializers.VarianceScaling

def swish(x):
    return x * tf.keras.activations.sigmoid(x)

def round_filters(filters, multiplier):
    depth_divisor = 8
    min_depth = None
    min_depth = min_depth or depth_divisor
    filters = filters * multiplier
    new_filters = max(min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor)
    if new_filters < 0.9 * filters:
        new_filters += depth_divisor
    return int(new_filters)

def round_repeats(repeats, multiplier):
    if not multiplier:
        return repeats
    return int(math.ceil(multiplier * repeats))

def SEBlock(inputs, input_channels, ratio=0.25):

    num_reduced_filters = max(1, int(input_channels * ratio))
    branch = tf.keras.layers.GlobalAveragePooling2D()(inputs)
    # branch = tf.keras.layers.Lambda(lambda branch: tf.expand_dims(input=branch, axis=1))(branch)
    branch = tf.keras.backend.expand_dims(branch, 1)
    branch = tf.keras.backend.expand_dims(branch, 1)
    # branch = tf.keras.layers.Lambda(lambda branch: tf.expand_dims(input=branch, axis=1))(branch)
    branch = tf.keras.layers.Conv2D(filters=num_reduced_filters, kernel_size=(1, 1), strides=1, padding="same", kernel_initializer=kernel_initializer)(branch)
    branch = swish(branch)
    branch = tf.keras.layers.Conv2D(filters=input_channels, kernel_size=(1, 1), strides=1, padding='same', kernel_initializer=kernel_initializer)(branch)
    branch = tf.keras.activations.sigmoid(branch)
    output = inputs * branch

    return output

def MBConv(in_channels, out_channels, expansion_factor, stride, k, drop_connect_rate, inputs, training=False):
    x = tf.keras.layers.Conv2D(filters=in_channels * expansion_factor,kernel_size=(1, 1),strides=1,padding="same", kernel_initializer=kernel_initializer)(inputs)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    x = swish(x)
    x = tf.keras.layers.DepthwiseConv2D(kernel_size=(k, k), strides=stride, padding="same", kernel_initializer=kernel_initializer)(x)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    x = SEBlock(x, in_channels*expansion_factor)
    x = swish(x)
    x = tf.keras.layers.Conv2D(filters=out_channels,kernel_size=(1, 1),strides=1,padding="same", kernel_initializer=kernel_initializer)(x)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    if stride == 1 and in_channels == out_channels:
        if drop_connect_rate:
            x = tf.keras.layers.Dropout(rate=drop_connect_rate)(x, training=training)
        x = tf.keras.layers.Add()([x, inputs])

    return x

def build_mbconv_block(inputs, in_channels, out_channels, layers, stride, expansion_factor, k, drop_connect_rate, training):

    x = inputs
    for i in range(layers):
        if i == 0:
            x = MBConv(in_channels=in_channels, out_channels=out_channels, expansion_factor=expansion_factor,
                       stride=stride, k=k, drop_connect_rate=drop_connect_rate, inputs=x, training=training)
        else:
            x = MBConv(in_channels=out_channels, out_channels=out_channels, expansion_factor=expansion_factor,
                       stride=1, k=k, drop_connect_rate=drop_connect_rate, inputs=x, training=training)

    return x


def EfficientNet(inputs, width_coefficient, depth_coefficient, dropout_rate, drop_connect_rate=0.2, training=False):

    features = []

    x = tf.keras.layers.Conv2D(filters=round_filters(32, width_coefficient),kernel_size=(3, 3),strides=2, padding="same", kernel_initializer=kernel_initializer) (inputs)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    x = swish(x)

    x = build_mbconv_block(x, in_channels=round_filters(32, width_coefficient),
                           out_channels=round_filters(16, width_coefficient),
                           layers=round_repeats(1, depth_coefficient),
                           stride=1,
                           expansion_factor=1, k=3,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(16, width_coefficient),
                           out_channels=round_filters(24, width_coefficient),
                           layers=round_repeats(2, depth_coefficient),
                           stride=1,
                           expansion_factor=6, k=3,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(24, width_coefficient),
                           out_channels=round_filters(40, width_coefficient),
                           layers=round_repeats(2, depth_coefficient),
                           stride=2,
                           expansion_factor=6, k=5,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(40, width_coefficient),
                           out_channels=round_filters(80, width_coefficient),
                           layers=round_repeats(3, depth_coefficient),
                           stride=2,
                           expansion_factor=6, k=3,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(80, width_coefficient),
                           out_channels=round_filters(112, width_coefficient),
                           layers=round_repeats(3, depth_coefficient),
                           stride=1,
                           expansion_factor=6, k=5,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(112, width_coefficient),
                           out_channels=round_filters(192, width_coefficient),
                           layers=round_repeats(4, depth_coefficient),
                           stride=2,
                           expansion_factor=6, k=5,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(192, width_coefficient),
                           out_channels=round_filters(320, width_coefficient),
                           layers=round_repeats(1, depth_coefficient),
                           stride=1,
                           expansion_factor=6, k=3,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = tf.keras.layers.Conv2D(filters=round_filters(1280, width_coefficient), kernel_size=(1, 1), strides=1, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    x = swish(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(rate=dropout_rate)(x, training=training)
    x = tf.keras.layers.Dense(units=1, activation=tf.keras.activations.softmax)(x)

    return x, features


def efficient_net_b0(inputs, training):
    return EfficientNet(inputs,
                        width_coefficient=1.0,
                        depth_coefficient=1.0,
                        dropout_rate=0.2,
                        drop_connect_rate=0.2,
                        training=training)

def up_sample(inputs, training=True):
    x = tf.keras.layers.UpSampling2D()(inputs)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    x = tf.keras.layers.ReLU()(x)
    return x

def singlePosenet(inputs, outc, training=True):

    _, features =  efficient_net_b0(inputs=inputs, training=training)

    # [ 1/2, 1/4, 1/8, 1/8, 1/16]
    outputs = []
    for i, name in enumerate(features):
        x = features[i]
        if x.shape[1] > inputs.shape[1] // 4:
            continue
        while x.shape[1] < (inputs.shape[1]//4):
            x = up_sample(x, training)
        outputs.append(x)

    quater_res = tf.keras.layers.Concatenate()(outputs)
    quater_res = tf.keras.layers.Conv2D(512, 3, 1, 'same', activation=tf.nn.relu)(quater_res)
    quater_res = tf.keras.layers.Conv2D(512, 3, 1, 'same', activation=tf.nn.relu)(quater_res)
    quater_res_out = tf.keras.layers.Conv2D(outc, 1, 1, 'same', name='quater', activation=None)(quater_res)

    half_res = up_sample(quater_res, training)
    half_res = tf.keras.layers.Conv2D(512, 3, 1, 'same', activation=tf.nn.relu)(half_res)
    half_res = tf.keras.layers.Conv2D(512, 3, 1, 'same', activation=tf.nn.relu)(half_res)
    half_res_out = tf.keras.layers.Conv2D(outc, 1, 1, 'same', name='half', activation=None)(half_res)

    return quater_res_out, half_res_out



inputs = tf.keras.Input(shape=(224,224,3), name='modelInput')
loaded_model = tf.keras.Model(inputs, singlePosenet(inputs, 13, True))
# loaded_model.load_weights('models/first_version/45-0.0073517.hdf5')

end_step = 10000
pruning_params = {
    'pruning_schedule': sparsity.PolynomialDecay(
        initial_sparsity=0.50,
        final_sparsity=0.90,
        begin_step=0,
        end_step=end_step,
        frequency=1000)
}

pruning_model = sparsity.prune_low_magnitude(loaded_model, **pruning_params)
pruning_model.summary()

and when I run this, raise error:

Traceback (most recent call last):
  File "tools/model_optim.py", line 211, in <module>
    pruning_model = sparsity.prune_low_magnitude(loaded_model, **pruning_params)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/prune.py", line 159, in prune_low_magnitude
    to_prune, input_tensors=None, clone_function=_add_pruning_wrapper)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/models.py", line 422, in clone_model
    model, input_tensors=input_tensors, layer_fn=clone_function)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/models.py", line 191, in _clone_functional_model
    model, new_input_layers, layer_fn)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/models.py", line 241, in _clone_layers_and_model_config
    config = network.get_network_config(model, serialize_layer_fn=_copy_layer)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/network.py", line 1942, in get_network_config
    layer_config = serialize_layer_fn(layer)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/models.py", line 238, in _copy_layer
    created_layers[layer.name] = layer_fn(layer)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/prune.py", line 140, in _add_pruning_wrapper
    return pruning_wrapper.PruneLowMagnitude(layer, **params)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 155, in __init__
    'PruneRegistry. You passed: {input}'.format(input=layer.__class__))
ValueError: Please initialize `Prune` with a supported layer. Layers should either be a `PrunableLayer` instance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.engine.base_layer.TensorFlowOpLayer'>


My tensorflow version: 2.0.0
tensorlfow_model_optimization version: tf-model-optimization-nightly-0.1.3.dev20191209

@Xhark
Copy link
Member

Xhark commented Dec 9, 2019

If you use keras layers version of multiply, sigmoid, expand_dims (reshape), then it'll works.
Here is some modified code, please verify this code is working as you expected.

import tensorflow as tf
from tensorflow_model_optimization.sparsity import  keras as sparsity
import math
from keras.backend import int_shape
NUM_CLASSES = 10

kernel_initializer = tf.keras.initializers.VarianceScaling

def swish(x):
    return tf.keras.layers.Multiply()([x, tf.keras.layers.Activation('sigmoid')(x)])

def round_filters(filters, multiplier):
    depth_divisor = 8
    min_depth = None
    min_depth = min_depth or depth_divisor
    filters = filters * multiplier
    new_filters = max(min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor)
    if new_filters < 0.9 * filters:
        new_filters += depth_divisor
    return int(new_filters)

def round_repeats(repeats, multiplier):
    if not multiplier:
        return repeats
    return int(math.ceil(multiplier * repeats))

def SEBlock(inputs, input_channels, ratio=0.25):

    global gbranch
    num_reduced_filters = max(1, int(input_channels * ratio))
    branch = tf.keras.layers.GlobalAveragePooling2D()(inputs)
    # branch = tf.keras.layers.Lambda(lambda branch: tf.expand_dims(input=branch, axis=1))(branch)
    gbranch = branch
    branch = tf.keras.layers.Reshape((1, 1, int_shape(branch)[1]))(branch)
    # branch = tf.keras.layers.Lambda(lambda branch: tf.expand_dims(input=branch, axis=1))(branch)
    branch = tf.keras.layers.Conv2D(filters=num_reduced_filters, kernel_size=(1, 1), strides=1, padding="same", kernel_initializer=kernel_initializer)(branch)
    branch = swish(branch)
    branch = tf.keras.layers.Conv2D(filters=input_channels, kernel_size=(1, 1), strides=1, padding='same', kernel_initializer=kernel_initializer)(branch)
    branch = tf.keras.layers.Activation('sigmoid')(branch)
    output = tf.keras.layers.Multiply()([inputs, branch])

    return output

def MBConv(in_channels, out_channels, expansion_factor, stride, k, drop_connect_rate, inputs, training=False):
    x = tf.keras.layers.Conv2D(filters=in_channels * expansion_factor,kernel_size=(1, 1),strides=1,padding="same", kernel_initializer=kernel_initializer)(inputs)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    x = swish(x)
    x = tf.keras.layers.DepthwiseConv2D(kernel_size=(k, k), strides=stride, padding="same", kernel_initializer=kernel_initializer)(x)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    x = SEBlock(x, in_channels*expansion_factor)
    x = swish(x)
    x = tf.keras.layers.Conv2D(filters=out_channels,kernel_size=(1, 1),strides=1,padding="same", kernel_initializer=kernel_initializer)(x)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    if stride == 1 and in_channels == out_channels:
        if drop_connect_rate:
            x = tf.keras.layers.Dropout(rate=drop_connect_rate)(x, training=training)
        x = tf.keras.layers.Add()([x, inputs])

    return x

def build_mbconv_block(inputs, in_channels, out_channels, layers, stride, expansion_factor, k, drop_connect_rate, training):

    x = inputs
    for i in range(layers):
        if i == 0:
            x = MBConv(in_channels=in_channels, out_channels=out_channels, expansion_factor=expansion_factor,
                       stride=stride, k=k, drop_connect_rate=drop_connect_rate, inputs=x, training=training)
        else:
            x = MBConv(in_channels=out_channels, out_channels=out_channels, expansion_factor=expansion_factor,
                       stride=1, k=k, drop_connect_rate=drop_connect_rate, inputs=x, training=training)

    return x


def EfficientNet(inputs, width_coefficient, depth_coefficient, dropout_rate, drop_connect_rate=0.2, training=False):

    features = []

    x = tf.keras.layers.Conv2D(filters=round_filters(32, width_coefficient),kernel_size=(3, 3),strides=2, padding="same", kernel_initializer=kernel_initializer) (inputs)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    x = swish(x)

    x = build_mbconv_block(x, in_channels=round_filters(32, width_coefficient),
                           out_channels=round_filters(16, width_coefficient),
                           layers=round_repeats(1, depth_coefficient),
                           stride=1,
                           expansion_factor=1, k=3,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(16, width_coefficient),
                           out_channels=round_filters(24, width_coefficient),
                           layers=round_repeats(2, depth_coefficient),
                           stride=1,
                           expansion_factor=6, k=3,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(24, width_coefficient),
                           out_channels=round_filters(40, width_coefficient),
                           layers=round_repeats(2, depth_coefficient),
                           stride=2,
                           expansion_factor=6, k=5,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(40, width_coefficient),
                           out_channels=round_filters(80, width_coefficient),
                           layers=round_repeats(3, depth_coefficient),
                           stride=2,
                           expansion_factor=6, k=3,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(80, width_coefficient),
                           out_channels=round_filters(112, width_coefficient),
                           layers=round_repeats(3, depth_coefficient),
                           stride=1,
                           expansion_factor=6, k=5,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(112, width_coefficient),
                           out_channels=round_filters(192, width_coefficient),
                           layers=round_repeats(4, depth_coefficient),
                           stride=2,
                           expansion_factor=6, k=5,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = build_mbconv_block(x, in_channels=round_filters(192, width_coefficient),
                           out_channels=round_filters(320, width_coefficient),
                           layers=round_repeats(1, depth_coefficient),
                           stride=1,
                           expansion_factor=6, k=3,
                           drop_connect_rate=drop_connect_rate,
                           training=training)
    features.append(x)

    x = tf.keras.layers.Conv2D(filters=round_filters(1280, width_coefficient), kernel_size=(1, 1), strides=1, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    x = swish(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(rate=dropout_rate)(x, training=training)
    x = tf.keras.layers.Dense(units=1, activation=tf.keras.activations.softmax)(x)

    return x, features


def efficient_net_b0(inputs, training):
    return EfficientNet(inputs,
                        width_coefficient=1.0,
                        depth_coefficient=1.0,
                        dropout_rate=0.2,
                        drop_connect_rate=0.2,
                        training=training)

def up_sample(inputs, training=True):
    x = tf.keras.layers.UpSampling2D()(inputs)
    x = tf.keras.layers.BatchNormalization()(x, training=training)
    x = tf.keras.layers.ReLU()(x)
    return x

def singlePosenet(inputs, outc, training=True):

    _, features =  efficient_net_b0(inputs=inputs, training=training)

    # [ 1/2, 1/4, 1/8, 1/8, 1/16]
    outputs = []
    for i, name in enumerate(features):
        x = features[i]
        if x.shape[1] > inputs.shape[1] // 4:
            continue
        while x.shape[1] < (inputs.shape[1]//4):
            x = up_sample(x, training)
        outputs.append(x)

    quater_res = tf.keras.layers.Concatenate()(outputs)
    quater_res = tf.keras.layers.Conv2D(512, 3, 1, 'same', activation=tf.nn.relu)(quater_res)
    quater_res = tf.keras.layers.Conv2D(512, 3, 1, 'same', activation=tf.nn.relu)(quater_res)
    quater_res_out = tf.keras.layers.Conv2D(outc, 1, 1, 'same', name='quater', activation=None)(quater_res)

    half_res = up_sample(quater_res, training)
    half_res = tf.keras.layers.Conv2D(512, 3, 1, 'same', activation=tf.nn.relu)(half_res)
    half_res = tf.keras.layers.Conv2D(512, 3, 1, 'same', activation=tf.nn.relu)(half_res)
    half_res_out = tf.keras.layers.Conv2D(outc, 1, 1, 'same', name='half', activation=None)(half_res)

    return quater_res_out, half_res_out

I'm not so sure this model also works well with your saved model.

@murdockhou
Copy link
Author

@alanchiao it works and load my saved model successfully too. But the result of this model have a little bit difference compare with original model. Anyway, thanks for your help again.

@murdockhou
Copy link
Author

@alanchiao I have optimize model for 5 epochs, and the params I used is this:

pruning_params = {
    'pruning_schedule': sparsity.PolynomialDecay(
        initial_sparsity=0.50,
        final_sparsity=0.90,
        begin_step=begin_step,
        end_step=end_step,
        frequency=1000)
}
pruning_model = sparsity.prune_low_magnitude(loaded_model, **pruning_params)

After I finish this, the pruning model saved and convert it to .pb file to test inference time, but the time cost is as same as the non-pruning original model, so, what's happend in prunning func? Thanks~

@murdockhou murdockhou reopened this Dec 12, 2019
@alanchiao
Copy link

The latency should be equivalent, as described here https://www.tensorflow.org/model_optimization/guide/pruning#overview.

If you follow #173, you can see when framework support will be released, resulting in latency improvements.

@alanchiao alanchiao added the technique:pruning Regarding tfmot.sparsity.keras APIs and docs label Feb 6, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
technique:pruning Regarding tfmot.sparsity.keras APIs and docs
Projects
None yet
Development

No branches or pull requests

3 participants