In [5]:
# Import libraries
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model

#### Basic implementation of Magnitude based pruning

In [20]:
def magnitude_based_pruning(model, pruning_percentage):
    """
    Prunes a model by setting the smallest weights to zero based on magnitude.
    
    Args:
        model: tf.keras model to prune
        pruning_percentage: percentage of weights to prune (0-100)
    Returns:
        pruned_model
    """
    # Get all weight tensors (kernels only)
    weights = [w for w in model.trainable_weights if 'kernel' in w.name]

    if not weights:
        print("No weights found to prune")
        return model

    # Flatten and concatenate all weights
    all_weights = tf.concat([tf.reshape(w, [-1]) for w in weights], axis=0)
    all_weights_abs = tf.abs(all_weights)

    # Compute threshold
    threshold = np.percentile(all_weights_abs.numpy(), pruning_percentage)

    # Prune weights
    total_params = 0
    pruned_params = 0

    for weight in weights:
        weight_shape = weight.shape
        weight_flat = tf.reshape(weight, [-1])
        mask = tf.cast(tf.abs(weight_flat) > threshold, weight.dtype)
        pruned_weights = weight_flat * mask

        # Update weights in model
        weight.assign(tf.reshape(pruned_weights, weight_shape))

    return model

In [23]:
def prune_with_fine_tuning(model, x_train, y_train, x_val, y_val,
                         final_sparsity=0.8, n_iterations=5, epochs_per_iter=2,
                         initial_lr=0.001):
    """
    Complete pruning pipeline with fine-tuning and learning rate reduction
    
    Args:
        model: Compiled Keras model
        x_train, y_train: Training data
        x_val, y_val: Validation data
        final_sparsity: Target sparsity (0-1)
        n_iterations: Number of pruning iterations
        epochs_per_iter: Fine-tuning epochs per iteration
        initial_lr: Initial learning rate
    
    Returns:
        Pruned and fine-tuned model
    """

    print("Initial evaluation:")
    model.evaluate(x_val, y_val, verbose=2)

    optimizer = tf.keras.optimizers.Adam(learning_rate=initial_lr)
    
    for i in range(n_iterations):
        current_target = (i + 1) * (final_sparsity / n_iterations)
        current_lr = initial_lr * (0.1 ** i)  # Reduce LR by 1/10 each iteration
        print(f"\nPruning iteration {i+1}/{n_iterations}")
        print(f"Target sparsity: {current_target:.1%}")
        
        # Update learning rate
        optimizer.learning_rate = current_lr
        
        # Prune model
        model= magnitude_based_pruning(model, current_target * 100)
        
        # Fine-tune with reduced learning rate
        print("Fine-tuning...")
        model.fit(x_train, y_train,
                epochs=epochs_per_iter,
                validation_data=(x_val, y_val),
                verbose=1)
    
    print("\nFinal evaluation:")
    model.evaluate(x_val, y_val, verbose=2)
    print(f"Final sparsity: {final_sparsity} %")
    
    return model

In [24]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
    
# 2. Build and compile model
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy'])
    
# 3. Initial training
print("Training initial model...")
model.fit(x_train, y_train, epochs=2, validation_split=0.2)
    
# 4. Prune with fine-tuning
pruned_model = prune_with_fine_tuning(
    model,
    x_train, y_train,
    x_test, y_test,
    final_sparsity=0.8,
    n_iterations=5,
    epochs_per_iter=2,
    initial_lr=0.001
)

Training initial model...
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 12ms/step - accuracy: 0.8879 - loss: 0.3598 - val_accuracy: 0.9640 - val_loss: 0.1210
Initial evaluation:
313/313 - 2s - 6ms/step - accuracy: 0.9643 - loss: 0.1117

Pruning iteration 1/5
Target sparsity: 16.0%
Fine-tuning...
Epoch 1/2
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 11ms/step - accuracy: 0.9708 - loss: 0.0919 - val_accuracy: 0.9750 - val_loss: 0.0816
Epoch 2/2
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 11ms/step - accuracy: 0.9811 - loss: 0.0569 - val_accuracy: 0.9742 - val_loss: 0.0780

Pruning iteration 2/5
Target sparsity: 32.0%
Fine-tuning...
Epoch 1/2
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 11ms/step - accuracy: 0.9872 - loss: 0.0393 - val_accuracy: 0.9761 - val_loss: 0.0809
Epoch 2/2
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 11ms/step - accuracy: 0.9896 - loss: 0.0317 -