In [69]:
from pathlib import Path

from PIL import Image
import pandas as pd
import numpy as np

import sklearn.model_selection as skms

import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from tensorflow.keras.layers import Dense

import plotly.express as px
import plotly.graph_objects as go

In [70]:
image_dir = Path('./Fish_Dataset/Fish_Dataset')

# Get filepaths and labels
# This ges all the filepaths that match the regex
filepaths = pd.Series(list(image_dir.glob(r'**/*.png')), name = 'Filepath')

# This takes all those filepaths and retrieves the immediate folder up that the filepaths exist in
labels = pd.Series(list(map(lambda x: os.path.split(os.path.dirname(x))[1], filepaths)), name = 'Label')

In [71]:
# Concat the two columns together with respect to each element
image_paths = pd.concat([filepaths, labels], axis = 1)
image_paths = image_paths.astype('string')
display(image_paths.info())
display(image_paths)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 18000 entries, 0 to 17999
Data columns (total 2 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   Filepath  18000 non-null  string
 1   Label     18000 non-null  string
dtypes: string(2)
memory usage: 281.4 KB


None

Unnamed: 0,Filepath,Label
0,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
1,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
2,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
3,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
4,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
...,...,...
17995,Fish_Dataset\Fish_Dataset\Trout\Trout GT\00996...,Trout GT
17996,Fish_Dataset\Fish_Dataset\Trout\Trout GT\00997...,Trout GT
17997,Fish_Dataset\Fish_Dataset\Trout\Trout GT\00998...,Trout GT
17998,Fish_Dataset\Fish_Dataset\Trout\Trout GT\00999...,Trout GT


In [72]:
# Remove all the image paths that end with GT. The operation that we perform below is to keep all rows that don't end with GT
image_paths = image_paths[~image_paths['Label'].str.endswith('GT')]
display(image_paths.info())
display(image_paths)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 9000 entries, 0 to 16999
Data columns (total 2 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   Filepath  9000 non-null   string
 1   Label     9000 non-null   string
dtypes: string(2)
memory usage: 210.9 KB


None

Unnamed: 0,Filepath,Label
0,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
1,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
2,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
3,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
4,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
...,...,...
16995,Fish_Dataset\Fish_Dataset\Trout\Trout\00996.png,Trout
16996,Fish_Dataset\Fish_Dataset\Trout\Trout\00997.png,Trout
16997,Fish_Dataset\Fish_Dataset\Trout\Trout\00998.png,Trout
16998,Fish_Dataset\Fish_Dataset\Trout\Trout\00999.png,Trout


In [73]:
# Split the data into train test and validation
train, test = skms.train_test_split(image_paths, random_state = 25, shuffle = True)
display(train)
display(test)

Unnamed: 0,Filepath,Label
2576,Fish_Dataset\Fish_Dataset\Gilt-Head Bream\Gilt...,Gilt-Head Bream
69,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
12324,Fish_Dataset\Fish_Dataset\Shrimp\Shrimp\00325.png,Shrimp
4654,Fish_Dataset\Fish_Dataset\Hourse Mackerel\Hour...,Hourse Mackerel
12741,Fish_Dataset\Fish_Dataset\Shrimp\Shrimp\00742.png,Shrimp
...,...,...
2175,Fish_Dataset\Fish_Dataset\Gilt-Head Bream\Gilt...,Gilt-Head Bream
16447,Fish_Dataset\Fish_Dataset\Trout\Trout\00448.png,Trout
4934,Fish_Dataset\Fish_Dataset\Hourse Mackerel\Hour...,Hourse Mackerel
12618,Fish_Dataset\Fish_Dataset\Shrimp\Shrimp\00619.png,Shrimp


Unnamed: 0,Filepath,Label
16740,Fish_Dataset\Fish_Dataset\Trout\Trout\00741.png,Trout
8888,Fish_Dataset\Fish_Dataset\Red Sea Bream\Red Se...,Red Sea Bream
10700,Fish_Dataset\Fish_Dataset\Sea Bass\Sea Bass\00...,Sea Bass
2639,Fish_Dataset\Fish_Dataset\Gilt-Head Bream\Gilt...,Gilt-Head Bream
647,Fish_Dataset\Fish_Dataset\Black Sea Sprat\Blac...,Black Sea Sprat
...,...,...
4690,Fish_Dataset\Fish_Dataset\Hourse Mackerel\Hour...,Hourse Mackerel
4589,Fish_Dataset\Fish_Dataset\Hourse Mackerel\Hour...,Hourse Mackerel
14789,Fish_Dataset\Fish_Dataset\Striped Red Mullet\S...,Striped Red Mullet
16957,Fish_Dataset\Fish_Dataset\Trout\Trout\00958.png,Trout


In [74]:
# These generators will allow us to take our image paths and convert them to actual images while also being memory efficient since they are generators
train_gen = ImageDataGenerator(preprocessing_function = preprocess_input, validation_split = .2)
test_gen = ImageDataGenerator(preprocessing_function = preprocess_input)

In [75]:
train_images = train_gen.flow_from_dataframe(
    dataframe = train,
    x_col = 'Filepath',
    y_col = 'Label',
    class_mode = 'categorical',
    target_size = (224, 224),
    batch_size = 32,
    seed = 25,
    subset = 'training'
)

val_images = train_gen.flow_from_dataframe(
    dataframe = train,
    x_col = 'Filepath',
    y_col = 'Label',
    class_mode = 'categorical',
    target_size = (224, 224),
    batch_size = 32,
    seed = 25,
    subset = 'validation'
)

test_images = test_gen.flow_from_dataframe(
    dataframe = test,
    x_col = 'Filepath',
    y_col = 'Label',
    class_mode = 'categorical',
    target_size = (224, 224),
    batch_size = 32,
    seed = 25,
    shuffle = False
)

Found 5400 validated image filenames belonging to 9 classes.
Found 1350 validated image filenames belonging to 9 classes.
Found 2250 validated image filenames belonging to 9 classes.


In [76]:
def get_compiled_model():
    mobilenet = MobileNetV2(
        input_shape = (224, 224, 3), # Input shape
        include_top = False, # Whether or not to include the final dense layers used for classifcation in mobilenet (we don't want this and will make our own)
        weights = 'imagenet', # Use the imagenet shapes
        pooling = 'avg' # pool using average of the block
    )

    mobilenet.trainable = False

    inputs = mobilenet.input
    x = Dense(128, activation = 'relu')(mobilenet.output)
    x = Dense(128, activation = 'relu')(x)
    outputs = Dense(9, activation = 'softmax')(x)

    model = tf.keras.Model(inputs = inputs, outputs = outputs)
    model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
    return model


In [78]:
model = get_compiled_model()
history = model.fit(train_images, validation_ data = val_images, epochs = 3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


In [82]:
predict_results = model.predict(test_images)

In [83]:
evaluation = model.evaluate(test_images)



In [84]:
result_indices = np.argmax(predict_results, axis = 1)
result_labels = [k for k, v in test_images.class_indices.items()]
results = [result_labels[result_index] for result_index in result_indices]

In [89]:
head = 20
sample = test.head(head)

for i in range(head):
    filepath, label = sample.iloc[i]
    label = test['Label'].iloc[i]
    fig = px.imshow(Image.open(sample['Filepath'].iloc[i]), title = f'Is: {label} | Predict: {results[i]}')
    fig.show()
