### V1.4
This model utilises oversampling to address class imbalance and utilises data augmentation to improve the model's generalisation. The dataset now also includes the Danish Fungi dataset, and has been denoised to get rid of irrelevant images. Additionally, this notebook is to be trained on a more powerful A100 GPU, as opposed to a V28 TPU, for more efficient convergence.

## Data Preparation
Classes are moderately imbalanced due to the nature of the dataset (aggregated on mushroomobserver.com), so effort has to be made to be intentional in splitting samples from each class fairly. A 0.8/0.1/0.1 split has been chosen to begin with due to the dataset being quite small https://encord.com/blog/train-val-test-split/.

In [None]:
!pip install -U crcmod

import os
import numpy as np
from PIL import Image
import tensorflow as tf
from google.cloud import storage
from io import BytesIO
from google.colab import auth
from sklearn.model_selection import train_test_split
import shutil
import glob

auth.authenticate_user()
from google.colab import drive
drive.mount('/content/drive')

!pip install --upgrade tensorflow

project_id = "mushroom-master-136c0"
bucket_name = "mushroom-master-central"
source_directory = "cleaned_dataset/"
version = "v1.4"

client = storage.Client(project=project_id)
bucket = client.bucket(bucket_name)

!mkdir -p /content/local_data/cleaned_dataset
!gsutil -m rsync -r gs://{bucket_name}/{source_directory} /content/local_data/cleaned_dataset

local_data_dir = "/content/local_data/cleaned_dataset"

def remove_hidden_files(folder_path):
    for root, dirs, files in os.walk(folder_path):
        for d in dirs:
            if d.startswith('.'):
                shutil.rmtree(os.path.join(root, d))
        for f in files:
            if f.startswith('.'):
                os.remove(os.path.join(root, f))

def prune_empty_classes(folder_path, min_count=10):
    for class_name in os.listdir(folder_path):
        class_dir = os.path.join(folder_path, class_name)
        if os.path.isdir(class_dir):
            valid_images = [f for f in os.listdir(class_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))]
            if len(valid_images) < min_count:
                shutil.rmtree(class_dir)

def get_images_and_labels_local(local_dir):
    class_names = sorted([
        d for d in os.listdir(local_dir) if os.path.isdir(os.path.join(local_dir, d)) and not d.startswith('.')
    ])
    class_to_index = {name: idx for idx, name in enumerate(class_names)}
    paths = []
    labels = []
    for class_name in class_names:
        class_path = os.path.join(local_dir, class_name)
        for root, _, files in os.walk(class_path):
            for f in files:
                if f.lower().endswith((".jpg", ".jpeg", ".png")) and not f.startswith('.'):
                    rel_path = os.path.join(os.path.relpath(root, local_dir), f)
                    paths.append(rel_path)
                    labels.append(class_to_index[class_name])
    return np.array(paths), np.array(labels), class_names, class_to_index

remove_hidden_files(local_data_dir)
prune_empty_classes(local_data_dir, min_count=1) # to remove any hidden files that might get into the model undetected

all_image_paths, all_labels, class_names, class_to_index = get_images_and_labels_local(local_data_dir)

train_val_paths, test_paths, train_val_labels, test_labels = train_test_split( # split into train and test
    all_image_paths, all_labels, test_size=0.1, stratify=all_labels, shuffle=True, random_state=42)

train_paths, val_paths, train_labels, val_labels = train_test_split( # split train into train and val - this actually results in a 0.81/0.09/0.1 split but this is rounded to 0.8/0.1/0.1 for simplicity and readability
    train_val_paths, train_val_labels, test_size=0.1, stratify=train_val_labels, shuffle=True, random_state=42)


## Oversampling
This section utilises oversampling to help negate the negative effect of the highly imbalanced dataset.

In [None]:
from collections import Counter

def oversample_data(paths, labels):
    class_counts = Counter(labels)
    max_count = max(class_counts.values())

    new_paths = []
    new_labels = []
    for c, count in class_counts.items():
        c_indices = np.where(labels == c)[0]
        c_paths = paths[c_indices]
        c_labels = labels[c_indices]

        replicate_factor = int(np.ceil(max_count / count))

        c_paths_oversampled = np.tile(c_paths, replicate_factor)
        c_labels_oversampled = np.tile(c_labels, replicate_factor)

        #trims if oversampling overshot
        c_paths_oversampled = c_paths_oversampled[:max_count]
        c_labels_oversampled = c_labels_oversampled[:max_count]

        new_paths.extend(c_paths_oversampled)
        new_labels.extend(c_labels_oversampled)

    # shuffle and convert to numpy array
    new_paths = np.array(new_paths)
    new_labels = np.array(new_labels)
    p = np.random.permutation(len(new_paths))
    new_paths = new_paths[p]
    new_labels = new_labels[p]
    return new_paths, new_labels

