In [9]:
IMG_SIZE = 512
BATCH_SIZE = 16
EPOCHS = 5
VAL_SPLIT = 0.2

# Data Loading

In [10]:
import tensorflow as tf

dataset_path = "../../data-collection/image-backend/saved_images"

train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    dataset_path,
    labels='inferred',
    label_mode='int',
    class_names=['hazard', 'non-hazard'],
    color_mode='rgb',
    batch_size=BATCH_SIZE,
    image_size=(IMG_SIZE, IMG_SIZE),
    shuffle=True,
    seed=42,
    validation_split=VAL_SPLIT,
    subset="training",
)

Found 160 files belonging to 2 classes.
Using 128 files for training.


In [11]:
validation_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    dataset_path,
    labels='inferred',
    label_mode='int',
    class_names=['hazard', 'non-hazard'],
    color_mode='rgb',
    batch_size=BATCH_SIZE,
    image_size=IMG_SIZE,
    shuffle=True,
    seed=42,
    validation_split=VAL_SPLIT,
    subset="validation",
)

Found 160 files belonging to 2 classes.
Using 32 files for validation.


In [12]:
train_dataset.class_names

['hazard', 'non-hazard']

# Model Fine-tuning

In [13]:
def build_model():
    base_model = tf.keras.applications.Xception(
        weights="imagenet",
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
        include_top=False,
    ) 

    base_model.trainable = False

    inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))

    x = tf.keras.applications.xception.preprocess_input(inputs)

    x = base_model(x, training=False)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    outputs = tf.keras.layers.Dense(1)(x)

    model = tf.keras.Model(inputs, outputs, name="Xception")

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2),
        loss='binary_crossentropy',
        metrics=['accuracy'],
    )  

    return model

In [14]:
model = build_model()
model.summary()

In [15]:
model.fit(train_dataset, epochs=EPOCHS, verbose=2, validation_data=validation_dataset)

Epoch 1/5
8/8 - 124s - 15s/step - accuracy: 0.5156 - loss: 7.8114 - val_accuracy: 0.4375 - val_loss: 9.0664
Epoch 2/5
8/8 - 130s - 16s/step - accuracy: 0.5156 - loss: 7.8072 - val_accuracy: 0.4375 - val_loss: 9.0664
Epoch 3/5
8/8 - 137s - 17s/step - accuracy: 0.5156 - loss: 7.8072 - val_accuracy: 0.4375 - val_loss: 9.0664
Epoch 4/5
8/8 - 125s - 16s/step - accuracy: 0.5156 - loss: 7.8072 - val_accuracy: 0.4375 - val_loss: 9.0664
Epoch 5/5
8/8 - 123s - 15s/step - accuracy: 0.5156 - loss: 7.8072 - val_accuracy: 0.4375 - val_loss: 9.0664


<keras.src.callbacks.history.History at 0x1574bfa10>

In [16]:
model.save("../models/Xception.keras")