In [None]:
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from glob import glob
import tensorflow_model_optimization as tfmot
from sklearn.utils.class_weight import compute_class_weight
import random
import gc
from tensorflow.keras.preprocessing.image import ImageDataGenerator


def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

setup_seed(1)

gc.collect()
tf.keras.backend.clear_session()

train_data_dir = 'path/to/train'
validation_data_dir = 'path/to/validation'
test_data_dir = 'path/to/test'

batches = 32

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

validation_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(224, 224),
    batch_size=batches,
    class_mode='categorical'
)

validation_generator = validation_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(224, 224),
    batch_size=batches,
    class_mode='categorical'
)

def get_class_names(directory):
    return sorted(os.listdir(directory))

class_names = get_class_names(train_data_dir)
class_indices = {class_name: index for index, class_name in enumerate(class_names)}

for class_name, index in class_indices.items():
    print(f"Class Name: {class_name}, Index: {index}")

def load_and_resize_image(image_path, target_size=(224, 224)):  
    image = cv2.imread(image_path)
    image = cv2.resize(image, target_size)
    image = image.astype('float32') / 255.0
    return image

def process_dataset(directory, batch_size, datagen, target_size=(224, 224)):  
    image_paths = glob(os.path.join(directory, '*/*.png'))
    random.shuffle(image_paths)
    count = len(image_paths)
    while True:
        for i in range(0, count, batch_size):
            batch_paths = image_paths[i:i+batch_size]
            images = []
            labels = []
            for path in batch_paths:
                img = load_and_resize_image(path, target_size)
                images.append(img)
                label = path.split(os.sep)[-2]
                label_index = class_indices[label]  
                labels.append(label_index)
            labels = to_categorical(labels, num_classes=len(class_names))
            yield np.array(images), np.array(labels)

def get_class_sample_counts(directory, class_indices):
    counts = {}
    for class_name in class_indices.keys():
        path = os.path.join(directory, class_name)
        counts[class_name] = len(glob(os.path.join(path, '*.png')))
    return counts

class_sample_counts = get_class_sample_counts(train_data_dir, class_indices)
samples_per_class = list(class_sample_counts.values())

class_weights = compute_class_weight(class_weight='balanced',
                                     classes=np.unique(list(class_indices.values())),
                                     y=list(class_indices.values()) * (sum(samples_per_class) // len(samples_per_class)))
class_weight_dict = dict(zip(np.unique(list(class_indices.values())), class_weights))

batches = 32 #adjust as needed

model = load_model('path/to/tensorflow/model') #requires a tensorflow model

# Evaluate the loaded model if needed
test_loss, test_accuracy = model.evaluate(
    test_generator,
    steps=len(glob(os.path.join(test_data_dir, '*/*.png'))) // batches
)
print(f"Loaded Model Test Loss: {test_loss}")
print(f"Loaded Model Test Accuracy: {test_accuracy}")

def print_model_weights_sparsity(model):
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        
        for weight in weights:
            if "quantize_layer" in weight.name:
                continue  # Skip auxiliary quantization weights
            
            weight_numpy = weight.numpy()  # Convert to numpy array for processing
            weight_size = weight_numpy.size
            zero_num = np.count_nonzero(weight_numpy == 0)
            
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})"
            )

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
}

callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

opt = Adam(learning_rate=5e-5)

pruned_model.compile(
    loss='categorical_crossentropy',
    optimizer=opt,
    metrics=['accuracy']
)

# Make sure to create train_generator and validation_generator appropriately
pruned_model.fit(
    train_generator,
    epochs=1,
    steps_per_epoch=len(glob(os.path.join(train_data_dir, '*/*.png'))) // batches,
    validation_data=validation_generator,
    validation_steps=len(glob(os.path.join(validation_data_dir, '*/*.png'))) // batches,
    callbacks=callbacks
)

stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)

_, pruned_model_accuracy = pruned_model.evaluate(
    test_generator,
    steps=len(glob(os.path.join(test_data_dir, '*/*.png'))) // batches
)

# PQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_pruned_model)
pqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme())

opt = Adam(learning_rate=1e-1) #adjust as needed
pqat_model.compile(
  loss='categorical_crossentropy',
  optimizer=opt,
  metrics=['accuracy']
)

print('Train pqat model:')

pqat_model.fit(
    train_generator,
    epochs=1,
    steps_per_epoch=len(glob(os.path.join(train_data_dir, '*/*.png'))) // batches,
    validation_data=validation_generator,
    validation_steps=len(glob(os.path.join(validation_data_dir, '*/*.png'))) // batches,
    callbacks=callbacks
)

pqat_model.evaluate(
    test_generator,
    steps=len(glob(os.path.join(test_data_dir, '*/*.png'))) // batches
)

# Convert the TensorFlow model to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_keras_model(pqat_model)
tflite_model = converter.convert()

# Save the TensorFlow Lite model to a file
with open('model.tflite', 'wb') as f: #Change the model name to an appropriate name
    f.write(tflite_model)

# Convert the TensorFlow model to TensorFlow Lite with integer quantization
converter = tf.lite.TFLiteConverter.from_keras_model(pqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8  # or tf.int8
converter.inference_output_type = tf.uint8  # or tf.int8
converter.representative_dataset = representative_data_gen  # This needs to be set up

def representative_data_gen():
    for input_value, _ in train_generator:
        yield [input_value]

tflite_quant_model = converter.convert()

# Save the quantized model
with open('quant_model.tflite', 'wb') as f:
    f.write(tflite_quant_model)

# To compile the model for Edge TPU, you need to use the Edge TPU Compiler, which is a separate utility
# This step cannot be done in Python and needs to be executed in a command line
# Example command: edgetpu_compiler quant_model.tflite

