In [None]:
import os
import numpy as np
from PIL import Image
import tensorflow as tf
from google.cloud import storage
from io import BytesIO
from google.colab import auth
from sklearn.model_selection import train_test_split
import shutil
import glob
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report

auth.authenticate_user()
from google.colab import drive
drive.mount('/content/drive')

!pip install --upgrade tensorflow

project_id = "mushroom-master-136c0"
bucket_name = "mushroom-master-central"
source_directory = "cleaned_dataset/"


client = storage.Client(project=project_id)
bucket = client.bucket(bucket_name)

!mkdir -p /content/local_data/cleaned_dataset
!gsutil -m rsync -r gs://{bucket_name}/{source_directory} /content/local_data/cleaned_dataset

local_data_dir = "/content/local_data/cleaned_dataset"

model_path  = '/content/drive/MyDrive/mushroom_masterv1_ftv1.4.keras'
tflite_path ='/content/drive/MyDrive/mushroom_master_mobilev1.4.tflite'
batch_size  = 128
image_size  = (224, 224)

def remove_hidden_files(folder_path):
    for root, dirs, files in os.walk(folder_path):
        for d in dirs:
            if d.startswith('.'):
                shutil.rmtree(os.path.join(root, d))
        for f in files:
            if f.startswith('.'):
                os.remove(os.path.join(root, f))

def prune_empty_classes(folder_path, min_count=10):
    for class_name in os.listdir(folder_path):
        class_dir = os.path.join(folder_path, class_name)
        if os.path.isdir(class_dir):
            valid_images = [f for f in os.listdir(class_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))]
            if len(valid_images) < min_count:
                shutil.rmtree(class_dir)

def get_images_and_labels_local(local_dir):
    class_names = sorted([
        d for d in os.listdir(local_dir) if os.path.isdir(os.path.join(local_dir, d)) and not d.startswith('.')
    ])
    class_to_index = {name: idx for idx, name in enumerate(class_names)}
    paths = []
    labels = []
    for class_name in class_names:
        class_path = os.path.join(local_dir, class_name)
        for root, _, files in os.walk(class_path):
            for f in files:
                if f.lower().endswith((".jpg", ".jpeg", ".png")) and not f.startswith('.'):
                    rel_path = os.path.join(os.path.relpath(root, local_dir), f)
                    paths.append(rel_path)
                    labels.append(class_to_index[class_name])
    return np.array(paths), np.array(labels), class_names, class_to_index

remove_hidden_files(local_data_dir)
prune_empty_classes(local_data_dir, min_count=1) # to remove any hidden files that might get into the model undetected

all_image_paths, all_labels, class_names, class_to_index = get_images_and_labels_local(local_data_dir)

train_val_paths, test_paths, train_val_labels, test_labels = train_test_split( # split into train and test
    all_image_paths, all_labels, test_size=0.1, stratify=all_labels, shuffle=True, random_state=42)

def load_and_preprocess_image(path, label):
     full_path = tf.strings.join(["/content/local_data/cleaned_dataset/", path])
     try:
         image_raw = tf.io.read_file(full_path)
         img = tf.image.decode_image(image_raw, channels=3, expand_animations=False)
     except tf.errors.InvalidArgumentError:
         # dummy image in case of error
         img = tf.zeros((image_size[0], image_size[1], 3), dtype=tf.uint8)

     img = tf.image.resize(img, image_size)
     img = tf.cast(img, tf.float32) / 255.0
     return img, label

def create_dataset(image_paths, labels, shuffle=True, repeat=False):
    # Converts lists of image paths and labels to tensors
    image_paths = tf.convert_to_tensor(image_paths, dtype=tf.string)
    labels = tf.convert_to_tensor(labels, dtype=tf.int32)

    # Create a dataset from the image paths and labels
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(image_paths), seed=42)
    if repeat:
        dataset = dataset.repeat()

    # Map the loading function which now returns only image and label
    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

test_ds = create_dataset(test_paths, test_labels, shuffle=False)

print(f"Loading model from {model_path} ...")
model = tf.keras.models.load_model(model_path)


test_loss, test_acc, test_top5 = model.evaluate(test_ds, verbose=2)
print(f"\nTest loss      : {test_loss:.4f}")
print(f"Test top‑1 acc : {test_acc:.4f}")
print(f"Test top‑5 acc : {test_top5:.4f}")

y_true = np.concatenate([y.numpy() for _, y in test_ds])
y_prob = model.predict(test_ds, verbose=0)
y_pred = np.argmax(y_prob, axis=1)

#create sorted list of model outputs to show best and worst performing classes
import pandas as pd
import matplotlib.pyplot as plt

report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)

df_report = pd.DataFrame(report_dict).transpose()

df_per_class = df_report.iloc[:-3].copy()

df_sorted = df_per_class.sort_values(by='f1-score')


df_display = df_sorted[['precision', 'recall', 'f1-score', 'support']]
df_display.style.format({
    'precision': '{:.3f}',
    'recall': '{:.3f}',
    'f1-score': '{:.3f}',
    'support': '{:.0f}'
})

In [None]:
!pip install ai-edge-litert
from ai_edge_litert.interpreter import Interpreter

#evaluate quantised model
interpreter = Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()
inp_idx  = interpreter.get_input_details()[0]["index"]
out_idx  = interpreter.get_output_details()[0]["index"]

top1_hits = top5_hits = n_samples = 0

for x_batch, y_batch in test_ds:
    x_batch = x_batch.numpy().astype(np.float32)
    y_batch = y_batch.numpy()

    for img, lbl in zip(x_batch, y_batch):
        input_data = np.expand_dims(img, axis=0)
        interpreter.set_tensor(inp_idx, input_data)
        interpreter.invoke()
        preds = interpreter.get_tensor(out_idx)[0]

        top1 = np.argmax(preds)
        top5 = np.argsort(preds)[-5:]

        top1_hits += int(top1 == lbl)
        top5_hits += int(lbl in top5)
        n_samples += 1

print(f"\nTFLite top‑1 acc : {top1_hits / n_samples:.4f}")
print(f"TFLite top‑5 acc : {top5_hits / n_samples:.4f}")

