In [6]:
import cv2
import numpy as np
import os
import tensorflow as tf
from keras.models import Model
from keras.layers import Input, Conv2D, BatchNormalization, DepthwiseConv2D, GlobalAveragePooling2D, Dense, Activation, Add, Reshape, Multiply
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

# Đăng ký hàm swish
@tf.keras.utils.register_keras_serializable()
def swish(x):
    return tf.nn.swish(x)

data_dir = "dataset"
labels = ['Mild_Demented', 'Moderate_Demented', 'Non_Demented', 'Very_Mild_Demented']
size = 28
num_classes = len(labels)
x = []
y = []

for label in labels:
    files = os.listdir(os.path.join(data_dir, label))
    for file in files:
        image = cv2.imread(os.path.join(data_dir, label, file), 1)
        print(os.path.join(data_dir, label, file))
        resized_image = cv2.resize(image, (size, size))
        x.append(resized_image)
        y.append(label)

x = np.array(x)
y = np.array(y)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
label_binarizer = LabelBinarizer()
y_train_onehot = label_binarizer.fit_transform(y_train)
y_test_onehot = label_binarizer.transform(y_test)

def SEBlock(inputs, ratio=4):
    filters = inputs.shape[-1]
    se = GlobalAveragePooling2D()(inputs)
    se = Dense(filters // ratio, activation='relu')(se)
    se = Dense(filters, activation='sigmoid')(se)
    se = Reshape((1, 1, filters))(se)
    return Multiply()([inputs, se])

def MBConvBlock(inputs, expansion_ratio, output_dim, stride):
    expand_dim = expansion_ratio * inputs.shape[-1]

    x = Conv2D(expand_dim, kernel_size=1, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation(swish)(x)

    x = DepthwiseConv2D(kernel_size=3, strides=stride, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation(swish)(x)

    x = SEBlock(x)

    x = Conv2D(output_dim, kernel_size=1, padding='same')(x)
    x = BatchNormalization()(x)

    if stride == 1 and inputs.shape[-1] == output_dim:
        x = Add()([x, inputs])

    return x

def build_efficient_net(input_shape, num_classes):
    inputs = Input(shape=input_shape)

    x = Conv2D(32, kernel_size=3, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation(swish)(x)

    x = MBConvBlock(x, expansion_ratio=1, output_dim=16, stride=1)
    x = MBConvBlock(x, expansion_ratio=6, output_dim=24, stride=2)
    x = MBConvBlock(x, expansion_ratio=6, output_dim=40, stride=2)
    x = MBConvBlock(x, expansion_ratio=6, output_dim=80, stride=2)
    x = MBConvBlock(x, expansion_ratio=6, output_dim=112, stride=1)
    x = MBConvBlock(x, expansion_ratio=6, output_dim=192, stride=2)
    x = MBConvBlock(x, expansion_ratio=6, output_dim=320, stride=1)

    x = Conv2D(1280, kernel_size=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation(swish)(x)

    x = GlobalAveragePooling2D()(x)

    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs)
    return model

# Khởi tạo và huấn luyện mô hình
input_shape = (size, size, 3)
model = build_efficient_net(input_shape, num_classes)

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(x_train, y_train_onehot, epochs=10, batch_size=36, validation_data=(x_test, y_test_onehot))

# Lưu trữ mô hình
model.save('dementia_detection_model.h5')
print("Model saved to dementia_detection_model.h5")

# Đánh giá mô hình
y_pred = model.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_test_classes = np.argmax(y_test_onehot, axis=1)

conf_matrix = confusion_matrix(y_test_classes, y_pred_classes)
correct_predictions = np.sum(np.diag(conf_matrix))
total_samples = np.sum(conf_matrix)
accuracy = correct_predictions / total_samples

print("Accuracy:", accuracy * 100)


dataset\Mild_Demented\mild.jpg
dataset\Mild_Demented\mild_10.jpg
dataset\Mild_Demented\mild_100.jpg
dataset\Mild_Demented\mild_101.jpg
dataset\Mild_Demented\mild_102.jpg
dataset\Mild_Demented\mild_103.jpg
dataset\Mild_Demented\mild_104.jpg
dataset\Mild_Demented\mild_105.jpg
dataset\Mild_Demented\mild_106.jpg
dataset\Mild_Demented\mild_107.jpg
dataset\Mild_Demented\mild_108.jpg
dataset\Mild_Demented\mild_109.jpg
dataset\Mild_Demented\mild_11.jpg
dataset\Mild_Demented\mild_110.jpg
dataset\Mild_Demented\mild_111.jpg
dataset\Mild_Demented\mild_112.jpg
dataset\Mild_Demented\mild_113.jpg
dataset\Mild_Demented\mild_114.jpg
dataset\Mild_Demented\mild_115.jpg
dataset\Mild_Demented\mild_116.jpg
dataset\Mild_Demented\mild_117.jpg
dataset\Mild_Demented\mild_118.jpg
dataset\Mild_Demented\mild_119.jpg
dataset\Mild_Demented\mild_12.jpg
dataset\Mild_Demented\mild_120.jpg
dataset\Mild_Demented\mild_121.jpg
dataset\Mild_Demented\mild_122.jpg
dataset\Mild_Demented\mild_123.jpg
dataset\Mild_Demented\mild_



Model saved to dementia_detection_model.h5
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step
Accuracy: 87.109375
