In [None]:
import tensorflow as tf
from tensorflow import keras

import numpy as np
import tempfile
import zipfile
import os

import tensorflow_model_optimization as tfmot
import matplotlib.pyplot as plt

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.plot.contour import plot_contour

import pickle
import tikzplotlib

**Note:** I have used TF 2.4.0 for the experiments. 

In [None]:
tf.__version__

# 1. Define Baseline Model 

## 1.1. Load MNIST

In [None]:
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

## 1.2. Get Baseline Model

In [None]:
def get_baseline_model(): 
    
    # Define the model architecture.
    model = keras.Sequential([
      keras.layers.InputLayer(input_shape=(28, 28)),
      keras.layers.Reshape(target_shape=(28, 28, 1)),
      keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
      keras.layers.MaxPooling2D(pool_size=(2, 2)),
      keras.layers.Flatten(),
      keras.layers.Dense(10)
    ])
    
    # Train the digit classification model
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

    model.fit(
      train_images,
      train_labels,
      epochs=5,
      validation_split=0.1,
    )
    
    return model 

# 2. Compression 

## 2.1. Auxiliary Functions 

### 2.1.1. Get size of zipped model 

In [None]:
def get_zipped_model_size(model): 
    if isinstance(model, bytes): 
        _, file = tempfile.mkstemp('.tflite')
        with open(file, 'wb') as f:
            f.write(model)
    else: 
        _, file = tempfile.mkstemp('.h5')
        tf.keras.models.save_model(model, file, include_optimizer=False)
    print('Saved baseline model to:', file)
    
    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)
    return os.path.getsize(zipped_file)

### 2.1.2. Evaluate the accuracy of the model (for binaries) 

In [None]:
def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]
    
    # Run predictions on ever y image in the "test" dataset.
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        # if i % 1000 == 0:
        #    print('Evaluated on {n} results so far.'.format(n=i))
    
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)
    
    print('\n')
    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

### 2.1.3. Get accuracy of the model (both for Keras models and for binaries)

In [None]:
def get_model_accuracy(model): 
    if isinstance(model, bytes): 
        interpreter = tf.lite.Interpreter(model_content=model)
        interpreter.allocate_tensors()
        return evaluate_model(interpreter)
    return model.evaluate(test_images, test_labels, verbose=1)[1]

# 2.2. Compression Function 

In [None]:
def apply_compression(baseline_model, initial_sparsity=0.5, final_sparsity=0.8, post_train_quant=False, qaware=False, clusters=0): 
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    
    # Create a clone of the baseline model (to avoid various call affecting each other)
    model = keras.models.clone_model(baseline_model)
    model.build((None, 10)) # replace 10 with number of variables in input layer
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    model.set_weights(baseline_model.get_weights())
    
    # Compute end step to finish pruning after 2 epochs.
    batch_size = 128
    epochs = 2
    validation_split = 0.1 # 10% of training set will be used for validation set. 

    num_images = train_images.shape[0] * (1 - validation_split)
    end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
    
    # Define model for pruning. 
    pruning_params = {
          'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=initial_sparsity,
                                                                   final_sparsity=final_sparsity,
                                                                   begin_step=0,
                                                                   end_step=end_step)
    }

    model_for_pruning = prune_low_magnitude(model, **pruning_params)

    # `prune_low_magnitude` requires a recompile.
    model_for_pruning.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    
    callbacks = [
      tfmot.sparsity.keras.UpdatePruningStep()
    ]
    
    # Note: We train with fewer weights, therefore training is faster as well. 
    model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                      callbacks=callbacks)
    
    # Ensure that TFLite does not affect accuracy 
    model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    
    if clusters > 1:  
        clustering_params = {
          'number_of_clusters': clusters,
          'cluster_centroids_init': CentroidInitialization.LINEAR
        }

        # Cluster a whole model
        model_for_export = cluster_weights(model_for_export, **clustering_params)

        # Use smaller learning rate for fine-tuning clustered model
        # TODO: Is the learning rate also a hyperparameter? 
        opt = tf.keras.optimizers.Adam(learning_rate=1e-5)


        model_for_export.compile(
          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
          optimizer=opt,
          metrics=['accuracy'])

        model_for_export.fit(
          train_images,
          train_labels,
          batch_size=500,
          epochs=1,
          validation_split=0.1)
        
        model_for_export = tfmot.clustering.keras.strip_clustering(model_for_export)
        
    if qaware:
        # q_aware stands for for quantization aware.
        model_for_export = tfmot.quantization.keras.quantize_model(model_for_export)
        model_for_export.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
        
        train_images_subset = train_images[0:1000] # out of 60000
        train_labels_subset = train_labels[0:1000]

        model_for_export.fit(train_images_subset, train_labels_subset, batch_size=500, epochs=1, validation_split=0.1)
        
    # print("Accuracy before TFLite:", model_for_pruning.evaluate(test_images, test_labels, verbose=1))
    converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
    # TODO: Might also want to add representative samples to the post-training quantisation 
    # https://www.tensorflow.org/model_optimization/guide/quantization/post_training
    if post_train_quant or qaware: 
        # TODO: Or this could be [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] instead. 
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        
    pruned_tflite_model = converter.convert()
    return pruned_tflite_model

# 3. Bayesian Optimisation

In [None]:
from ax import optimize
import sys
from ax.service.ax_client import AxClient

### Prepare the baseline model

In [None]:
baseline_model = get_baseline_model()

### Save/Load Baseline Model
Uncomment the cells below to save/load the baseline model.

In [None]:
# baseline_model.save("baseline_models/baseline_model_1601.h5")

In [None]:
baseline_model = keras.models.load_model("baseline_models/baseline_model_1601.h5")

### Get baseline size and accuracy

In [None]:
BASELINE_SIZE = get_zipped_model_size(baseline_model)
BASELINE_ACCURACY = get_model_accuracy(baseline_model)

In [None]:
print("Baseline size:", BASELINE_SIZE)
print("Baseline accuracy:", BASELINE_ACCURACY)

## 3.1. BO without any constraint 

In [None]:
ax_client = AxClient()

In [None]:
def evaluate_fun(p): 
    final_sparsity = p.get("final_sparsity")
    clusters = p.get("clusters")
    post_train_quant = p.get("post_train_quant")
    qaware = p.get("qaware")
    
    print(final_sparsity)

    accuracies = [] 
    sizes = [] 
    
    for x in range(3): 
        res = apply_compression(baseline_model, 
                            final_sparsity=final_sparsity, 
                            clusters=clusters,
                            post_train_quant=post_train_quant, 
                            qaware=qaware)
        sizes.append(get_zipped_model_size(res))
        accuracies.append(get_model_accuracy(res))
    
    final_size = (np.mean(sizes), np.std(sizes))
    return {"size": final_size, "accuracy": (np.mean(accuracies), np.std(accuracies))}

In [None]:
ax_client.create_experiment(
    name="compression_experiment",
    parameters=[
          {
            "name": "final_sparsity",
            "type": "range",
            "value_type": "float",
            "bounds": [0.0, 0.999],
          }, 
          {
            "name": "clusters",
            "type": "range",
            "value_type": "int",
            "bounds": [2, 100],
          }, 
          {
            "name": "post_train_quant",
            "type": "choice",
            "value_type": "bool",
            "values": [True, False],
          }, 
          {
            "name": "qaware",
            "type": "choice",
            "value_type": "bool",
            "values": [True, False],
          }, 
        
        ],
    objective_name="size",
    minimize=True
)

In [None]:
import time
start = time.time()

for i in range(50):
    print("Iteration", i)
    parameters, trial_index = ax_client.get_next_trial()
    # Local evaluation here can be replaced with deployment to external system.
    ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate_fun(parameters))
    
print("TOTAL TIME TAKEN:", time.time() - start, "seconds")

In [None]:
ax_client.save_to_json_file(filepath="results/baseline_1601.json")

## 3.2. BO with `outcome_constraint`

In [None]:
ax_client = AxClient()

In [None]:
def evaluate_fun(p): 
    final_sparsity = p.get("final_sparsity")
    clusters = p.get("clusters")
    post_train_quant = p.get("post_train_quant")
    qaware = p.get("qaware")

    accuracies = [] 
    sizes = [] 
    
    for x in range(3): 
        res = apply_compression(baseline_model, 
                            final_sparsity=final_sparsity, 
                            clusters=clusters,
                            post_train_quant=post_train_quant, 
                            qaware=qaware)
        sizes.append(get_zipped_model_size(res))
        accuracies.append(get_model_accuracy(res))
    
    final_size = (np.mean(sizes), np.std(sizes))
    return {"size": final_size, "accuracy": (np.mean(accuracies), np.std(accuracies))}

