### V1.5
This notebook removes oversampling to avoid overfitting on rare mushroom species, changes the base model from ConvNext-Small to EfficientNetV2S, adds some optimisations such as increased input image resolution, and improves model evaluation.

## Data Preparation
Classes are moderately imbalanced due to the nature of the dataset (aggregated on mushroomobserver.com), so effort has to be made to be intentional in splitting samples from each class fairly. A 0.8/0.1/0.1 split has been chosen to begin with due to the dataset being quite small https://encord.com/blog/train-val-test-split/.

In [None]:
!pip install -U crcmod

import os
import numpy as np
from PIL import Image
import tensorflow as tf
from google.cloud import storage
from io import BytesIO
from google.colab import auth
from sklearn.model_selection import train_test_split
import shutil
import glob

auth.authenticate_user()
from google.colab import drive
drive.mount('/content/drive')

!pip install --upgrade tensorflow

project_id = "mushroom-master-136c0"
bucket_name = "mushroom-master-central"
source_directory = "cleaned_dataset/"


client = storage.Client(project=project_id)
bucket = client.bucket(bucket_name)

!mkdir -p /content/local_data/cleaned_dataset
!gsutil -m rsync -r gs://{bucket_name}/{source_directory} /content/local_data/cleaned_dataset

local_data_dir = "/content/local_data/cleaned_dataset"

def remove_hidden_files(folder_path):
    for root, dirs, files in os.walk(folder_path):
        for d in dirs:
            if d.startswith('.'):
                shutil.rmtree(os.path.join(root, d))
        for f in files:
            if f.startswith('.'):
                os.remove(os.path.join(root, f))

def prune_empty_classes(folder_path, min_count=10):
    for class_name in os.listdir(folder_path):
        class_dir = os.path.join(folder_path, class_name)
        if os.path.isdir(class_dir):
            valid_images = [f for f in os.listdir(class_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))]
            if len(valid_images) < min_count:
                shutil.rmtree(class_dir)

def get_images_and_labels_local(local_dir):
    class_names = sorted([
        d for d in os.listdir(local_dir) if os.path.isdir(os.path.join(local_dir, d)) and not d.startswith('.')
    ])
    class_to_index = {name: idx for idx, name in enumerate(class_names)}
    paths = []
    labels = []
    for class_name in class_names:
        class_path = os.path.join(local_dir, class_name)
        for root, _, files in os.walk(class_path):
            for f in files:
                if f.lower().endswith((".jpg", ".jpeg", ".png")) and not f.startswith('.'):
                    rel_path = os.path.join(os.path.relpath(root, local_dir), f)
                    paths.append(rel_path)
                    labels.append(class_to_index[class_name])
    return np.array(paths), np.array(labels), class_names, class_to_index

remove_hidden_files(local_data_dir)
prune_empty_classes(local_data_dir, min_count=1) # to remove any hidden files that might get into the model undetected

all_image_paths, all_labels, class_names, class_to_index = get_images_and_labels_local(local_data_dir)

train_val_paths, test_paths, train_val_labels, test_labels = train_test_split( # split into train and test
    all_image_paths, all_labels, test_size=0.1, stratify=all_labels, shuffle=True, random_state=42)

train_paths, val_paths, train_labels, val_labels = train_test_split( # split train into train and val - this actually results in a 0.81/0.09/0.1 split but this is rounded to 0.8/0.1/0.1 for simplicity and readability
    train_val_paths, train_val_labels, test_size=0.1, stratify=train_val_labels, shuffle=True, random_state=42)


## Data Preprocessing
Data needs to be:


*   Decoded from JPEGs into RGB grids of pixels
*   From there converted into floating-point tensors
*   Reshaped into a standard size
*   Packed into batches



In [None]:
from PIL import Image
from tensorflow.keras import mixed_precision

version = "v1.5"
mixed_precision.set_global_policy('mixed_float16')
batch_size = 128
image_size = (384, 384)

def load_and_preprocess_image(path, label):
     full_path = tf.strings.join(["/content/local_data/cleaned_dataset/", path])
     try:
         image_raw = tf.io.read_file(full_path)
         img = tf.image.decode_image(image_raw, channels=3, expand_animations=False)
     except tf.errors.InvalidArgumentError:
         # dummy image in case of error
         img = tf.zeros((image_size[0], image_size[1], 3), dtype=tf.uint8)

     img = tf.image.resize(img, image_size)
     img = tf.cast(img, tf.float32) / 255.0
     return img, label

