In [9]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.applications import VGG16

In [5]:
train_path = "dataset/train"
test_path = "dataset/test"
num_classes = 4

In [25]:
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

for layer in base_model.layers:
    layer.trainable = False

In [43]:
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(4, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

In [44]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [45]:
train_ds = image_dataset_from_directory(
    train_path,
    validation_split=0.2,
    image_size=(224, 224),
    batch_size=32,
    seed=4242,
    subset="training"
)

valid_ds = image_dataset_from_directory(
    train_path,
    validation_split=0.2,
    image_size=(224, 224),
    batch_size=32,
    seed=4242,
    subset="validation"
)

test_ds = image_dataset_from_directory(
    test_path,
    image_size=(224, 224),
    batch_size=32
)

Found 2249 files belonging to 4 classes.
Using 1800 files for training.
Found 2249 files belonging to 4 classes.
Using 449 files for validation.
Found 688 files belonging to 4 classes.


In [None]:
history = model.fit(
    train_ds,
    validation_data=valid_ds, 
    epochs=50
)

In [None]:
model.save("marble.h5")

In [None]:
history_df = pd.DataFrame(history.history)
history_df.head()

In [None]:
plt.figure()
plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend(["train", "val"])
plt.show()

In [None]:
plt.figure()
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend(["train", "val"])
plt.show()

In [None]:
test_loss, test_acc = model.evaluate(validation_generator)
print('Test accuracy:', test_acc)