# Oversample the *training* set only
train_paths, train_labels = oversample_data(train_paths, train_labels)

## Data Preprocessing
Data needs to be:


*   Decoded from JPEGs into RGB grids of pixels
*   From there converted into floating-point tensors
*   Reshaped into a standard size
*   Packed into batches



In [None]:
from PIL import Image

batch_size = 128
image_size = (224, 224)

def load_and_preprocess_image(path, label):
     full_path = tf.strings.join(["/content/local_data/cleaned_dataset/", path])
     try:
         image_raw = tf.io.read_file(full_path)
         img = tf.image.decode_image(image_raw, channels=3, expand_animations=False)
     except tf.errors.InvalidArgumentError:
         # dummy image in case of error
         img = tf.zeros((image_size[0], image_size[1], 3), dtype=tf.uint8)

     img = tf.image.resize(img, image_size)
     img = tf.cast(img, tf.float32) / 255.0
     return img, label

def create_dataset(image_paths, labels, shuffle=True, repeat=False):
    # Converts lists of image paths and labels to tensors
    image_paths = tf.convert_to_tensor(image_paths, dtype=tf.string)
    labels = tf.convert_to_tensor(labels, dtype=tf.int32)

    # Create a dataset from the image paths and labels
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(image_paths), seed=42)
    if repeat:
        dataset = dataset.repeat()

    # Map the loading function which now returns only image and label
    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

# Create the datasets
train_ds = create_dataset(train_paths, train_labels, shuffle=True, repeat=True)
val_ds = create_dataset(val_paths, val_labels, shuffle=False)
test_ds = create_dataset(test_paths, test_labels, shuffle=False)

print("Datasets created!")


## Data Augmentation

Data Augmentation is used to boost the size of the datasets to help account for the variability inherent in user-uploaded images. The keras.layers library is used for its flexibility - moderately strong augmentations are performed, striking a balance between expanding the dataset and computation time.

In [None]:
import tensorflow as tf
from keras import layers

rotation_factor = 30 / 360.0
zoom_factor = 0.2  # Scaling between 80% and 120%

data_augmentation_layers = [
    layers.RandomFlip("horizontal"),  # Random horizontal flip (probability of 0.5)
    layers.RandomRotation(factor=(-rotation_factor, rotation_factor)),  # Random rotation
    layers.RandomZoom(height_factor=(-zoom_factor, zoom_factor), width_factor=(-zoom_factor, zoom_factor)),  # Random scaling
    # keras_cv.layers.RandomBrightness(factor=(-0.3, 0.3)),  # Random brightness adjustment
]

def data_augmentation(images, targets):
    for layer in data_augmentation_layers:
        images = layer(images)
    # Clip values to keep between 0 and 1
    images = tf.clip_by_value(images, 0.0, 1.0)
    return images, targets


# Apply data augmentation to training dataset
augmented_train_ds = train_ds.map(data_augmentation, num_parallel_calls=tf.data.AUTOTUNE)
augmented_train_ds = augmented_train_ds.prefetch(tf.data.AUTOTUNE)



## Fine Tuning
A fine-tuned version of the ConvNeXt-Small model - ConvNeXt is a CNN that implements characteristics of Visual Transformers to keep it comparable with state-of-the-art pure ViT architectures (https://arxiv.org/abs/2201.03545). The top 4 layers - the most abstract representations of the data - are then fine tuned and trained on the mushroom dataset.

In [None]:
from tensorflow import keras
from keras import layers
from keras.applications import ConvNeXtSmall
import pickle

model_path = f'/content/drive/MyDrive/mushroom_masterv1_ft{version}.keras'
history_path = f'/content/drive/MyDrive/mushroom_training_history{version}.pkl'


#with strategy.scope():
base_model = ConvNeXtSmall(
    input_shape=(224, 224, 3),
    weights='imagenet',
    include_top=False
)
base_model.trainable = True
for layer in base_model.layers[:-4]:
    layer.trainable = False

inputs = keras.Input(shape=(224, 224, 3))
x = base_model(inputs)
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(len(class_names), activation='softmax')(x)
model = keras.Model(inputs, outputs)

model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0),
    metrics=["accuracy", tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top_5_accuracy')]
)

model.summary()

early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=8,
    start_from_epoch=10,
    min_delta=0.005,
    restore_best_weights=True
)

