# 1. Setup and Load Data

In [None]:
import tensorflow as tf
import os
import cv2
import imghdr
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Input
from tensorflow.keras.metrics import Precision, Recall, BinaryAccuracy

In [None]:
# Avoid OOM errors by setting GPU Memory Consumption Growth
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

## 1.2 Remove Dodgy Images

In [None]:
# remove dodgy images
data_dir = 'data'
image_exts = ['jpeg', 'jpg', 'bmp', 'png']

for image_class in os.listdir(data_dir):
    pass #not working with this dataset
    for image in os.listdir(os.path.join(data_dir, image_class)):
        image_path = os.path.join(data_dir, image_class, image)
        try:
            img = cv2.imread(image_path)
            tip = imghdr.what(image_path)
            
            if tip not in image_exts:
                print('Image not in extension list {}'.format(image_path))
                os.remove(image_path)
        except Exception as e:
            print('Issue with image {}'.format(image_path))


## 1.3 Load Data

In [None]:
data = tf.keras.utils.image_dataset_from_directory('data')
data_iterator = data.as_numpy_iterator()

batch = data_iterator.next()
# Images represented as numpy arrays
batch[0].shape

In [None]:
# show examples of images in each category
fig, ax = plt.subplots(ncols=4, figsize=(20,20))
for idx, img in enumerate(batch[0][:4]):
    ax[idx].imshow(img.astype(int))
    ax[idx].title.set_text(batch[1][idx])


# 2. Preprocess Data


## 2.1 Scale Data

In [None]:
scaled_data = data.map(lambda x,y: (x/255,y))
#scaled_data.as_numpy_iterator().next()

## 2.2 Split Data

In [None]:
train_size = int(len(scaled_data)*0.7)
validation_size = int(len(scaled_data)*0.2)
test_size = int(len(scaled_data)*0.1)

train_data = scaled_data.take(train_size)
validation_data = scaled_data.skip(train_size).take(validation_size)
test_data = scaled_data.skip(train_size+validation_size).take(test_size)

# 3. Deep Model

## 3.1 Build Deep Learning Model

In [None]:
model = Sequential([
    Input(shape=(256,256,3)),
    Conv2D(16, (3,3), 1, activation='relu'),
    MaxPooling2D(),
    Conv2D(32, (3,3), 1, activation='relu'),
    MaxPooling2D(),
    Conv2D(16, (3,3), 1, activation='relu'),
    MaxPooling2D(),
    Flatten(),
    Dense(256, activation='relu'),
    Dense(1, activation='sigmoid')
])

model.compile('adam', loss=tf.losses.BinaryCrossentropy(), metrics=['accuracy'])

model.summary()

## 3.2 Train

In [None]:
logdir='logs'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

hist = model.fit(train_data, epochs=20, validation_data=validation_data, callbacks=[tensorboard_callback])


## 3.3 Plot Performance


In [None]:
fig = plt.figure()
plt.plot(hist.history['loss'], color='teal', label='loss')
plt.plot(hist.history['val_loss'], color='orange', label='val_loss')
fig.suptitle('Loss', fontsize=20)
plt.legend(loc='upper left')
plt.show()


# 4. Evaluate

## 4.1 Evaluate

In [None]:
pre = Precision()
re = Recall()
acc = BinaryAccuracy()

for batch in test_data.as_numpy_iterator():
    X, y = batch
    yhat = model.predict(X)
    pre.update_state(y, yhat)
    re.update_state(y, yhat)
    acc.update_state(y, yhat)

print(f'Precision: {pre.result().numpy()}, Recall: {re.result().numpy()},Accuracy: {acc.result().numpy()}')

## 4.2 Test

In [None]:
img = cv2.imread('12499.jpg')
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.show()

resize = tf.image.resize(img, (256,256))
plt.imshow(resize.numpy().astype(int))
plt.show()

yhat = model.predict(np.expand_dims(resize/255,0))

if yhat > 0.5:
    print('Precicted class is Dog')
else:
    print('Precicted class is Cat')

# 5. Save the Model
## 5.1 Save the Model

In [None]:
from tensorflow.keras.models import load_model

model.save(os.path.join('models', 'catdog.keras'))

new_model = load_model(os.path.join('models', 'catdog.keras'))

yhatnew = new_model.predict(np.expand_dims(resize/255,0))

if yhatnew > 0.5:
    print('Precicted class is Dog')
else:
    print('Precicted class is Cat')