## Data Preparation
Classes are moderately imbalanced due to the nature of the dataset (aggregated on mushroomobserver.com), so effort has to be made to be intentional in splitting samples from each class fairly. A 0.8/0.1/0.1 split has been chosen to begin with due to the dataset being quite small https://encord.com/blog/train-val-test-split/.

In [None]:
import os
import numpy as np
from PIL import Image
import logging

source_directory = '/content/drive/Othercomputers/My MacBook Pro/MO_MI_images'
version = 'v1'

def is_valid_file(filepath):
    """Check if file is a valid image."""
    if not os.path.basename(filepath).startswith('.') and os.path.isfile(filepath):
        try:
            img = Image.open(filepath)
            img.verify()
            img.close()
            return True
        except (IOError, SyntaxError):
            logging.warning(f"Invalid image file: {filepath}")
            return False
    else:
        return False

def get_images_and_labels(source_dir, min_images_per_class=250):
    class_names = [d for d in os.listdir(source_dir)
                   if not d.startswith('.') and os.path.isdir(os.path.join(source_dir, d))]
    class_names.sort()

    # Filter classes based on minimum number of images
    valid_class_names = []
    class_to_image_paths = {}
    for class_name in class_names:
        class_dir = os.path.join(source_dir, class_name)
        image_filenames = [f for f in os.listdir(class_dir)
                           if is_valid_file(os.path.join(class_dir, f))]
        if len(image_filenames) >= min_images_per_class: # this filters out most classes - will have to integrate more sources into dataset
            valid_class_names.append(class_name)
            image_paths = [os.path.join(class_dir, f) for f in image_filenames]
            class_to_image_paths[class_name] = image_paths
        else:
            logging.warning(f"Skipping class '{class_name}' (only {len(image_filenames)} images)")

    # Create a mapping from class names to label indices
    class_names = sorted(valid_class_names)
    class_to_index = {class_name: index for index, class_name in enumerate(class_names)}

    # Collect all image paths and labels
    all_image_paths = []
    all_labels = []
    for class_name in class_names:
        image_paths = class_to_image_paths[class_name]
        labels = [class_to_index[class_name]] * len(image_paths)
        all_image_paths.extend(image_paths)
        all_labels.extend(labels)

    return all_image_paths, all_labels, class_names, class_to_index

all_image_paths, all_labels, class_names, class_to_index = get_images_and_labels(source_directory)
num_classes = len(class_names)
print(f"Number of classes: {num_classes}")
print(f"Total images: {len(all_image_paths)}")

# Save labels to a file
with open(f'content/drive/MyDrive/labels{version}.txt', 'w') as f:
    for label in class_names:
        f.write(f"{label}\n")

from sklearn.model_selection import train_test_split

# Convert to numpy arrays
all_image_paths = np.array(all_image_paths)
all_labels = np.array(all_labels)

# First split into train/val and test
train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
    all_image_paths, all_labels, test_size=0.1, stratify=all_labels, random_state=42)

# Then split train/val into train and val
train_paths, val_paths, train_labels, val_labels = train_test_split(
    train_val_paths, train_val_labels, test_size=0.1, stratify=train_val_labels, random_state=42)

print(f"Training samples: {len(train_paths)}")
print(f"Validation samples: {len(val_paths)}")
print(f"Test samples: {len(test_paths)}")

## Data Preprocessing
Data needs to be:


*   Decoded from JPEGs into RGB grids of pixels
*   From there converted into floating-point tensors
*   Reshaped into a standard size
*   Packed into batches

This can be done by the keras function `image_dataset_from_directory`. This function also features the ability to shuffle the training data to expose the model to a more diverse set of examples. (https://www.markovml.com/glossary/data-shuffling)



In [None]:
import tensorflow as tf

batch_size = 1536 # massive batch size to make use of 300GB RAM V28 TPU
image_size = (224, 224)

def load_and_preprocess_image(path, label):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, image_size)
    image = image / 255.0  # Normalize to [0,1] range
    return image, label

def create_dataset(image_paths, labels, shuffle=True):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(image_paths), seed=42)
    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

train_ds = create_dataset(train_paths, train_labels, shuffle=True)
val_ds = create_dataset(val_paths, val_labels, shuffle=False)
test_ds = create_dataset(test_paths, test_labels, shuffle=False)

## Data Augmentation

