In [None]:
# mushroom_master_local_gcs.py
# -*- coding: utf-8 -*-

"""

Requires:
- google-cloud-storage
- tensorflow, tensorflow_datasets, keras-cv
- scikit-learn
- matplotlib, pillow
- gcloud authentication on your local machine
"""

import os
import glob
import shutil
import pickle
import numpy as np
from PIL import Image

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from keras import layers
from keras.applications import ConvNeXtSmall
from sklearn.model_selection import train_test_split

# For data augmentation
import keras_cv
from keras_cv.layers import RandomApply

###############################################################################
# 1. Load iNaturalist Data (mini) via TFDS (for pretraining).
###############################################################################

def preprocess_inat_data(image, label, img_size=(224, 224)):
    image = tf.image.resize(image, img_size)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

print("Loading iNaturalist mini dataset...")
ds_inat_train, ds_inat_val = tfds.load(
    "i_naturalist2021_mini",
    split=["train[:90%]", "train[90%:]"],
    as_supervised=True
)

ds_inat_train = ds_inat_train.map(preprocess_inat_data, num_parallel_calls=tf.data.AUTOTUNE)
ds_inat_val = ds_inat_val.map(preprocess_inat_data, num_parallel_calls=tf.data.AUTOTUNE)

ds_inat_train = ds_inat_train.batch(256).prefetch(tf.data.AUTOTUNE)
ds_inat_val = ds_inat_val.batch(256).prefetch(tf.data.AUTOTUNE)

# iNaturalist 2021 mini has 11083 classes
num_inat_classes = 11083

inat_base_model = ConvNeXtSmall(
    input_shape=(224, 224, 3),
    weights=None,
    include_top=False
)
inat_base_model.trainable = True

inputs_inat = layers.Input(shape=(224, 224, 3))
x_inat = inat_base_model(inputs_inat)
x_inat = layers.GlobalAveragePooling2D()(x_inat)
outputs_inat = layers.Dense(num_inat_classes, activation='softmax')(x_inat)
inat_model = tf.keras.Model(inputs_inat, outputs_inat)

inat_model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    metrics=["accuracy", tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)]
)

print("Starting pretraining on iNaturalist (mini)...")
inat_model.fit(
    ds_inat_train,
    validation_data=ds_inat_val,
    epochs=10  # Increase if you have time/resources
)

inat_checkpoint_path = "./inat_pretrained_checkpoint.keras"
inat_model.save(inat_checkpoint_path)
print(f"Saved iNaturalist pretrained model to {inat_checkpoint_path}")







In [None]:
###############################################################################
# 2. Read and organize the Mushroom data from GCS
###############################################################################
# We'll do the same approach: gather file paths from GCS, parse them by class,
# split them into train/val/test sets with stratification, etc.

# Adjust these to match your GCS environment
bucket_name = "mushroom-master-central"           # e.g. "my-bucket"
source_directory = "cleaned_dataset"              # e.g. "cleaned_dataset"
gcs_prefix = f"gs://{bucket_name}/{source_directory}"

# Remove hidden/corrupted images logic won't be the same on GCS as local,
# but we can filter out obviously invalid paths or any non-image extension.

def is_image_file(filename):
    ext = filename.lower().rsplit('.', 1)[-1]
    return ext in ["jpg", "jpeg", "png"]

def list_gcs_files_recursive(prefix):
    """
    Recursively list files under the given GCS prefix using tf.io.gfile.
    Returns a list of all file paths that look like images.
    """
    files = []
    # tf.io.gfile.walk works similarly to os.walk, but for GCS
    for (dirpath, dirnames, filenames) in tf.io.gfile.walk(prefix):
        for fname in filenames:
            full_path = os.path.join(dirpath, fname)
            if is_image_file(full_path):
                files.append(full_path)
    return sorted(files)

all_gcs_files = list_gcs_files_recursive(gcs_prefix)

# We'll parse class names by splitting out the directory structure:
# e.g. "gs://bucket/cleaned_dataset/<class_name>/image.jpg"
# so we want the piece after "cleaned_dataset/" until the next slash as the class_name.

def extract_class_name(file_path):
    # Example path: "gs://bucket/cleaned_dataset/<class_name>/image.jpg"
    # We want <class_name>
    # We'll split by the source_directory + "/"
    # then split by slash again
    rel = file_path.split(source_directory + "/")[-1]
    class_name = rel.split("/")[0]
    return class_name

class_names = {}
paths = []
labels = []
for fpath in all_gcs_files:
    class_name = extract_class_name(fpath)
    if class_name not in class_names:
        class_names[class_name] = len(class_names)
    class_idx = class_names[class_name]
    paths.append(fpath)
    labels.append(class_idx)

paths = np.array(paths)
labels = np.array(labels, dtype=int)
unique_class_names = sorted(class_names.keys(), key=lambda x: class_names[x])
num_classes = len(unique_class_names)

# Now do train/test/val split with stratify
train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
    paths, labels, test_size=0.1, stratify=labels, shuffle=True, random_state=42)

