### V1.2o
This notebook trains the model without utilising transfer learning - the model is purely trained on the mushroom dataset

In [None]:
!pip install tensorflow --upgrade


## Data Preparation
Classes are moderately imbalanced due to the nature of the dataset (aggregated on mushroomobserver.com), so effort has to be made to be intentional in splitting samples from each class fairly. A 0.8/0.1/0.1 split has been chosen to begin with due to the dataset being quite small https://encord.com/blog/train-val-test-split/.

In [None]:

import os
import numpy as np
from PIL import Image
import tensorflow as tf
from google.cloud import storage
from io import BytesIO

import numpy as np

import tensorflow_datasets as tfds
from google.colab import auth
from sklearn.model_selection import train_test_split
import os
auth.authenticate_user()

from google.colab import drive
drive.mount('/content/drive')
import shutil
import glob

project_id = "mushroom-master-136c0"
bucket_name = "mushroom-master-central"
source_directory = "cleaned_dataset/"
version = "v1.2o"

client = storage.Client(project=project_id)
bucket = client.bucket(bucket_name)

!mkdir -p /content/local_data/cleaned_dataset
!gsutil -m rsync -r gs://{bucket_name}/{source_directory} /content/local_data/cleaned_dataset

local_data_dir = "/content/local_data/cleaned_dataset"

def remove_hidden_files(folder_path):
    for root, dirs, files in os.walk(folder_path):
        for d in dirs:
            if d.startswith('.'):
                shutil.rmtree(os.path.join(root, d))
        for f in files:
            if f.startswith('.'):
                os.remove(os.path.join(root, f))

def remove_corrupted_images(folder_path, valid_exts=(".jpg", ".jpeg", ".png")):
    for ext in valid_exts:
        pattern = os.path.join(folder_path, "**", f"*{ext}")
        for file_path in glob.glob(pattern, recursive=True):
            try:
                with Image.open(file_path) as img:
                    img.verify()
                img = Image.open(file_path)
                img.load()
            except:
                os.remove(file_path)

def prune_empty_classes(folder_path, min_count=10):
    for class_name in os.listdir(folder_path):
        class_dir = os.path.join(folder_path, class_name)
        if os.path.isdir(class_dir):
            valid_images = [f for f in os.listdir(class_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))]
            if len(valid_images) < min_count:
                shutil.rmtree(class_dir)

def get_images_and_labels_local(local_dir):
    class_names = sorted([
        d for d in os.listdir(local_dir) if os.path.isdir(os.path.join(local_dir, d)) and not d.startswith('.')
    ])
    class_to_index = {name: idx for idx, name in enumerate(class_names)}
    paths = []
    labels = []
    for class_name in class_names:
        class_path = os.path.join(local_dir, class_name)
        for root, _, files in os.walk(class_path):
            for f in files:
                if f.lower().endswith((".jpg", ".jpeg", ".png")) and not f.startswith('.'):
                    rel_path = os.path.join(os.path.relpath(root, local_dir), f)
                    paths.append(rel_path)
                    labels.append(class_to_index[class_name])
    return np.array(paths), np.array(labels), class_names, class_to_index


remove_hidden_files(local_data_dir)
remove_corrupted_images(local_data_dir)
prune_empty_classes(local_data_dir, min_count=1)

all_image_paths, all_labels, class_names, class_to_index = get_images_and_labels_local(local_data_dir)
num_classes = len(class_names)
train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
    all_image_paths, all_labels, test_size=0.1, stratify=all_labels, shuffle=True, random_state=42)

train_paths, val_paths, train_labels, val_labels = train_test_split(
    train_val_paths, train_val_labels, test_size=0.1, stratify=train_val_labels, random_state=42)

batch_size = 128
image_size = (224, 224)

def parse_image(filename, label):
    full_path = tf.strings.join(["/content/local_data/cleaned_dataset/", filename])
    image_raw = tf.io.read_file(full_path)
    image = tf.image.decode_jpeg(image_raw, channels=3)
    image = tf.image.resize(image, image_size)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

AUTOTUNE = tf.data.AUTOTUNE

def create_dataset(paths, labels, shuffle=True, repeat=False):
    ds = tf.data.Dataset.from_tensor_slices((paths, labels))
    if shuffle:
        ds = ds.shuffle(len(paths), seed=42)
    if repeat:
        ds = ds.repeat()
    ds = ds.map(parse_image, num_parallel_calls=AUTOTUNE)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(AUTOTUNE)
    return ds

train_ds = create_dataset(train_paths, train_labels, shuffle=True, repeat=True)
val_ds = create_dataset(val_paths, val_labels, shuffle=False)
test_ds = create_dataset(test_paths, test_labels, shuffle=False)

print("Train set:", len(train_paths), "Val set:", len(val_paths), "Test set:", len(test_paths))

## Data Augmentation