Data Augmentation is used to boost the size of the datasets to help account for the variability inherent in user-uploaded images. The imgaug library (https://imgaug.readthedocs.io/en/latest/) is used for its flexibility - moderately strong augmentations are performed, striking a balance between expanding the dataset and computation time.

In [None]:
!pip install keras-cv --upgrade
import keras_cv
import tensorflow as tf
from keras_cv.layers import RandomApply
from keras import layers

rotation_factor = 30 / 360.0
zoom_factor = 0.2  # Scaling between 80% and 120%

data_augmentation_layers = [
    layers.RandomFlip("horizontal"),  # Random horizontal flip (probability of 0.5)
    layers.RandomRotation(factor=(-rotation_factor, rotation_factor)),  # Random rotation
    layers.RandomZoom(height_factor=(-zoom_factor, zoom_factor), width_factor=(-zoom_factor, zoom_factor)),  # Random scaling
    RandomApply(
        layer=keras_cv.layers.RandomGaussianBlur(kernel_size=3, factor=(0.0, 1.0)),
        rate=0.5  # Apply with 50% probability
    ),
    keras_cv.layers.RandomBrightness(factor=(-0.3, 0.3)),  # Random brightness adjustment
    RandomApply(
        layer=keras_cv.layers.RandomCutout(
            height_factor=(0.2, 0.2),
            width_factor=(0.2, 0.2),
        ),
        rate=0.5
    ),
]

def data_augmentation(images, targets):
    for layer in data_augmentation_layers:
        images = layer(images)
    return images, targets

# Apply data augmentation to training dataset
augmented_train_ds = train_ds.map(data_augmentation, num_parallel_calls=tf.data.AUTOTUNE)
augmented_train_ds = augmented_train_ds.prefetch(tf.data.AUTOTUNE)



## Fine Tuning
A fine-tuned version of the EfficientNetV2M - EfficientNetV2 shows both high efficiency (in terms of training time) and top-1 accuracy on the ImageNet benchmark (https://arxiv.org/pdf/2104.00298) and is very efficient in terms of training time. The 'M' (Medium) model is used to strike a balance between computational time and accuracy. The top 4 layers - the most abstract representations of the data - are then fine tuned and trained on the mushroom dataset.

In [None]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.applications import EfficientNetV2M
from tensorflow.keras.callbacks import ReduceLROnPlateau
import os
import pickle

# Define paths for saving model and history
model_path = '/content/drive/MyDrive/mushroom_masterv1_ft.keras'
history_path = '/content/drive/MyDrive/mushroom_training_history.pkl'

try: # initialise TPU if it is not already intialised
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)
    print("TPU already initialized. Reusing existing TPU strategy.")
except ValueError:
    print("TPU not initialized. Skipping TPU initialization.")
    strategy = tf.distribute.MirroredStrategy()

print("All devices: ", tf.config.list_logical_devices('TPU'))

with strategy.scope(): # within TPU strategy defined above
    if os.path.exists(model_path):
        # Load the model if it exists - saves recompiling it
        model = keras.models.load_model(model_path)
        print("Model loaded from disk.")
    else:
        # Define and compile the model
        base_model = EfficientNetV2M(
            input_shape=(224, 224, 3),
            weights=None,
            include_top=True)

        base_model.load_weights(
            '/content/drive/MyDrive/efficientnetv2-m-21k.h5',
            by_name=True, skip_mismatch=True)

        base_model = keras.Model(
            inputs=base_model.input,
            outputs=base_model.get_layer(name='top_activation').output)

        base_model.trainable = True
        for layer in base_model.layers[:-4]:
            layer.trainable = False

        inputs = keras.Input(shape=(224, 224, 3))
        x = base_model(inputs)
        x = layers.GlobalAveragePooling2D()(x)
        x = layers.Dense(256)(x)
        x = layers.Dropout(0.25)(x)

        outputs = layers.Dense(num_classes, activation='softmax')(x)
        model = keras.Model(inputs, outputs)

        model.compile(loss="sparse_categorical_crossentropy",
                      optimizer=keras.optimizers.Adam(learning_rate=1e-3), # https://arxiv.org/pdf/1412.6980 - may need to adjust LR
                      metrics=["accuracy"])

    model.summary()

# Early stopping callback - stops the model running unnecessarily
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True
)

callbacks = [
    keras.callbacks.ModelCheckpoint( # good if training stops unexpectedly
        filepath=model_path,
        save_best_only=True,
        monitor="val_loss"),
    early_stopping
]

# Train the model
history = model.fit(
    train_ds,
    epochs=50,
    validation_data=val_ds,
    callbacks=callbacks
)

# Save the training history
with open(history_path, 'wb') as f:
    pickle.dump(history.history, f)



# Data Analysis

In [None]:
import pickle
history_path = '/content/drive/MyDrive/mushroom_training_history.pkl'
with open(history_path, 'rb') as f:
    history = pickle.load(f)

import matplotlib.pyplot as plt

# Plot training & validation accuracy values
plt.plot(history['accuracy'])
plt.plot(history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()