In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow_model_optimization.quantization.keras import quantize_model



In [16]:

# Load the data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Reshape the data to 32x32x3 and normalize it
x_train = tf.image.grayscale_to_rgb(tf.expand_dims(x_train, axis=-1))
x_train = tf.image.resize(x_train, (32, 32))
x_train = x_train / 255.0
x_test = tf.image.grayscale_to_rgb(tf.expand_dims(x_test, axis=-1))
x_test = tf.image.resize(x_test, (32, 32))
x_test = x_test / 255.0


In [None]:
(train_images_rgb, train_labels), (test_images_rgb, test_labels) = (x_train, y_train), (x_test, y_test)
train_labels = keras.utils.to_categorical(train_labels, 10)
test_labels = keras.utils.to_categorical(test_labels, 10)

In [22]:

# Define ResNet50 model
resnet = ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))

# Add custom layers on top of ResNet50
model = Sequential()
model.add(Flatten(input_shape=resnet.output_shape[1:]))
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

# Combine ResNet50 and custom layers
model_input = tf.keras.layers.Input(shape=(32, 32, 3))
output = model(resnet(model_input))
model = tf.keras.Model(inputs=model_input, outputs=output)

# Compile model
model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])

model.summary()

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_7 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 resnet50 (Functional)       (None, 1, 1, 2048)        23587712  
                                                                 
 sequential_2 (Sequential)   (None, 10)                527114    
                                                                 
Total params: 24,114,826
Trainable params: 24,061,706
Non-trainable params: 53,120
_________________________________________________________________


  super(Adam, self).__init__(name, **kwargs)


In [24]:

# Train model with early stopping and model checkpoint callbacks
es_callback = EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
mc_callback = ModelCheckpoint('best_model.h5', monitor='val_accuracy', save_best_only=True)
model.fit(x_train, y_train, validation_split=0.2, epochs=1, batch_size=128, callbacks=[es_callback, mc_callback])

# Load best model from checkpoint
model.load_weights('best_model.h5')

# Fine-tune model using QAT technique
quantize_model(model)
model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=0.001), metrics=['accuracy'])
model.fit(x_train, y_train, validation_split=0.2, epochs=1, batch_size=128)

 38/375 [==>...........................] - ETA: 31:45 - loss: 0.8297 - accuracy: 0.7455

KeyboardInterrupt: 

In [None]:

# Evaluate model accuracy for different quantization levels
for bits in [16, 4, -1]:
    q_model = quantize_model(model, bits=bits)
    q_model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=0.0001), metrics=['accuracy'])
    loss, acc = q_model.evaluate(x_test, y_test)
    print(f"Quantized model with {bits} bits : accuracy={acc}")
