In [1]:
import numpy as np

In [2]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.initializers import Initializer

In [3]:
# Set a pruning threshold (percentage)
pruning_threshold = 20

In [4]:
# Load the MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Preprocess the data
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Convert labels to one-hot encoding
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [5]:
# Custom initializers
class CustomKernelInitializer(Initializer):
    def __init__(self, weights):
        self.weights = weights

    def __call__(self, shape, dtype=None):
        return self.weights

class CustomBiasInitializer(Initializer):
    def __init__(self, biases):
        self.biases = biases

    def __call__(self, shape, dtype=None):
        return self.biases

In [6]:
def prune_nodes(model, pruning_threshold):

    layers_to_prune = model.layers[:-1]

    # Create a new model with pruned layers
    new_model = Sequential()
    new_model.add(Flatten(input_shape=(28, 28)))

    for i, layer in enumerate(layers_to_prune):
        if isinstance(layer, Dense):
            weights, biases = layer.get_weights()

            if i > 1:
              weights = weights[~old_nodes_to_prune, :]

            # Calculate the sum of input weights to each node
            sum_input_weights = np.sum(np.abs(weights) / np.max(np.abs(weights)), axis=0)

            next_layer_weights, _ = model.layers[i + 1].get_weights()
            sum_output_weights    = np.sum(np.abs(next_layer_weights) / np.max(np.abs(next_layer_weights)), axis=1)

            # Total sum of weights for each node
            total_weights_sum = sum_input_weights + sum_output_weights

            # Determine nodes to prune based on the total sum of weights
            threshold = np.percentile(total_weights_sum, pruning_threshold)
            nodes_to_prune = total_weights_sum  < threshold

            # Prune the weights and biases
            pruned_weights  = weights[:, ~nodes_to_prune]
            pruned_biases   = biases[~nodes_to_prune]

            # Initialize the custom initializers
            kernel_initializer = CustomKernelInitializer(pruned_weights)
            bias_initializer = CustomBiasInitializer(pruned_biases)

            new_model.add(Dense(pruned_weights.shape[1], kernel_initializer = kernel_initializer, bias_initializer = bias_initializer, activation = 'relu'))

            old_nodes_to_prune = nodes_to_prune

    # Get weights and biases of the last layer
    weights, biases = model.layers[-1].get_weights()

    # Prune the weights and biases
    weights  = weights[~nodes_to_prune, :]

    # Initialize the custom initializers
    kernel_initializer = CustomKernelInitializer(weights)
    bias_initializer = CustomBiasInitializer(biases)

    new_model.add(Dense(10, kernel_initializer = kernel_initializer, bias_initializer = bias_initializer, activation='softmax'))

    return new_model

In [7]:
# Build a simple neural network model
original_model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(64, activation='relu'),
    Dense(10, activation='softmax')
])

  super().__init__(**kwargs)


In [8]:
# Compile the model
original_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [9]:
# Train the model
original_model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)

print(original_model.summary())

Epoch 1/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 6ms/step - accuracy: 0.8576 - loss: 0.4884 - val_accuracy: 0.9594 - val_loss: 0.1353
Epoch 2/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 4ms/step - accuracy: 0.9647 - loss: 0.1177 - val_accuracy: 0.9685 - val_loss: 0.1013
Epoch 3/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 8ms/step - accuracy: 0.9774 - loss: 0.0761 - val_accuracy: 0.9702 - val_loss: 0.0971
Epoch 4/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 8ms/step - accuracy: 0.9822 - loss: 0.0575 - val_accuracy: 0.9688 - val_loss: 0.1033
Epoch 5/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 5ms/step - accuracy: 0.9850 - loss: 0.0463 - val_accuracy: 0.9732 - val_loss: 0.0945


None


In [10]:
# Evaluate the origina model
loss, accuracy = original_model.evaluate(x_test, y_test)
print(f'Test accuracy before pruning: {accuracy}')

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.9730 - loss: 0.0950
Test accuracy before pruning: 0.9776999950408936


In [11]:
# Prune the nodes in the model and get the new pruned model
pruned_model = prune_nodes(original_model, pruning_threshold)

print(pruned_model.summary())

0.6994689
0.7966861
(64, 10)
(51, 10)


None


In [12]:
# Compile the pruned model
pruned_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [13]:
# Retrain the pruned model
pruned_model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)

Epoch 1/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 4ms/step - accuracy: 0.9864 - loss: 0.0431 - val_accuracy: 0.9746 - val_loss: 0.0942
Epoch 2/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 4ms/step - accuracy: 0.9907 - loss: 0.0297 - val_accuracy: 0.9755 - val_loss: 0.0886
Epoch 3/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - accuracy: 0.9930 - loss: 0.0217 - val_accuracy: 0.9737 - val_loss: 0.0994
Epoch 4/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 5ms/step - accuracy: 0.9939 - loss: 0.0186 - val_accuracy: 0.9749 - val_loss: 0.1017
Epoch 5/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 4ms/step - accuracy: 0.9941 - loss: 0.0189 - val_accuracy: 0.9755 - val_loss: 0.1012


<keras.src.callbacks.history.History at 0x7ef2bcedce80>

In [14]:
# Evaluate the pruned model
loss, accuracy = pruned_model.evaluate(x_test, y_test)
print(f'Test accuracy after pruning: {accuracy}')

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.9737 - loss: 0.1138
Test accuracy after pruning: 0.9779000282287598
