In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
import os

from tensorflow import keras
from keras.applications import MobileNetV2
from keras.models import Model
from keras.layers import Input, Dense, Dropout, GlobalMaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


In [2]:
base_dir = "./dataset"

In [3]:
class PrepareModel():
    def __init__(self, base):
        self.base_dir = base
        self.dataset = None
        self.train_df = None
        self.val_df = None
        self.test_df = None
        self.num_classes = 5
        
        self.train_generator = None
        self.val_generator = None
        self.test_generator = None
        
        self.model = None
        self.batch_size = 32
        self.epochs = 5
        self.input_shape = (224, 224, 3)
        self.history = None
        
    def create_dataset(self):
        df = pd.DataFrame(columns=['image', 'label'])
        
        for category in os.listdir(self.base_dir):
            sub_folder = os.path.join(self.base_dir, category)
            
            if not os.path.isdir(sub_folder):
                continue
            for img in os.listdir(sub_folder):
                img_path = os.path.join(sub_folder, img)
                img_label = category
                
                df.loc[len(df)] = {'image':img_path, 'label':img_label}
        self.dataset = df
        
        self.train_df, self.test_df = train_test_split(self.dataset, test_size=0.2, stratify=self.dataset['label'], random_state=42)
        self.train_df, self.val_df = train_test_split(self.train_df, test_size=0.2, stratify=self.train_df['label'], random_state=42)
        
        print(f'Train data: {self.train_df.shape}')
        print(f'Train data: {self.val_df.shape}')
        print(f'Train data: {self.test_df.shape}')
        
    def load_data(self):
        
        train_gen = ImageDataGenerator(
            rescale=1./255,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            vertical_flip=True,
            rotation_range=50,
            brightness_range=[0.5, 1.5]
        )
        
        val_gen = ImageDataGenerator(
            rescale = 1./255
        )
        
        self.train_generator = train_gen.flow_from_dataframe(
            dataframe = self.train_df,
            x_col = 'image',
            y_col = 'label',
            target_size = (224, 224),
            bacth_size = 32,
            class_mode = 'categorical',
            shuffle = True
        )
        
        self.val_generator = val_gen.flow_from_dataframe(
            dataframe = self.val_df,
            x_col = 'image',
            y_col = 'label',
            target_size = (224, 224),
            bacth_size = 32,
            class_mode = 'categorical',
            shuffle = True
        )
        self.test_generator = val_gen.flow_from_dataframe(
            dataframe = self.train_df,
            x_col = 'image',
            y_col = 'label',
            target_size = (224, 224),
            bacth_size = 32,
            class_mode = 'categorical',
            shuffle = True
        )
        labels = self.test_generator.labels
        print(labels[:5])
        
    def load_model(self):
        model = MobileNetV2(weights='imagenet', include_top=False)
        
        for layer in model.layers:
            layer.trainable = False
            
        inputs = Input(shape=self.input_shape)
        
        x = model(inputs)
        x = GlobalMaxPooling2D()(x)
        x = Dense(1024, activation='relu')(x)
        x = Dropout(0.5)(x)
        
        Outputs = Dense(self.num_classes, activation='softmax')(x)
        
        self.model = Model(inputs, Outputs)
        
        self.model.compile(
            optimizer = 'Adam',
            loss = 'categorical_crossentropy',
            metrics = ['accuracy']
        )
        
        print(self.model.summary())
    
    def train_model(self):
        self.history = self.model.fit(
            self.train_generator,
            epochs = self.epochs,
            batch_size = self.batch_size,
            validation_data = self.val_generator
        )
        
    def test_model(self):
        y_pred = self.model.predict(self.test_generator)
        y_pred = np.argmax(y_pred)
        y_true = self.test_generator.labels
        print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
        print(f"Precision: {precision_score(y_true, y_pred)}")
        print(f'Recall: {recall_score(y_true, y_pred)}')
        print(f'F1-score: {f1_score(y_true, y_pred)}')

In [4]:
model = PrepareModel(base_dir)

In [5]:
model.create_dataset()

Train data: (48000, 2)
Train data: (12000, 2)
Train data: (15000, 2)


In [6]:
model.load_data()

Found 48000 validated image filenames belonging to 5 classes.
Found 12000 validated image filenames belonging to 5 classes.
Found 48000 validated image filenames belonging to 5 classes.
[1, 1, 0, 4, 3]


In [7]:
model.load_model()

  model = MobileNetV2(weights='imagenet', include_top=False)


None


In [8]:
model.train_model()

Epoch 1/5


  self._warn_if_super_not_called()


[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2573s[0m 2s/step - accuracy: 0.8445 - loss: 0.8414 - val_accuracy: 0.9452 - val_loss: 0.1469
Epoch 2/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1792s[0m 1s/step - accuracy: 0.9081 - loss: 0.2464 - val_accuracy: 0.9702 - val_loss: 0.0821
Epoch 3/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1572s[0m 1s/step - accuracy: 0.9174 - loss: 0.2248 - val_accuracy: 0.9489 - val_loss: 0.1351
Epoch 4/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1510s[0m 1s/step - accuracy: 0.9234 - loss: 0.2127 - val_accuracy: 0.9416 - val_loss: 0.1566
Epoch 5/5
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1533s[0m 1s/step - accuracy: 0.9255 - loss: 0.2049 - val_accuracy: 0.9488 - val_loss: 0.1414


In [None]:
model.test_model()