Data Augmentation is used to boost the size of the datasets to help account for the variability inherent in user-uploaded images. The imgaug library (https://imgaug.readthedocs.io/en/latest/) is used for its flexibility - moderately strong augmentations are performed, striking a balance between expanding the dataset and computation time.

In [None]:
from keras import layers

rotation_factor = 30 / 360.0
zoom_factor = 0.2  # Scaling between 80% and 120%

data_augmentation_layers = [
    layers.RandomFlip("horizontal"),  # Random horizontal flip (probability of 0.5)
    layers.RandomRotation(factor=(-rotation_factor, rotation_factor)),  # Random rotation
    layers.RandomZoom(height_factor=(-zoom_factor, zoom_factor), width_factor=(-zoom_factor, zoom_factor)),  # Random scaling
]

def data_augmentation(images, targets):
    for layer in data_augmentation_layers:
        images = layer(images)
    # Clip values to ensure they remain in [0,1]
    images = tf.clip_by_value(images, 0.0, 1.0)
    return images, targets


# Apply data augmentation to training dataset
augmented_train_ds = train_ds.map(data_augmentation, num_parallel_calls=tf.data.AUTOTUNE)
augmented_train_ds = augmented_train_ds.prefetch(tf.data.AUTOTUNE)

## Fine Tuning
A fine-tuned version of the ConvNeXt-Base model - ConvNeXt is a CNN that implements characteristics of Visual Transformers to keep it comparable with state-of-the-art pure ViT architectures (https://arxiv.org/abs/2201.03545). The top 4 layers - the most abstract representations of the data - are then fine tuned and trained on the mushroom dataset.

In [None]:
from tensorflow import keras
from keras import utils
from keras.applications import ConvNeXtSmall
import pickle

model_path = f'/content/drive/MyDrive/mushroom_masterv1_ft{version}.keras'
history_path = f'/content/drive/MyDrive/mushroom_training_history{version}.pkl'

base_model = ConvNeXtSmall(
    include_top=False,
    weights=None,
    input_shape=(224, 224, 3),
    pooling='avg'
)


inputs= layers.Input(shape=(224, 224, 3))

x = base_model(inputs)

outputs = layers.Dense(num_classes, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)


model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0),
    metrics=["accuracy", tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top_5_accuracy')
]
)

model.summary()

early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=8,
    start_from_epoch=10,
    min_delta=0.005,
    restore_best_weights=True
)

reduce_lr_plateau = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.1,
    patience=4,
    min_lr=1e-6
)

steps_per_epoch = len(train_paths) // batch_size
validation_steps = len(val_paths) // batch_size if len(val_paths) > batch_size else None

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=model_path,
        save_best_only=True,
        monitor="val_loss"
    ),
    early_stopping,
    reduce_lr_plateau
]

history = model.fit(
    augmented_train_ds,
    epochs=35,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=validation_steps,
    callbacks=callbacks
)

# Save training history
with open(history_path, 'wb') as f:
    pickle.dump(history.history, f)

utils.plot_model(model, to_file='/content/drive/model_architecture.png', show_shapes=True, show_layer_names=True)

# Display Data

Display the model validation accuracy and loss at each epoch of the training cycle, and accuracy against test set.



In [None]:
# Evaluate the model on the test data
!pip install tensorflow --upgrade
import tensorflow as tf
version = 'v1.2o'
model_path = f'/content/drive/MyDrive/mushroom_masterv1_ft{version}.keras'
model = tf.keras.models.load_model(model_path)
# test_loss, test_acc, test_top_5_acc = model.evaluate(test_ds, verbose=2)
#print(f"Test accuracy: {test_acc}, Test Top 5 Accuracy: {test_top_5_acc}")

import pickle
history_path = f'/content/drive/MyDrive/mushroom_training_history{version}.pkl'
with open(history_path, 'rb') as f:
    history = pickle.load(f)

import matplotlib.pyplot as plt

# Plot training & validation accuracy values
plt.plot(history['accuracy'])
plt.plot(history['val_accuracy'])
plt.plot(history['top_5_accuracy'])
plt.plot(history['val_top_5_accuracy'])

#plt.plot([test_acc] * len(history['accuracy']), linestyle='--')  # Horizontal line for test accuracy
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation', 'Test', 'Top 5', 'Val Top 5'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history['loss'])
plt.plot(history['val_loss'])
#plt.plot([test_loss] * len(history['loss']), linestyle='--')  # Horizontal line for test loss
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation', 'Test'], loc='upper left')
plt.show()

# Convert to TFLite
The trained model must be converted to a TFLite file to be compatible with the mobile app. From https://medium.com/@hellokhorshed/a-step-by-step-guide-to-convert-keras-model-to-tensorflow-lite-tflite-model-6c8d08707488

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model) #not doing this as the model underperformed
tflite_model = converter.convert()
with open(f'/content/drive/MyDrive/mushroom_master_mobile{version}.tflite', 'wb') as f:
  f.write(tflite_model)

In [None]:
from google.colab import drive
drive.mount('/content/drive')