In [None]:
import os
import tensorflow as tf
import cv2
import imghdr

import numpy as np
from matplotlib import pyplot as plt

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout, SpatialDropout2D

In [None]:
data_dir = 'Detect_solar_dust'

image_exts = ['jpeg','jpg', 'bmp', 'png']

# check all images are openable
for image_class in os.listdir(data_dir):
    for image in os.listdir(os.path.join(data_dir, image_class)):
        image_path = os.path.join(data_dir, image_class, image)
        try:
            img = cv2.imread(image_path)
            tip = imghdr.what(image_path)
            if tip not in image_exts: 
                print('Image not in ext list {}'.format(image_path))
                os.remove(image_path)
        except Exception as e: 
            print('Failed to load image {}'.format(image_path))
            # os.remove(image_path)

In [None]:
np.random.seed(0) # Add random seed of training for reproducibility

In [None]:
# TODO: explore params
data = tf.keras.utils.image_dataset_from_directory(data_dir)

In [None]:
labels = []
for images, labels_batch in data:
  labels.extend(labels_batch.numpy())

plt.hist(labels)
plt.xlabel('Class Name')
plt.ylabel('Count')
plt.title('Distribution of image classes')

# Set the x-axis tick labels to the class names
plt.xticks(ticks=range(len(data.class_names)), labels=data.class_names)
plt.show()

In [None]:
# Scaling the data to optimize learning time
# from 0-255 (RGB) to 0-1
data = data.map(lambda x,y: (x/255, y))

# used to get batches of our data
batch = data.as_numpy_iterator().next()

In [None]:
# Labels of a batch, 1 representing dirty, 0 clean
fig, ax = plt.subplots(ncols=4, figsize=(20,20))
for idx, img in enumerate(batch[0][:4]):
    ax[idx].imshow(img)
    ax[idx].title.set_text(batch[1][idx])

In [None]:
len(data)

In [None]:
train_size = int(len(data)*.7) # To train the model
val_size = int(len(data)*.2) # To finetune the model
test_size = int(len(data)*.1) # To evaluate the model

In [None]:
train = data.skip(val_size + test_size).take(train_size)
val = data.take(val_size)
test = data.skip(val_size).take(test_size)

In [None]:
model = Sequential(
    [
        Conv2D(16, (3,3), 1, activation="leaky_relu", input_shape=(256,256, 3)),
        MaxPooling2D(), # Takes maximum value out of an area (default is (2,2))
        # SpatialDropout2D(0.2),
        # BatchNormalization(),
        
        Flatten(), # Condensing into a single dimension
        
        Dense(128, activation="leaky_relu"),
        Dropout(0.5),
        Dense(1, activation="sigmoid")
    ]
)

In [None]:
model.compile('adam', loss=tf.losses.BinaryCrossentropy(), metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
hist = model.fit(train, epochs=26, validation_data=val)

In [None]:
fig = plt.figure()
plt.plot(hist.history['loss'], color='teal', label='loss')
plt.plot(hist.history['val_loss'], color='orange', label='val_loss')
fig.suptitle('Loss', fontsize=20)
plt.legend(loc="upper left")
plt.show()

In [None]:
fig = plt.figure()
plt.plot(hist.history['accuracy'], color='teal', label='accuracy')
plt.plot(hist.history['val_accuracy'], color='orange', label='val_accuracy')
fig.suptitle('Accuracy', fontsize=20)
plt.legend(loc="upper left")
plt.show()


In [None]:
from tensorflow.keras.metrics import Precision, Recall, BinaryAccuracy

pre = Precision()
re = Recall()
acc = BinaryAccuracy()

for batch in test.as_numpy_iterator(): 
    X, y = batch
    yhat = model.predict(X)
    pre.update_state(y, yhat)
    re.update_state(y, yhat)
    acc.update_state(y, yhat)

print(pre.result(), re.result(), acc.result())

In [None]:
# Manual test, replace the img url with your image
import cv2
 
img = cv2.imread('manual_test_images/test5.png')
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.show()

In [None]:
resize = tf.image.resize(img, (256,256))
plt.imshow(resize.numpy().astype(int))
plt.show()

In [None]:
yhat = model.predict(np.expand_dims(resize/255, 0))
 
yhat

In [None]:
if yhat > 0.5: 
    print(f'Predicted class is dirty')
else:
    print(f'Predicted class is clean')