In [1]:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, Flatten, Dense
from tensorflow.keras.regularizers import l1


2025-04-06 15:53:24.786391: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-04-06 15:53:24.844799: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-06 15:53:24.846780: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2025-04-06 15:53:24.846786: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc

In [2]:
def build_model(input_shape, n_classes, filters_per_conv_layer, neurons_per_dense_layer):
    x = x_in = Input(input_shape)
    for i, f in enumerate(filters_per_conv_layer):
        x = Conv2D(f, (3, 3), strides=(1, 1), padding='valid',
                   kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001),
                   use_bias=False, name=f'conv_{i}')(x)
        x = BatchNormalization(name=f'bn_conv_{i}')(x)
        x = Activation('relu', name=f'conv_act_{i}')(x)
        x = MaxPooling2D(pool_size=(2, 2), name=f'pool_{i}')(x)
    x = Flatten()(x)
    for i, n in enumerate(neurons_per_dense_layer):
        x = Dense(n, kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001),
                  use_bias=False, name=f'dense_{i}')(x)
        x = BatchNormalization(name=f'bn_dense_{i}')(x)
        x = Activation('relu', name=f'dense_act_{i}')(x)
    x = Dense(n_classes, name='output_dense')(x)
    x_out = Activation('softmax', name='output_softmax')(x)
    return Model(inputs=[x_in], outputs=[x_out], name='baseline_model')


In [3]:
def train_model(model, train_data, val_data, epochs=10):
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    model.fit(train_data, epochs=epochs, validation_data=val_data)
    return model


In [4]:
def prune_model(model, begin_step=0, end_step=2000, final_sparsity=0.5):
    """
    Prune the model using TensorFlow Model Optimization Toolkit.
    Args:
        model: The model to be pruned.
        begin_step: The step at which pruning starts.
        end_step: The step at which pruning ends.
        final_sparsity: The final sparsity level.
    Returns:
        pruned_model: The pruned model.
        callbacks: Callbacks for pruning.
    """
    pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.0,
        final_sparsity=final_sparsity,
        begin_step=begin_step,
        end_step=end_step,
    )
    pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)
    pruned_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
    ]
    return pruned_model, callbacks


In [5]:

# Function to strip and quantize the pruned model
def strip_and_qat_model(pruned_model, train_data, val_data, epochs=5):
    """Activation-aware quantization aware training (QAT) of the pruned model."""
    stripped_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
    quantize_model = tfmot.quantization.keras.quantize_model
    qat_model = quantize_model(stripped_model)
    qat_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    qat_model.fit(train_data, epochs=epochs, validation_data=val_data)
    return qat_model


In [6]:
def export_tflite_model(model, filename='model_quant.tflite'):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()
    with open(filename, "wb") as f:
        f.write(tflite_model)
    print(f"Saved quantized model to: {filename}")


In [7]:
import tensorflow as tf
from tensorflow.keras.utils import to_categorical

# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize the pixel values to [0, 1]
x_train = x_train.astype("float32") / 255.0
x_test  = x_test.astype("float32") / 255.0

# Reshape to add channel dimension (28x28x1)
x_train = x_train.reshape((-1, 28, 28, 1))
x_test  = x_test.reshape((-1, 28, 28, 1))

# Convert labels to one-hot encoding
y_train = to_categorical(y_train, 10)
y_test  = to_categorical(y_test, 10)

# Split off a validation set
from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1, random_state=42)

# Create tf.data.Dataset objects (optional but recommended for performance)
batch_size = 1024

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size)
val_data   = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size)
test_data  = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

# Optional: set number of epochs
n_epochs = 10


2025-04-06 15:53:26.020326: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2025-04-06 15:53:26.020349: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
2025-04-06 15:53:26.020360: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (theodoros-MS-7D75): /proc/driver/nvidia/version does not exist
2025-04-06 15:53:26.020558: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
import tensorflow_model_optimization as tfmot

# 1. Build baseline model
model = build_model(input_shape=(28, 28, 1), n_classes=10,
                    filters_per_conv_layer=[16, 16, 24],
                    neurons_per_dense_layer=[42, 64])

# 2. Apply QAT BEFORE training
quantize_model = tfmot.quantization.keras.quantize_model
qat_model = quantize_model(model)

# 3. Compile and train as usual
qat_model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history = qat_model.fit(train_data, validation_data=val_data, epochs=10)


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [9]:
qat_model.summary()
# 4. Evaluate the quantized model
loss, acc = qat_model.evaluate(test_data)
print(f"QAT Model Test Accuracy: {acc:.4f}")

Model: "baseline_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 quantize_layer (QuantizeLay  (None, 28, 28, 1)        3         
 er)                                                             
                                                                 
 quant_conv_0 (QuantizeWrapp  (None, 26, 26, 16)       177       
 erV2)                                                           
                                                                 
 quant_bn_conv_0 (QuantizeWr  (None, 26, 26, 16)       65        
 apperV2)                                                        
                                                                 
 quant_conv_act_0 (QuantizeW  (None, 26, 26, 16)       3         
 rapperV2)                                          