train_paths, val_paths, train_labels, val_labels = train_test_split(
    train_val_paths, train_val_labels, test_size=0.1, stratify=train_val_labels, random_state=42)

print(f"Classes found: {num_classes}")
print(f"Train set size: {len(train_paths)}")
print(f"Val set size:   {len(val_paths)}")
print(f"Test set size:  {len(test_paths)}")

In [None]:
###############################################################################
# 3. Build tf.data.Dataset from GCS paths
###############################################################################

batch_size = 256
image_size = (224, 224)

def parse_gcs_image(file_path, label):
    """
    Reads an image directly from GCS using tf.io.read_file, decodes, resizes, normalizes.
    """
    image_raw = tf.io.read_file(file_path)
    image = tf.image.decode_image(image_raw, channels=3, expand_animations=False)
    image = tf.image.resize(image, image_size)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

def create_dataset(paths_arr, labels_arr, shuffle=True, repeat=False):
    ds = tf.data.Dataset.from_tensor_slices((paths_arr, labels_arr))
    if shuffle:
        ds = ds.shuffle(len(paths_arr), seed=42)
    if repeat:
        ds = ds.repeat()
    ds = ds.map(parse_gcs_image, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

train_ds = create_dataset(train_paths, train_labels, shuffle=True, repeat=True)
val_ds = create_dataset(val_paths, val_labels, shuffle=False)
test_ds = create_dataset(test_paths, test_labels, shuffle=False)



In [None]:
###############################################################################
# 4. Data Augmentation for the Mushroom Dataset
###############################################################################

rotation_factor = 30 / 360.0
zoom_factor = 0.2

data_augmentation_layers = [
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(factor=(-rotation_factor, rotation_factor)),
    layers.RandomZoom(height_factor=(-zoom_factor, zoom_factor), width_factor=(-zoom_factor, zoom_factor)),
    keras_cv.layers.RandomBrightness(factor=(-0.3, 0.3)),
]

def data_augmentation(images, targets):
    for layer in data_augmentation_layers:
        images = layer(images)
    images = tf.clip_by_value(images, 0.0, 1.0)
    return images, targets

augmented_train_ds = train_ds.map(data_augmentation, num_parallel_calls=tf.data.AUTOTUNE)
augmented_train_ds = augmented_train_ds.prefetch(tf.data.AUTOTUNE)


In [None]:
###############################################################################
# 5. Fine-Tuning on Mushroom Images
###############################################################################

# Reload the pretrained iNat model base
print(f"Loading iNaturalist pretrained model from {inat_checkpoint_path} ...")
inat_base_model = keras.models.load_model(inat_checkpoint_path)

# Freeze all but top 4 layers
for layer in inat_base_model.layers[:-4]:
    layer.trainable = False

# Build final model for mushroom classification
inputs_mushroom = layers.Input(shape=(224, 224, 3))
x_mushroom = inat_base_model(inputs_mushroom)
x_mushroom = layers.GlobalAveragePooling2D()(x_mushroom)
outputs_mushroom = layers.Dense(num_classes, activation='softmax')(x_mushroom)
model = tf.keras.Model(inputs_mushroom, outputs_mushroom)

model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0),
    metrics=["accuracy", tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)]
)

model_path = "./mushroom_masterv1_ft.keras"
history_path = "./mushroom_training_history.pkl"

early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=8,
    start_from_epoch=10,
    min_delta=0.005,
    restore_best_weights=True
)

reduce_lr_plateau = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.1,
    patience=4,
    min_lr=1e-6
)

steps_per_epoch = len(train_paths) // batch_size
validation_steps = len(val_paths) // batch_size if len(val_paths) > batch_size else None

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=model_path,
        save_best_only=True,
        monitor="val_loss"
    ),
    early_stopping,
    reduce_lr_plateau
]

print("Beginning fine-tuning on mushroom images...")
history_obj = model.fit(
    augmented_train_ds,
    epochs=150,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=validation_steps,
    callbacks=callbacks
)

# Save training history
with open(history_path, 'wb') as f:
    pickle.dump(history_obj.history, f)
print(f"Training history saved to: {history_path}")


In [None]:
###############################################################################
# 6. Evaluate on Test Set
###############################################################################

test_loss, test_acc, test_top_5_acc = model.evaluate(test_ds, verbose=2)
print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f} | Top-5 Accuracy: {test_top_5_acc:.4f}")





In [None]:
###############################################################################
# 7. Plot Accuracy/Loss
###############################################################################
import matplotlib.pyplot as plt

with open(history_path, 'rb') as f:
    history = pickle.load(f)

plt.plot(history['accuracy'], label='Train Acc')
plt.plot(history['val_accuracy'], label='Val Acc')
plt.plot([test_acc]*len(history['accuracy']), linestyle='--', label='Test Acc')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

plt.plot(history['loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.plot([test_loss]*len(history['loss']), linestyle='--', label='Test Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
###############################################################################
# 8. Convert to TFLite
###############################################################################

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open("./mushroom_master_mobile.tflite", 'wb') as f:
    f.write(tflite_model)

print("TFLite model saved to mushroom_master_mobile.tflite")
print("Done.")