reduce_lr_plateau = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.1,
    patience=4,
    min_lr=1e-6
)

terminate_nan = tf.keras.callbacks.TerminateOnNaN()

steps_per_epoch = len(train_paths) // batch_size
validation_steps = len(val_paths) // batch_size if len(val_paths) > batch_size else None

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=model_path,
        save_best_only=True,
        monitor="val_loss"
    ),
    early_stopping,
    reduce_lr_plateau,
    terminate_nan
]

history = model.fit(
    augmented_train_ds,
    epochs=150,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=validation_steps,
    callbacks=callbacks
)

# Save training history
with open(history_path, 'wb') as f:
    pickle.dump(history.history, f)

# Display Data

Display the model validation accuracy and loss at each epoch of the training cycle, and accuracy against test set.



In [None]:
# Evaluate the model on the test data
test_loss, test_acc, test_top_5_acc = model.evaluate(test_ds, verbose=2)
print(f"Test accuracy: {test_acc}, Test Top 5 Accuracy: {test_top_5_acc}")

import pickle
history_path = f'/content/drive/MyDrive/mushroom_training_history{version}.pkl'
with open(history_path, 'rb') as f:
    history = pickle.load(f)

import matplotlib.pyplot as plt

# Plot training & validation accuracy values
plt.plot(history['accuracy'])
plt.plot(history['val_accuracy'])
plt.plot(history['top_5_accuracy'])
plt.plot(history['val_top_5_accuracy'])
plt.plot([test_acc] * len(history['accuracy']), linestyle='--')  # Horizontal line for test accuracy
plt.plot([test_top_5_acc] * len(history['accuracy']), linestyle='--')  # Horizontal line for test accuracy
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation', 'Top 5', 'Val Top 5', 'Test', 'Test Top 5'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.plot([test_loss] * len(history['loss']), linestyle='--')  # Horizontal line for test loss
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation', 'Test'], loc='upper left')
plt.show()

# Convert to TFLite
The trained model must be converted to a TFLite file to be compatible with the mobile app. From https://medium.com/@hellokhorshed/a-step-by-step-guide-to-convert-keras-model-to-tensorflow-lite-tflite-model-6c8d08707488

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open(f'/content/drive/MyDrive/mushroom_master_mobile{version}.tflite', 'wb') as f:
  f.write(tflite_model)

## Advanced Evaluation
As this is the model that is being used with the final build of the app, a more advanced evaluation is taken, including building a confusion matrix and evaluating the .tflite model.

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

version    = "v1.4"
model_path = f'/content/drive/MyDrive/mushroom_masterv1_ft{version}.keras'

# since this is evaluation is taking place after the initial traiing, the model needs to be re-loaded in to colab and re-compiled
model = tf.keras.models.load_model(model_path, compile=False)

model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0),
    metrics=["accuracy", tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top_5_accuracy')]
)


# confusion matrix and classificaiton report using scikit-learn and seaborn (for CM heatmap)
y_true = np.concatenate([y.numpy() for _, y in test_ds])
y_prob = model.predict(test_ds, verbose=0)
y_pred = np.argmax(y_prob, axis=1)

cm = confusion_matrix(y_true, y_pred, labels=range(len(class_names)))

plt.figure(figsize=(12, 10))
sns.heatmap(
    cm,
    cmap="Blues",
    xticklabels=class_names,
    yticklabels=class_names,
    square=True,
    cbar=False,
)
plt.title("Confusion matrix — test set")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

print(classification_report(y_true, y_pred, target_names=class_names, digits=3))


# evaluate TFLite model as the post-traing quantisation process may effect accuracy
tflite_path = f'/content/drive/MyDrive/mushroom_master_mobile{version}.tflite'
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()
input_idx  = interpreter.get_input_details()[0]["index"]
output_idx = interpreter.get_output_details()[0]["index"]

top1_hits = top5_hits = n_samples = 0
for batch_x, batch_y in test_ds:
    interpreter.set_tensor(input_idx, batch_x.numpy().astype(np.float32))
    interpreter.invoke()
    preds = interpreter.get_tensor(output_idx)

    top1 = np.argmax(preds, axis=1)
    top5 = np.argsort(preds, axis=1)[:, -5:]

    y_np = batch_y.numpy()
    top1_hits += np.sum(top1 == y_np)
    top5_hits += np.sum([label in p5 for label, p5 in zip(y_np, top5)])
    n_samples += y_np.size

print(f"TFLite top‑1 accuracy : {top1_hits / n_samples:.4f}")
print(f"TFLite top‑5 accuracy : {top5_hits / n_samples:.4f}")