def create_dataset(image_paths, labels, shuffle=True, repeat=False):
    # Converts lists of image paths and labels to tensors
    image_paths = tf.convert_to_tensor(image_paths, dtype=tf.string)
    labels = tf.convert_to_tensor(labels, dtype=tf.int32)

    # Create a dataset from the image paths and labels
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(image_paths), seed=42)
    if repeat:
        dataset = dataset.repeat()

    # Map the loading function which now returns only image and label
    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

# Create the datasets
train_ds = create_dataset(train_paths, train_labels, shuffle=True, repeat=True)
val_ds = create_dataset(val_paths, val_labels, shuffle=False)
test_ds = create_dataset(test_paths, test_labels, shuffle=False)

print("Datasets created!")


## Data Augmentation

Data Augmentation is used to boost the size of the datasets to help account for the variability inherent in user-uploaded images. The keras.layers library is used for its flexibility - moderately strong augmentations are performed, striking a balance between expanding the dataset and computation time.

In [None]:
import tensorflow as tf
from keras import layers

rotation_factor = 30 / 360.0
zoom_factor = 0.2  # Scaling between 80% and 120%

data_augmentation_layers = [
    layers.RandomFlip("horizontal"),  # Random horizontal flip (probability of 0.5)
    layers.RandomRotation(factor=(-rotation_factor, rotation_factor)),  # Random rotation
    layers.RandomZoom(height_factor=(-zoom_factor, zoom_factor), width_factor=(-zoom_factor, zoom_factor)),  # Random scaling
    # keras_cv.layers.RandomBrightness(factor=(-0.3, 0.3)),  # Random brightness adjustment
]

def data_augmentation(images, targets):
    for layer in data_augmentation_layers:
        images = layer(images)
    # Clip values to keep between 0 and 1
    images = tf.clip_by_value(images, 0.0, 1.0)
    return images, targets


# Apply data augmentation to training dataset
augmented_train_ds = train_ds.map(data_augmentation, num_parallel_calls=tf.data.AUTOTUNE)
augmented_train_ds = augmented_train_ds.prefetch(tf.data.AUTOTUNE)



In [None]:
from sklearn.utils.class_weight import compute_class_weight

classes = np.unique(train_labels)
weights = compute_class_weight(
    class_weight='balanced',
    classes=classes,
    y=train_labels
)

class_weight = { int(c): w for c, w in zip(classes, weights) }

## Fine Tuning
A fine-tuned version of the ConvNeXt-Small model - ConvNeXt is a CNN that implements characteristics of Visual Transformers to keep it comparable with state-of-the-art pure ViT architectures (https://arxiv.org/abs/2201.03545). The top 4 layers - the most abstract representations of the data - are then fine tuned and trained on the mushroom dataset.

In [None]:
from tensorflow import keras
from keras import layers
from keras.applications import EfficientNetV2S
import pickle

# Paths
model_path      = '/content/drive/MyDrive/mushroom_masterv1_ftv1.5.keras'
history_path    = '/content/drive/MyDrive/mushroom_training_historyv1.5.pkl'

# load last model checkpoint
model = keras.models.load_model(model_path) # training was interrupted by runnin gout of compute - had to reload checkpoint

model.compile(
    loss = keras.losses.SparseCategoricalCrossentropy(),
    optimizer = keras.optimizers.Adam(learning_rate=1e-4, clipnorm=1.0),
    metrics = ["accuracy", keras.metrics.SparseTopKCategoricalAccuracy(5, name='top_5_accuracy')]
)

model.summary()

early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=11,
    start_from_epoch=10,
    min_delta=0.005,
    restore_best_weights=True
)

reduce_lr_plateau = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.1,
    patience=4,
    min_lr=1e-6
)

terminate_nan = tf.keras.callbacks.TerminateOnNaN()

steps_per_epoch = len(train_paths) // batch_size
validation_steps = len(val_paths) // batch_size if len(val_paths) > batch_size else None

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=model_path,
        save_best_only=True,
        monitor="val_loss"
    ),
    early_stopping,
    reduce_lr_plateau,
    terminate_nan
]

history = model.fit(
    augmented_train_ds,
    epochs=140,
    initial_epoch=33,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=validation_steps,
    callbacks=callbacks,
    class_weight=class_weight
)


# Save training history
with open(history_path, 'wb') as f:
    pickle.dump(history.history, f)