In [None]:
ax_client.create_experiment(
    name="compression_experiment",
    parameters=[
          {
            "name": "final_sparsity",
            "type": "range",
            "value_type": "float",
            "bounds": [0.0, 0.999],
          }, 
          {
            "name": "clusters",
            "type": "range",
            "value_type": "int",
            "bounds": [2, 100],
          }, 
          {
            "name": "post_train_quant",
            "type": "choice",
            "value_type": "bool",
            "values": [True, False],
          }, 
          {
            "name": "qaware",
            "type": "choice",
            "value_type": "bool",
            "values": [True, False],
          }, 
        
        ],
    objective_name="size",
    minimize=True,
    outcome_constraints=["accuracy >= 0.95"]
)

In [None]:
import time
start = time.time()

for i in range(50):
    print("Iteration", i)
    parameters, trial_index = ax_client.get_next_trial()
    # Local evaluation here can be replaced with deployment to external system.
    ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate_fun(parameters))
    
print("TOTAL TIME TAKEN:", time.time() - start, "seconds")

In [None]:
ax_client.save_to_json_file(filepath="results/outconst_1601.json")

## 3.3. BO with ReturnInf

In [None]:
ax_client = AxClient()

In [None]:
def evaluate_fun(p): 
    final_sparsity = p.get("final_sparsity")
    clusters = p.get("clusters")
    post_train_quant = p.get("post_train_quant")
    qaware = p.get("qaware")

    accuracies = [] 
    sizes = [] 
    
    for x in range(3): 
        res = apply_compression(baseline_model, 
                            final_sparsity=final_sparsity, 
                            clusters=clusters,
                            post_train_quant=post_train_quant, 
                            qaware=qaware)
        sizes.append(get_zipped_model_size(res))
        accuracies.append(get_model_accuracy(res))
    
    if (np.mean(accuracies)) < 0.95: 
        final_size = (BASELINE_SIZE, 0)
    else: 
        final_size = (np.mean(sizes), np.std(sizes))
    return {"size": final_size, "accuracy": (np.mean(accuracies), np.std(accuracies))}

In [None]:
ax_client.create_experiment(
    name="compression_experiment",
    parameters=[
          {
            "name": "final_sparsity",
            "type": "range",
            "value_type": "float",
            "bounds": [0.0, 0.999],
          }, 
          {
            "name": "clusters",
            "type": "range",
            "value_type": "int",
            "bounds": [2, 100],
          }, 
          {
            "name": "post_train_quant",
            "type": "choice",
            "value_type": "bool",
            "values": [True, False],
          }, 
          {
            "name": "qaware",
            "type": "choice",
            "value_type": "bool",
            "values": [True, False],
          }, 
        
        ],
    objective_name="size",
    minimize=True,
    outcome_constraints=["accuracy >= 0.95"]
)

In [None]:
import time
start = time.time()

for i in range(50):
    print("Iteration", i)
    parameters, trial_index = ax_client.get_next_trial()
    # Local evaluation here can be replaced with deployment to external system.
    ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate_fun(parameters))
    
print("TOTAL TIME TAKEN:", time.time() - start, "seconds")

In [None]:
ax_client.save_to_json_file(filepath="results/returninf_1601.json")

## 3.4. BO with Linear Combination 

In [None]:
ax_client = AxClient()

In [None]:
def evaluate_fun(p): 
    final_sparsity = p.get("final_sparsity")
    clusters = p.get("clusters")
    post_train_quant = p.get("post_train_quant")
    qaware = p.get("qaware")

    accuracies = [] 
    sizes = [] 
    
    for x in range(3): 
        res = apply_compression(baseline_model, 
                            final_sparsity=final_sparsity, 
                            clusters=clusters,
                            post_train_quant=post_train_quant, 
                            qaware=qaware)
        sizes.append(get_zipped_model_size(res))
        accuracies.append(get_model_accuracy(res))
    
    final_size = (np.mean(sizes), np.std(sizes))
    final_accuracy = (np.mean(accuracies), np.std(accuracies))
    linear_comb = final_size[0] + (BASELINE_ACCURACY - final_accuracy[0]) * BASELINE_SIZE
    
    return {"linear_combination": linear_comb, "size": final_size, "accuracy": final_accuracy}

In [None]:
ax_client.create_experiment(
    name="compression_experiment",
    parameters=[
          {
            "name": "final_sparsity",
            "type": "range",
            "value_type": "float",
            "bounds": [0.0, 0.999],
          }, 
          {
            "name": "clusters",
            "type": "range",
            "value_type": "int",
            "bounds": [2, 100],
          }, 
          {
            "name": "post_train_quant",
            "type": "choice",
            "value_type": "bool",
            "values": [True, False],
          }, 
          {
            "name": "qaware",
            "type": "choice",
            "value_type": "bool",
            "values": [True, False],
          }, 
        
        ],
    objective_name="linear_combination",
    minimize=True
)

In [None]:
import time
start = time.time()

for i in range(50):
    print("Iteration", i)
    parameters, trial_index = ax_client.get_next_trial()
    # Local evaluation here can be replaced with deployment to external system.
    ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate_fun(parameters))
    
print("TOTAL TIME TAKEN:", time.time() - start, "seconds")

In [None]:
ax_client.save_to_json_file(filepath="results/lin_comb_1701.json")

# 4. Other Optimisation Methods 

## 4.1. Random Search 

The evaluation function remains similar: 

In [None]:
def evaluate_fun(final_sparsity, clusters, post_train_quant, qaware): 
    
    accuracies = [] 
    sizes = [] 
    
    for x in range(3): 
        res = apply_compression(baseline_model, 
                            final_sparsity=final_sparsity, 
                            clusters=clusters,
                            post_train_quant=post_train_quant, 
                            qaware=qaware)
        sizes.append(get_zipped_model_size(res))
        accuracies.append(get_model_accuracy(res))
    
    final_size = (np.mean(sizes), np.std(sizes))
    return {"size": final_size, "accuracy": (np.mean(accuracies), np.std(accuracies))}

In [None]:
def get_random_parameters(): 
    return {
        "final_sparsity": np.random.rand(), 
        "clusters": np.random.randint(2,101), 
        "post_train_quant": bool(np.random.randint(0,2)), 
        "qaware": bool(np.random.randint(0,2))
    }

In [None]:
import time
start = time.time()

results = [] 

for i in range(55):
    print("Iteration", i)
    parameters = get_random_parameters()
    print("Parameters", parameters)
    res = (parameters, evaluate_fun(**parameters))
    results.append(res)
    
print("TOTAL TIME TAKEN:", time.time() - start, "seconds")

In [None]:
with open('results/random_search_results_1701.pickle', 'wb') as file:
    pickle.dump(results, file, protocol=pickle.HIGHEST_PROTOCOL)

## 4.2. Grid Search (Exhaustive Search)

In [None]:
def evaluate_fun(final_sparsity, clusters, post_train_quant, qaware): 
    
    accuracies = [] 
    sizes = [] 
    
    for x in range(3): 
        res = apply_compression(baseline_model, 
                            final_sparsity=final_sparsity, 
                            clusters=clusters,
                            post_train_quant=post_train_quant, 
                            qaware=qaware)
        sizes.append(get_zipped_model_size(res))
        accuracies.append(get_model_accuracy(res))
    
    final_size = (np.mean(sizes), np.std(sizes))
    return {"size": final_size, "accuracy": (np.mean(accuracies), np.std(accuracies))}

In [None]:
import time
start = time.time()

results = [] 

for final_sparsity in np.linspace(0, 0.99, num=7): 
    for clusters in range(2, 101, 30): 
        for post_train_quant in [True, False]: 
            for qaware in [True, False]: 
                parameters = [final_sparsity, clusters, post_train_quant, qaware]
                print("Testing with", parameters)
                res = evaluate_fun(float(final_sparsity), clusters, post_train_quant, qaware)
                results.append((parameters, res))
                
print("TOTAL TIME TAKEN:", time.time() - start, "seconds")

### Saving the Output

In [None]:
with open('results/grid_search_results_1601.pickle', 'wb') as file:
    pickle.dump(results, file, protocol=pickle.HIGHEST_PROTOCOL)

# Sensitivity Analysis

In [None]:
plt.rcParams["font.size"] = "13"

## Pruning sparsity

In [None]:
model_sizes = [] 
model_accuracies = [] 
all_models = [] 
in_sp = []
fin_sp = []

for final_sparsity in np.linspace(0.80, 0.99, 10): 
        fin_sp.append(final_sparsity)
        print("final_sparsity =", str(final_sparsity))
        model_with_pruning = apply_compression(baseline_model, 
                                           initial_sparsity=0.5,
                                           final_sparsity=float(final_sparsity))
        all_models.append(model_with_pruning)
        model_sizes.append(get_zipped_model_size(model_with_pruning))
        model_accuracies.append(get_model_accuracy(model_with_pruning))

In [None]:
t = np.linspace(0.80, 0.99, 10)
data1 = model_sizes
data2 = [x*100 for x in model_accuracies]

fig, ax1 = plt.subplots(figsize=(6.5,5))

color = 'tab:red'
ax1.set_xlabel('Level of sparsity')
ax1.set_ylabel('Model size (bytes)', color=color)
ax1.plot(t, data1, color=color, linewidth=2)
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

color = 'tab:blue'
ax2.set_ylabel('Inference accuracy (%)', color=color)  # we already handled the x-label with ax1
ax2.plot(t, data2, color=color, linewidth=2)
ax2.tick_params(axis='y', labelcolor=color)

fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.show()

In [None]:
fig.savefig("graphs/sparsity.svg")

## Number of clusters 

In [None]:
wc_models = [] 
wc_accuracies = [] 
wc_sizes = [] 
wc_options = [2 ,5, 10, 16, 20, 25, 50]


for n in wc_options: 
    res = apply_compression(baseline_model, initial_sparsity=0.0, final_sparsity=0.0, clusters=n)
    wc_accuracies.append(get_model_accuracy(res))
    wc_sizes.append(get_zipped_model_size(res))
    wc_models.append(res)

In [None]:
t = wc_options
data1 = wc_sizes
data2 = [x*100 for x in wc_accuracies]

fig, ax1 = plt.subplots(figsize=(6.5,5))

color = 'tab:red'
ax1.set_xlabel('Number of clusters')
ax1.set_ylabel('Model size (bytes)', color=color)
ax1.plot(t, data1, color=color, linewidth=2)
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

color = 'tab:blue'
ax2.set_ylabel('Inference accuracy (%)', color=color)  # we already handled the x-label with ax1
ax2.plot(t, data2, color=color, linewidth=2)
ax2.tick_params(axis='y', labelcolor=color)

fig.tight_layout()  # otherwise the right y-label is slightly clipped

In [None]:
fig.savefig("graphs/clusters.svg")

# Weight distribution experiment

In [None]:
all_weights = [] 
for i in range(4): 
    all_weights += list(baseline_model.get_weights()[i].flatten())

In [None]:
plt.rcParams["font.size"] = "13"
plt.figure(figsize=(6,5))
plt.hist(all_weights, bins=300)
plt.xlabel("Weight value")
plt.ylabel("Count")
plt.savefig("graphs/weights.svg")
plt.show()

In [None]:
def get_pruned_model(baseline_model, final_sparsity=0.8): 
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    
    # Create a clone of the baseline model (to avoid various call affecting each other)
    model = keras.models.clone_model(baseline_model)
    model.build((None, 10)) # replace 10 with number of variables in input layer
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    model.set_weights(baseline_model.get_weights())
    
    # Compute end step to finish pruning after 2 epochs.
    batch_size = 128
    epochs = 2
    validation_split = 0.1 # 10% of training set will be used for validation set. 

    num_images = train_images.shape[0] * (1 - validation_split)
    end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
    
    # Define model for pruning. 
    pruning_params = {
          'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=initial_sparsity,
                                                                   final_sparsity=final_sparsity,
                                                                   begin_step=0,
                                                                   end_step=end_step)
    }

    model_for_pruning = prune_low_magnitude(model, **pruning_params)

    # `prune_low_magnitude` requires a recompile.
    model_for_pruning.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    
    callbacks = [
      tfmot.sparsity.keras.UpdatePruningStep()
    ]
    
    # Note: We train with fewer weights, therefore training is faster as well. 
    model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                      callbacks=callbacks)
    
    # Ensure that TFLite does not affect accuracy 
    model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    return model_for_export

In [None]:
pruned_model = get_pruned_model(baseline_model, final_sparsity=0.8)

In [None]:
all_weights_after_pruning = [] 
for i in range(4): 
    all_weights_after_pruning += list(pruned_model.get_weights()[i].flatten())
all_weights_after_pruning = list(filter(lambda x: x != 0, all_weights_after_pruning))

In [None]:
plt.rcParams["font.size"] = "13"
plt.figure(figsize=(6,5))
plt.hist(all_weights_after_pruning, bins=300)
plt.xlabel("Weight value")
plt.ylabel("Count")
plt.savefig("graphs/weights_after_pruning.png")
plt.show()