In [None]:
import os
import shutil
import random

def split_dataset_train_test(source_dir, dest_dir, train_ratio=0.9, seed=42):
    random.seed(seed)
    classes = [d for d in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, d))]
    for split in ['train', 'test']:
        for cls in classes:
            os.makedirs(os.path.join(dest_dir, split, cls), exist_ok=True)
    for cls in classes:
        cls_dir = os.path.join(source_dir, cls)
        images = [f for f in os.listdir(cls_dir) if os.path.isfile(os.path.join(cls_dir, f))]
        random.shuffle(images)
        n_total = len(images)
        n_train = int(n_total * train_ratio)
        train_files = images[:n_train]
        test_files = images[n_train:]
        for split, files in zip(['train', 'test'], [train_files, test_files]):
            for f in files:
                src = os.path.join(cls_dir, f)
                dst = os.path.join(dest_dir, split, cls, f)
                if not os.path.exists(dst):
                    shutil.copy2(src, dst)

# Only run this once to avoid duplicating files
source_dir = '../dataset'
dest_dir = '../dataset_split'
if not os.path.exists(os.path.join(dest_dir, 'train')):
    split_dataset_train_test(source_dir, dest_dir)

In [None]:
from tensorflow.keras.applications import MobileNetV2, EfficientNetB5
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import layers, models, Input


def custom_cnn(input_shape, num_classes):
    inputs = Input(shape=input_shape)

    # Block 1
    x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)

    # Block 2
    x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)

    # Block 3

    # Flatten + Fully Connected Layers
    x = layers.Flatten()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.5)(x)

    # Output layer
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    # Build model
    model = models.Model(inputs, outputs, name="CustomCNN")
    model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])

    return model

def mobilenet_model(input_shape, num_classes, freeze_base=True):
    base_model = EfficientNetB5(weights='imagenet', include_top=False, input_shape=input_shape)
    for layer in base_model.layers:
        layer.trainable = not freeze_base
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.4)(x)
    predictions = Dense(num_classes, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=predictions)
    model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def train_model(model, train_gen, val_gen, epochs=30):
    print('Training model...')
    early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=1)
    history = model.fit(
        train_gen,
        epochs=epochs,
        validation_data=val_gen,
        callbacks=[early_stop]
    )
    return history

def evaluate_model(model, history, test_gen, model_name):
    print('Evaluating on test set...')
    test_loss, test_acc = model.evaluate(test_gen, verbose=0)
    print(f"Test accuracy: {test_acc:.4f}, Test loss: {test_loss:.4f}")
    print(f'Saving model to {model_name}...')
    model.save(model_name)
    # Plot
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(history.history['accuracy'], label='Train Acc')
    plt.plot(history.history['val_accuracy'], label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy')
    plt.subplot(1,2,2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss')
    plt.tight_layout()
    plt.show()

def main():
    base_dir = '../dataset_split'
    train_dir = f'{base_dir}/train'
    test_dir = f'{base_dir}/test'

    img_size = (128, 128)
    batch_size = 32
    epochs = 8

    print('Loading data for MobileNetV2...')

    train_val_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)
    test_datagen = ImageDataGenerator(rescale=1./255)

    train_gen = train_val_datagen.flow_from_directory(
        train_dir,
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='training',
        shuffle=True,
        seed=42
    )

    val_gen = train_val_datagen.flow_from_directory(
        train_dir,
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation',
        shuffle=False
    )

    test_gen = test_datagen.flow_from_directory(
        test_dir,
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical',
        shuffle=False
    )

    num_classes = train_gen.num_classes

    # Train with frozen base
    model = mobilenet_model(img_size + (3,), num_classes, freeze_base=True)
    history = train_model(model, train_gen, val_gen, epochs=epochs)
    evaluate_model(model, history, test_gen, model_name='mushroom_mobilenet_frozen.h5')

    # Fine-tune
    #print('Unfreezing base model for fine-tuning...')
    #for layer in model.layers:
    #    layer.trainable = True
    #model.compile(optimizer=Adam(1e-5), loss='categorical_crossentropy', metrics=['accuracy'])
    #history_finetuned = train_model(model, train_gen, val_gen, epochs=10)
    #evaluate_model(model, history_finetuned, test_gen, model_name='mushroom_mobilenet_finetuned.h5')

if __name__ == '__main__':
    main()

Loading data for MobileNetV2...
Found 8056 images belonging to 106 classes.
Found 1908 images belonging to 106 classes.
Found 1166 images belonging to 106 classes.
Training model...
Epoch 1/8
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 166ms/step - accuracy: 0.0459 - loss: 4.5747 - val_accuracy: 0.2385 - val_loss: 3.1600
Epoch 2/8
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 167ms/step - accuracy: 0.2614 - loss: 3.0074 - val_accuracy: 0.3501 - val_loss: 2.6240
Epoch 3/8
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 167ms/step - accuracy: 0.3770 - loss: 2.3832 - val_accuracy: 0.3705 - val_loss: 2.4685
Epoch 4/8
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 172ms/step - accuracy: 0.4613 - loss: 2.0002 - val_accuracy: 0.4025 - val_loss: 2.3606
Epoch 5/8
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 166ms/step - accuracy: 0.5278 - loss: 1.7343 - val_accuracy: 0.4109 - val_loss: 2.358

KeyboardInterrupt: 