In [1]:
#@title Authorize `wandb`
HAS_WANDB_ACCOUNT = True #@param ["True", "False"] {type:"raw"}
import wandb
if not HAS_WANDB_ACCOUNT:
    wandb.login(anonymous='allow')
else:
    wandb.login()

In [2]:
import tensorflow as tf
from tensorflow.keras.models import load_model
import tensorflow_model_optimization as tfmot
from pathlib import Path

In [40]:
import tempfile
import numpy as np

In [3]:
#@title training parameters
keras = tf.keras
l = keras.layers

batch_size = 128 #@param {type:"integer"}
num_classes = 10 #@param {type:"integer"}
epochs = 12 #@param {type:"integer"}
# input image dimensions
img_rows = 28 #@param {type:"integer"}
img_cols = 28 #@param {type:"integer"}
input_shape = (img_rows, img_cols, 1)

save_path = "./saved/base_mnist.h5" #@param {type:"string"}
save_path_pruning_model =  "./saved/pruning_mnist.h5" #@param {type:"string"}

In [4]:
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
Path(save_path_pruning_model).parent.mkdir(parents=True, exist_ok=True)

In [5]:
# Load MNIST dataset.
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Add a channel dimension.
train_images = train_images.reshape(-1, img_rows, img_cols, 1)
test_images = test_images.reshape(-1, img_rows, img_cols, 1)


In [6]:
def build_model(input_shape):
    inp = tf.keras.Input(shape=input_shape)
    x = l.Conv2D(32, 5, padding='same', activation='relu')(inp)
    x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
    x = l.BatchNormalization()(x)
    x = l.Conv2D(64, 5, padding='same', activation='relu')(x)
    x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
    x = l.Flatten()(x)
    x = l.Dense(1024, activation='relu')(x)
    x = l.Dropout(0.4)(x)
    out = l.Dense(num_classes, activation='softmax')(x)

    return tf.keras.models.Model([inp], [out])

In [7]:
model = build_model(input_shape)
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 32)        832       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         
_________________________________________________________________
batch_normalization (BatchNo (None, 14, 14, 32)        128       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 64)        51264     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 3136)              0     

In [8]:
dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
dataset = dataset.batch(batch_size=batch_size)

In [9]:
#@title Train a simple CNN without any pruning
wandb_run_id = "vanilla-training-cnn" #@param {type:"string"}
if HAS_WANDB_ACCOUNT:
        wandb.init(entity='ilab', project='tensorflow_pruning', id=wandb_run_id)
else:
    wandb.init(id=wandb_run_id)

In [14]:
optimizer = tf.keras.optimizers.Adam()
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [15]:
features, labels = next(iter(dataset))
print(features.shape)
print(labels.shape)

(128, 28, 28, 1)
(128,)


In [16]:
for epoch in range(epochs):
    for x, y in dataset:
        with tf.GradientTape() as tape:
            y_pred = model([x])
            loss_value = loss(y, y_pred)
            wandb.log({'epoch': epoch, 'loss': loss_value})
            gradients = tape.gradient(loss_value, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        

model.save(save_path)

In [23]:
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
test_dataset = test_dataset.batch(batch_size=16)

In [24]:
features, labels = next(iter(test_dataset))
print(features.shape)
print(labels.shape)

(16, 28, 28, 1)
(16,)


In [25]:
test_accuracy = tf.keras.metrics.Accuracy()
def evaluate_model(eval_model):
    for x, y in test_dataset:
        logits = eval_model([x])
        prediction = tf.argmax(logits, axis=1, output_type=tf.int32)
        test_accuracy(prediction, y)
    
    print("Test set accuracy: {:.3%}".format(test_accuracy.result()))  
    

In [26]:
evaluate_model(model)

Test set accuracy: 99.150%


In [27]:
#@title Take a trained network, prune it with more training
target_sparsity = 0.5 #@param {type:"number"}
begin_step = 0 #@param {type:"integer"}
end_step =  -1 #@param {type:"integer"}
frequency = 1 #@param {type:"integer"}
epochs = 4 #@param {type:"integer"}
wandb_run_id = "pruning-trained-net" #@param {type:"string"}

if HAS_WANDB_ACCOUNT:
        wandb.init(entity='ilab', project='tensorflow_pruning', id=wandb_run_id)
else:
    wandb.init(id=wandb_run_id)
    
# Define pruning schedule
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
          target_sparsity=target_sparsity,
          begin_step=begin_step,
          end_step=end_step,
          frequency=frequency
      )
}

In [28]:
loaded_model = load_model(save_path)
loaded_model_weights = loaded_model.get_weights()

base_model = build_model(input_shape)
base_model.set_weights(loaded_model_weights)
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)

Instructions for updating:
Please use `layer.add_weight` method instead.


In [31]:
# Pruning
unused_arg = -1

log_dir = tempfile.mkdtemp()
model.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(model)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(model)

In [33]:
#optimizer = tf.keras.optimizers.Adam()
step_callback.on_train_begin() # run pruning callback
for epoch in range(epochs):
    log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback
    for x, y in dataset:
        step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback
        with tf.GradientTape() as tape:
            y_pred = model_for_pruning([x])
            loss_value = loss(y, y_pred)
            wandb.log({'epoch': epoch, 'loss': loss_value})
            gradients = tape.gradient(loss_value, model_for_pruning.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model_for_pruning.trainable_variables))
    step_callback.on_epoch_end(batch=unused_arg) # run pruning callback
        

model_for_pruning.save(save_path_pruning_model)

In [34]:
evaluate_model(model_for_pruning)

Test set accuracy: 98.925%


In [36]:
#docs_infra: no_execute
%load_ext tensorboard
%tensorboard --logdir={log_dir}

In [38]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

Saved pruned Keras model to: /tmp/tmpj84dws08.h5


In [39]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()

_, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(quantized_and_pruned_tflite_file, 'wb') as f:
  f.write(quantized_and_pruned_tflite_model)

print("TF Lite model saved to: ",quantized_and_pruned_tflite_file)

TF Lite model saved to:  /tmp/tmp9pc4o_q3.tflite


In [42]:
def evaluate_tflite_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on ever y image in the "test" dataset.
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print('Evaluated on {n} results so far.'.format(n=i))
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    print('\n')
    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

In [43]:
interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_tflite_model(interpreter)

print('Pruned and quantized TFLite test_accuracy:', test_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned and quantized TFLite test_accuracy: 0.9871
