Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot save pruned model with MultiHeadAttention Layer #1077

Open
christian-steinmeyer opened this issue Jun 27, 2023 · 2 comments
Open

Cannot save pruned model with MultiHeadAttention Layer #1077

christian-steinmeyer opened this issue Jun 27, 2023 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@christian-steinmeyer
Copy link

christian-steinmeyer commented Jun 27, 2023

Describe the bug
Trying to save a model that wraps a MultiHeadAttention layer in a PruneLowMagnitude, fails with duplicate dataset name.

System information

TensorFlow version (installed from source or binary): 2.13.0rc1

TensorFlow Model Optimization version (installed from source or binary): 0.7.5

Python version: 3.10

Describe the expected behavior
Successful model save.

Describe the current behavior
When saving a pruned model, I get a ValueError: Unable to create dataset (name already exists) on "mask:0".

Code to reproduce the issue

import tensorflow as tf
import tensorflow_model_optimization as tfmot
import tempfile

if __name__ == '__main__':
    # model
    inputs = tf.keras.layers.Input(shape=(28, 28, 3))
    x = tf.keras.layers.Conv2D(filters=128, kernel_size=3, activation='relu')(inputs)
    x = tf.keras.layers.MultiHeadAttention(num_heads=4, key_dim=128)(query=x, value=x, key=x)
    outputs = tf.keras.layers.Flatten()(x)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    model.compile(optimizer='adam', loss='mse')

    # call model to initialize weights
    model(tf.ones((1, 28, 28, 3)))

    # prune model
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.5,
            final_sparsity=0.9,
            begin_step=0,
            end_step=1,
            frequency=1,
        ),
    }
    model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

    with tempfile.TemporaryDirectory() as temp_dir:
        model_for_pruning.save(temp_dir + '/model.h5')  # <-- fails

Potentially related to #661 and #944.

@christian-steinmeyer christian-steinmeyer added the bug Something isn't working label Jun 27, 2023
@christian-steinmeyer
Copy link
Author

Tagging @Xhark as you have worked on similar issues in the past

@dansuh17
Copy link
Member

@Xhark Could you follow-up on this one?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants