<a href="https://colab.research.google.com/github/svantepihl/Thesis-MaskDetection/blob/master/train_model_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

`NOTE!`  
`This notebook needs to be ran in a TPU environment`

`Dataset loads from a private GCP bucket but can be downloaded from here :` 

[LINK](https://drive.google.com/drive/folders/18UJsRrjrW4lIlKbYhNQbQcrLFoApxymp?usp=sharing)

# Imports

In [None]:
from datetime import datetime
import math
import re
import os
import time
import sys
import numpy as np
import pandas as pd
import sklearn
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

# Installs

In [None]:
!pip install -U tensorboard_plugin_profile # Tensorboard plugin for profiling performance 

# Auth for Google Cloud Plattform.

In [None]:
if 'google.colab' in sys.modules:
   from google.colab import auth
   auth.authenticate_user()

# Tensorflow version

In [None]:
%tensorflow_version 2.x
import tensorflow as tf, tensorflow.keras.backend as K
print("Tensorflow version " + tf.__version__)

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('To enable a high-RAM runtime, select the Runtime > "Change runtime type"')
  print('menu, and then select High-RAM in the Runtime shape dropdown. Then, ')
  print('re-execute this cell.')
else:
  print('You are using a high-RAM runtime!')

# TPU config  
Reference: [TPUs in Google Colab](https://colab.research.google.com/notebooks/tpu.ipynb)

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

# Constants

All constants are captitalized

In [None]:
#@title New run?

SEED = 1337

AUTO = tf.data.experimental.AUTOTUNE

IMAGE_WIDTH = 224
IMAGE_HEIGHT = 224
IMAGE_CHANNELS = 3
IMAGE_SIZE = [IMAGE_WIDTH, IMAGE_HEIGHT]
IMAGE_SHAPE = [*IMAGE_SIZE, IMAGE_CHANNELS]

EPOCHS = 400

# TRANINGS SETTINGS
if tpu != None:
  BATCH_SIZE = 128 * strategy.num_replicas_in_sync
  VALIDATION_BATCH_SIZE = 128 * strategy.num_replicas_in_sync
  TEST_BATCH_SIZE = 128 * strategy.num_replicas_in_sync
  AUG_BATCH = BATCH_SIZE
else:
  BATCH_SIZE = 64
  VALIDATION_BATCH_SIZE = 64
  TEST_BATCH_SIZE = 64
  AUG_BATCH = BATCH_SIZE

GCS_DATASET_PATTERN = 'gs://facemask-detection-thesis-32-tfrecords-jpeg-224x224/*.tfrec' # GCS bucket where dataset is stored

GCS_LOG_BUCKET = 'gs://facemask-detection-thesis-training-logs/' # To store training logs for tensorboard

NEW_RUN = True #@param {type:"boolean"}
if NEW_RUN:
  now = datetime.now()
  dt_string = now.strftime("%Y-%m-%d_%H")

  RUN_FOLDER = '/content/drive/MyDrive/MaskedFace/Final/RUN-' + dt_string +'/'
  os.makedirs(RUN_FOLDER)
  print("Created folder: "+ RUN_FOLDER)

  MODEL_FOLDER = RUN_FOLDER + 'Models/'
  os.makedirs(MODEL_FOLDER)
  print("Created folder: "+ MODEL_FOLDER)

CLASSES = ['MaskCorrect', 'MaskOnChin', 'MaskOnlyOnMouth', 'NoMask']

TRAIN_AND_VALIDATION_SPLIT = 0.20

# Utility functions

In [None]:
def write_vars_to_file(f, **kwargs):
    for name, val in kwargs.items():
      f.write("%s = %s\n" %(name, val))

def get_dataset_labels(dataset):
  _, labels = tuple(zip(*dataset.unbatch()))
  labels = np.array(labels)
  return labels

def dataset_to_numpy_util(dataset,N):
  dataset = dataset.unbatch().shuffle(N).batch(N)
  for images, labels in dataset:
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    break;  
  return numpy_images, numpy_labels

def whole_dataset_to_numpy_util(dataset):
  images, labels = tuple(zip(*dataset.unbatch()))
  images = np.array(images)
  labels = np.array(labels)
  return images,labels

def title_from_label_and_target(label, correct_label):
  label = np.argmax(label, axis=-1)  # one-hot to class number
  correct_label = np.argmax(correct_label, axis=-1) # one-hot to class number
  correct = (label == correct_label)
  return "{} [{}{}{}]".format(CLASSES[label], str(correct), ', shoud be ' if not correct else '',
                              CLASSES[correct_label] if not correct else ''), correct

def display_one_grayscale_image(image, title, subplot, red=False):
    plt.subplot(subplot)
    plt.axis('off')
    arr = np.asarray(image)
    arr = arr[:,:,0]
    plt.imshow(arr, cmap='gray', vmin=0, vmax=255) # Grayscale 
    plt.title(title, fontsize=16, color='red' if red else 'black')
    return subplot+1

def display_one_image(image, title, subplot, red=False):
    plt.subplot(subplot)
    plt.axis('off')
    plt.imshow(image, cmap='gray', vmin=0, vmax=255) # color
    plt.title(title, fontsize=16, color='red' if red else 'black')
    return subplot+1
  
def display_9_images_from_dataset(dataset, grayscale = False):
  subplot=331
  plt.figure(figsize=(13,13))
  images, labels = dataset_to_numpy_util(dataset, 9)
  for i, image in enumerate(images):
    title = CLASSES[np.argmax(labels[i], axis=-1)]
    if grayscale:
      subplot = display_one_grayscale_image(image, title, subplot)
    else:
      subplot = display_one_image(image, title, subplot)
    if i >= 8:
      break;
              
  plt.tight_layout() 
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()
  
def display_9_images_with_predictions(images, predictions, labels):
  subplot=331
  plt.figure(figsize=(13,13))
  for i, image in enumerate(images):
    title, correct = title_from_label_and_target(predictions[i], labels[i])
    subplot = display_one_image(image, title, subplot, not correct)
    if i >= 8:
      break;
              
  plt.tight_layout() 
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()
  
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  #ax.set_ylim(0.28,1.05)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])

# Augmentations

## Rotate, shear, shift, zoom

In [None]:
def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    
    
    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

In [None]:
def rotate_shear_shift_zoom(image,label):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    rot = 14. * tf.random.normal([1],dtype='float32')
    shr = 5. * tf.random.normal([1],dtype='float32') 
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    h_shift = 12. * tf.random.normal([1],dtype='float32') 
    w_shift = 12. * tf.random.normal([1],dtype='float32') 
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    

    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
    
    image = tf.reshape(d,[DIM,DIM,3])
    return image,label

## Image augmentations

In [None]:
def img_augment(image, one_hot_class):
    image = tf.image.random_hue(image, 0.05,seed=SEED) 
    image = tf.image.random_saturation(image, 0.6, 1.5,seed=SEED)
    image = tf.image.random_contrast(image, 0.7, 1.3,seed=SEED)
    image = tf.image.random_flip_left_right(image,seed=SEED)
    image = tf.image.random_brightness(image,0.3)
    return image, one_hot_class

# Dataset functions

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    # image = (tf.cast(image, tf.float32) / 127.5) - 1  # convertion done in model
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

In [None]:
def read_tfrecord(example):
  features = {
    "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
    "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means scalar
    "one_hot_class": tf.io.VarLenFeature(tf.float32),
  }
  example = tf.io.parse_single_example(example, features)
  image = decode_image(example['image'])
  one_hot_class = tf.reshape(tf.sparse.to_dense(example['one_hot_class']), [4])
  label = tf.cast(example['class'], tf.int32)
  return image, one_hot_class

In [None]:
def load_dataset(filenames, ordered = False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # Diregarding data order. Order does not matter since we will be shuffling the data anyway
    
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
        
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # use data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_tfrecord, num_parallel_calls = AUTO) # returns a dataset of (image, label) pairs if labeled = True or (image, id) pair if labeld = False
    return dataset

In [None]:
def get_training_dataset(dataset):
    dataset = dataset.repeat()
    dataset = dataset.map(img_augment, num_parallel_calls=AUTO)
    dataset = dataset.map(rotate_shear_shift_zoom, num_parallel_calls=AUTO)
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

In [None]:
def get_validation_dataset(dataset):
    dataset = dataset.batch(VALIDATION_BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

# Load data from GCS

In [None]:
%%time
filenames = tf.io.gfile.glob(GCS_DATASET_PATTERN)

# Split into train/test

In [None]:
TRAIN_FILENAMES, VALIDATION_FILENAMES = sklearn.model_selection.train_test_split(filenames, test_size=TRAIN_AND_VALIDATION_SPLIT)

# Summarise dataset

In [None]:
def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. 00-2000.tfrec = 2000 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
NUM_TRAINING_IMAGES = count_data_items(TRAIN_FILENAMES)
# use validation data for training
NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE

TRAIN_STEPS = count_data_items(TRAIN_FILENAMES) // BATCH_SIZE
print("TOTAL IMAGES: ", int(count_data_items(filenames)))
print("TRAINING IMAGES: ", int(NUM_TRAINING_IMAGES), ", STEPS PER EPOCH: ", int(STEPS_PER_EPOCH))
print("VALIDATION IMAGES ",int(NUM_VALIDATION_IMAGES))

# Callbacks

## Learning rate

In [None]:
# Learning rate settings
start_lr = 0.000000000000001 
min_lr = 0.000000000000001 
if tpu != None:
  max_lr = 0.00005 * strategy.num_replicas_in_sync
else: 
  max_lr = 0.00005 * 8
rampup_epochs = 200
sustain_epochs = 50
exp_decay = .9

# Learning rate function
def learning_rate_fn(epoch):
    def lr(epoch, start_lr, min_lr, max_lr, rampup_epochs, sustain_epochs, exp_decay):
        if epoch < rampup_epochs:
            lr = (max_lr - start_lr)/rampup_epochs * epoch + start_lr
        elif epoch < rampup_epochs + sustain_epochs:
            lr = max_lr
        else:
            lr = (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr
        return lr
    return lr(epoch, start_lr, min_lr, max_lr, rampup_epochs, sustain_epochs, exp_decay)
    


# Plot learning rate
rng = [i for i in range(EPOCHS)]
y = [learning_rate_fn(x) for x in rng]
plt.plot(rng, [learning_rate_fn(x) for x in rng])
print(y[0], y[-1])

### Save learning rate settings 

In [None]:
lr_file = open(RUN_FOLDER + 'LR.txt',mode='a+')
write_vars_to_file(lr_file,
                   start_lr=start_lr,
                   min_lr=min_lr,
                   max_lr=max_lr,
                   rampup_epochs=rampup_epochs,
                   sustain_epochs=sustain_epochs,
                   exp_decay=exp_decay)
lr_file.close()


In [None]:
def get_learning_rate():
  return tf.keras.callbacks.LearningRateScheduler(lambda epoch: learning_rate_fn(epoch), verbose=True)

## Early stopping

In [None]:
def get_earlystopping_callback(epoch_patience = 60):
  return tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=epoch_patience,
    verbose=True
  )

## Model checkpoint

In [None]:
def get_checkpoint_callback(model_name):
  checkpoint_path = MODEL_FOLDER + model_name
  return tf.keras.callbacks.ModelCheckpoint(checkpoint_path,monitor='val_loss',verbose=1,save_best_only=True)

## Tensorboard callback

In [None]:
def get_tensorboard_callback():
  logs = GCS_LOG_BUCKET + datetime.now().strftime("%Y%m%d-%H%M%S")
  return tf.keras.callbacks.TensorBoard(log_dir = logs, 
                                        histogram_freq = 1,
                                        profile_batch=(20,50))

# Load and start Tensorboard

In [None]:
# Load the TensorBoard notebook extension.
%load_ext tensorboard

In [None]:
# Get TPU profiling service address. This address will be needed for capturing
# profile information with TensorBoard in the following steps.
service_addr = tpu.get_master().replace(':8470', ':8466')
tpu_worker = os.environ['COLAB_TPU_ADDR'].replace('8470', '8466') 
print(tf.profiler.experimental.client.monitor(tpu_worker,1))
print(tpu_worker)

In [None]:
# Launch TensorBoard.
%tensorboard --logdir=gs://facemask-detection-thesis-training-logs/  

# Create model

In [None]:
def create_model():
  # Base model
  base_model = tf.keras.applications.MobileNetV3Small(
    input_shape=IMAGE_SHAPE,
    minimalistic=True, 
    include_top=False,
    weights='imagenet'
  )

  input = tf.keras.Input(shape=IMAGE_SHAPE)
  x = tf.keras.applications.mobilenet_v3.preprocess_input(input)
  x = base_model(x)
  x = tf.keras.layers.GlobalAveragePooling2D()(x)
  x = tf.keras.layers.Dense(128,activation='relu')(x)
  x = tf.keras.layers.Dropout(0.2)(x)
  outputs = tf.keras.layers.Dense(4, activation='softmax')(x)

  model = tf.keras.Model(input, outputs)

  model.compile(
      optimizer='adam',
      loss = 'categorical_crossentropy',
      metrics=['categorical_accuracy']
  )
  return model

In [None]:
def create_tpu_model():
  with strategy.scope():
    return create_model()

## Save model settings

In [None]:
model = create_model()
model.summary()
json_model = model.to_json()
with open(RUN_FOLDER+'model.json', 'w') as json_file:
    json_file.write(json_model)
del model, json_model

# Training

In [None]:
def train_model():
  #early_stopping = get_earlystopping_callback()
  tensorboard = get_tensorboard_callback()
  learning_rate = get_learning_rate()

  train_dataset = load_dataset(TRAIN_FILENAMES)
  val_dataset = load_dataset(VALIDATION_FILENAMES)

  checkpoint_name = f'model_checkpoint' + '.h5'
  model_checkpoint = get_checkpoint_callback(checkpoint_name)

  model = create_tpu_model()

  history = model.fit(
            get_training_dataset(train_dataset), 
            validation_data=get_validation_dataset(val_dataset),
            steps_per_epoch=STEPS_PER_EPOCH, 
            epochs=EPOCHS, 
            callbacks=[tensorboard,learning_rate,model_checkpoint])
  
  print('Load best weights for model prediction')
  model.load_weights(MODEL_FOLDER + checkpoint_name)

  return model,history

In [None]:
final_model, history  = train_model()

In [None]:
saved_model_path = MODEL_FOLDER + 'final_model.h5'

final_model.save(saved_model_path)

In [None]:
def display_training_curves(training, validation, title, subplot):
  ax = plt.subplot(subplot)
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['training', 'validation'])

plt.subplots(figsize=(10,10))
plt.tight_layout()
display_training_curves(history.history['categorical_accuracy'], history.history['val_categorical_accuracy'], 'categorical_accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

In [None]:
# Must copy the model from TPU to CPU to be able to compose them.
restored_model = create_model();
restored_model.set_weights(final_model.get_weights()) # this copies the weights from TPU


# Convert to TFLite Model (Not quantized)

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(restored_model)
tflite_model = converter.convert()

tflite_path = MODEL_FOLDER + "final_model.tflite"
with open(tflite_path, 'wb') as f:
  f.write(tflite_model)

# Convert to TFLite model (quantaized)

In [None]:
quant_dataset = load_dataset(TRAIN_FILENAMES)

In [None]:
def representative_data_gen():
  for images,_ in quant_dataset.batch(1).take(200):
    images = images.numpy()
    images = images.astype(np.float32)
    images = tf.keras.applications.mobilenet_v3.preprocess_input(images)
    yield [images]

converter = tf.lite.TFLiteConverter.from_keras_model(restored_model)
converter.experimental_new_converter = True
# This enables quantization
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# This sets the representative dataset for quantization
converter.representative_dataset = representative_data_gen
# This ensures that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8,tf.lite.OpsSet.TFLITE_BUILTINS]
# For full integer quantization, though supported types defaults to int8 only, we explicitly declare it for clarity.
converter.target_spec.supported_types = [tf.int8]
# These set the input and output tensors to uint8 (added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_quant_model = converter.convert()

tflite_quant_path = MODEL_FOLDER + "final_model_quant.tflite"
with open(tflite_quant_path, 'wb') as f:
  f.write(tflite_quant_model)

# Compare quantized vs non-quantized

Get dataset to test.

In [None]:
validation_dataset = get_validation_dataset(load_dataset(VALIDATION_FILENAMES))

batch_images, batch_labels = dataset_to_numpy_util(validation_dataset,5000)
#batch_images = tf.keras.applications.mobilenet_v3.preprocess_input(batch_images)

In [None]:
logits = restored_model(batch_images)
prediction = np.argmax(logits, axis=1)
truth = np.argmax(batch_labels, axis=1)

keras_accuracy = tf.keras.metrics.Accuracy()
keras_accuracy(prediction, truth)

print("Raw model accuracy: {:.3%}".format(keras_accuracy.result()))

In [None]:
def set_input_tensor(interpreter, input):
  input_details = interpreter.get_input_details()[0]
  tensor_index = input_details['index']
  input_tensor = interpreter.tensor(tensor_index)()[0]
  # Inputs for the TFLite model must be uint8, so we quantize our input data.
  # NOTE: This step is necessary only because we're receiving input data from
  # ImageDataGenerator, which rescaled all image data to float [0,1]. When using
  # bitmap inputs, they're already uint8 [0,255] so this can be replaced with:
  #   input_tensor[:, :] = input
  #scale, zero_point = input_details['quantization']
  #input_tensor[:, :] = np.uint8(input / scale + zero_point)
  input_tensor[:, :] = input

def classify_image(interpreter, input):
  set_input_tensor(interpreter, input)
  interpreter.invoke()
  output_details = interpreter.get_output_details()[0]
  output = interpreter.get_tensor(output_details['index'])
  # Outputs from the TFLite model are uint8, so we dequantize the results:
  scale, zero_point = output_details['quantization']
  output = scale * (output - zero_point)
  top_1 = np.argmax(output)
  return top_1

interpreter = tf.lite.Interpreter(tflite_quant_path)
interpreter.allocate_tensors()

# Collect all inference predictions in a list
batch_prediction = []
batch_truth = np.argmax(batch_labels, axis=1)

for i in range(len(batch_images)):
  prediction = classify_image(interpreter, batch_images[i])
  batch_prediction.append(prediction)

# Compare all predictions to the ground truth
tflite_accuracy = tf.keras.metrics.Accuracy()
tflite_accuracy(batch_prediction, batch_truth)
print("Quant TF Lite accuracy: {:.3%}".format(tflite_accuracy.result()))
