In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import pathlib
import datetime

# Model configuration
IMAGE_SHAPE = (224, 224)
batch_size = 32
num_epochs = 2

# Load and prepare dataset
data_root = pathlib.Path('./dataset')
class_dirs = ['leaf', 'non_leaf']
data_paths = [data_root / class_name for class_name in class_dirs]
for path in data_paths:
    assert path.exists(), f"Could not find directory at {path}"

# Create datasets
train_ds = tf.keras.utils.image_dataset_from_directory(
    str(data_root),
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=IMAGE_SHAPE,
    batch_size=batch_size
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    str(data_root),
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=IMAGE_SHAPE,
    batch_size=batch_size
)

class_names = np.array(train_ds.class_names)

# Normalize and optimize data pipeline
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)).cache().prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)).cache().prefetch(buffer_size=tf.data.AUTOTUNE)

# Create model
feature_extractor_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
feature_extractor_layer = hub.KerasLayer(
    feature_extractor_url,
    input_shape=(*IMAGE_SHAPE, 3),
    trainable=False
)

model = tf.keras.Sequential([
    feature_extractor_layer,
    tf.keras.layers.Dense(len(class_names))
])

# Compile and train
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['acc']
)

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=num_epochs,
    callbacks=[tensorboard_callback]
)

# Save model
model.save('../leaf_nonleaf.h5')

Found 25518 files belonging to 2 classes.
Using 20415 files for training.
Found 25518 files belonging to 2 classes.
Using 5103 files for validation.
Epoch 1/2
Epoch 2/2
