In [27]:
import os, shutil
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras import (models, layers, optimizers)
from tensorflow.keras.applications import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input

import plotly.express as ex
import plotly.graph_objects as go

In [28]:
dataset_path = "../data/processed"

In [29]:
print(len(os.listdir(dataset_path + "/train/dogs")))
print(len(os.listdir(dataset_path + "/train/cats")))

1414
1422


In [30]:
BATCH_SIZE = 16
IMG_SIZE = 150

In [31]:
batch_size = 16
x_train, x_valid = tf.keras.utils.image_dataset_from_directory(
    (dataset_path + "/train"),
    batch_size=batch_size,
    image_size=(IMG_SIZE, IMG_SIZE),
    validation_split=.2,
    subset="both",
    seed=42
)

Found 2836 files belonging to 2 classes.
Using 2269 files for training.
Using 567 files for validation.


In [32]:
x_train_count = len(x_train) * batch_size
x_valid_count = len(x_valid) * batch_size
print(x_train_count)
print(x_valid_count)

2272
576


In [33]:
# https://www.tensorflow.org/guide/keras/preprocessing_layers
data_augmentation = tf.keras.Sequential([
    layers.Resizing(IMG_SIZE, IMG_SIZE),
    layers.Rescaling(1./255),
    layers.RandomRotation(.4),
    layers.RandomWidth(.2),
    layers.RandomHeight(.2),
    # layers.RandomCrop(.2, .2),
    layers.RandomZoom(.2),
    layers.RandomFlip("horizontal"),
    layers.Resizing(IMG_SIZE, IMG_SIZE),
])

In [34]:
# ~/.keras/models
base_model = VGG16(
    weights='imagenet',
    include_top=False,
    input_shape=(150, 150, 3)
)
base_model.trainable = False

In [35]:
model = models.Sequential()
model.add(data_augmentation)
model.add(base_model)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512))
model.add(layers.Dense(1, activation='sigmoid'))
model.build(input_shape=(None, IMG_SIZE, IMG_SIZE, 3))
model.compile(
    optimizer=optimizers.RMSprop(learning_rate=2e-5),
    loss='binary_crossentropy',
    metrics=['acc']
)
model.summary()

Model: "sequential_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential_6 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 vgg16 (Functional)          (None, 4, 4, 512)         14714688  
                                                                 
 flatten_2 (Flatten)         (None, 8192)              0         
                                                                 
 dense_8 (Dense)             (None, 256)               2097408   
                                                                 
 dense_9 (Dense)             (None, 1)                 257       
                                                                 
Total params: 16,812,353
Trainable params: 2,097,665
Non-trainable params: 14,714,688
_________________________________________________________________


In [36]:
history = model.fit(
    x_train.repeat(),
    steps_per_epoch=int(x_train_count/BATCH_SIZE),
    epochs=30,
    validation_data=x_valid.repeat(),
    validation_steps=int(x_valid_count/BATCH_SIZE)
)

Epoch 1/30
  8/142 [>.............................] - ETA: 1:45 - loss: 0.6830 - acc: 0.5625

KeyboardInterrupt: 

In [None]:
model.save("./model_t.h5")

In [None]:
# load model
# model = models.load_model("./model_t.h5")

In [None]:
acc = history.history['acc']
vall_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = [x for x in range(1, len(acc) + 1)]

fig = go.Figure()
fig.add_trace(go.Scatter(x=epochs, y=acc, name="Training Accuracy"))
fig.add_trace(go.Scatter(x=epochs, y=vall_acc, name="Validation Accuracy"))
fig.update_layout(
    title="Training and Validation Accuracy",
    xaxis_title="Epochs",
    yaxis_title="Accuracy"
)
fig.show()

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=epochs, y=loss, name="Training Loss"))
fig.add_trace(go.Scatter(x=epochs, y=val_loss, name="Validation Loss"))
fig.update_layout(
    title="Training and Validation Loss",
    xaxis_title="Epochs",
    yaxis_title="Loss"
)
fig.show()