In [1]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.metrics import confusion_matrix, classification_report

In [2]:
# ====================
# CONFIG
# ====================
dataset_root = 'ParkinsonsDetection\Dataset' 
img_size = (224, 224)
batch_size = 2
epochs = 10
class_names = ['Non-Demented', 'Very-Mild-Demented', 'Mild-Demented', 'Moderate-Demented', 'Severe-Demented']
num_classes = len(class_names)
class_to_idx = {name: i for i, name in enumerate(class_names)}

model_path = "models\dementia_model.h5"
conf_matrix_path = "confusion_matrix_val.png"
training_plot_path = "training_plot.png"

In [3]:
# ====================
# DATASET LOADER
# ====================
def load_data(image_dir, label_dir):
    images = []
    labels = []
    for filename in os.listdir(image_dir):
        if not filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            continue
        img_path = os.path.join(image_dir, filename)
        label_path = os.path.join(label_dir, os.path.splitext(filename)[0] + '.txt')
        with open(label_path, 'r') as f:
            line = f.readline().strip()
            class_index = int(line.split()[0])  # ⬅️ Extract just the class number

        img = Image.open(img_path).convert('RGB').resize(img_size)
        images.append(np.array(img))
        labels.append(class_index)

    images = np.array(images) / 255.0
    labels = tf.keras.utils.to_categorical(labels, num_classes)
    return images, labels



In [4]:
# Load train and validation data
X_train, y_train = load_data(
    os.path.join(dataset_root, 'train', 'images'),
    os.path.join(dataset_root, 'train', 'labels')
)
X_val, y_val = load_data(
    os.path.join(dataset_root, 'test', 'images'),
    os.path.join(dataset_root, 'test', 'labels')
)

In [5]:
# ====================
# MODEL
# ====================
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(224, 224, 3)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# ====================
# TRAINING
# ====================
print("start training")
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    batch_size=batch_size,
    epochs=epochs
)

start training
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [6]:
# ====================
# SAVE MODEL
# ====================
model.save(model_path)
print(f"✅ Model saved as {model_path}")

  saving_api.save_model(


✅ Model saved as models\dementia_model.h5


In [7]:
# ====================
# TRAINING PLOT
# ====================
plt.figure(figsize=(10, 5))
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title("Training and Validation Metrics")
plt.xlabel("Epoch")
plt.ylabel("Value")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(training_plot_path)
print(f"📈 Training plot saved as {training_plot_path}")
plt.close()


📈 Training plot saved as training_plot.png


In [8]:
# ====================
# CONFUSION MATRIX
# ====================
y_true = np.argmax(y_val, axis=1)
y_pred_probs = model.predict(X_val)
y_pred = np.argmax(y_pred_probs, axis=1)

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix (Validation Set)")
plt.tight_layout()
plt.savefig(conf_matrix_path)
print(f"📊 Confusion matrix saved as {conf_matrix_path}")
plt.close()


📊 Confusion matrix saved as confusion_matrix_val.png


In [9]:
# ====================
# CLASSIFICATION REPORT
# ====================
print("\nClassification Report:\n")
print(classification_report(y_true, y_pred, target_names=class_names))


Classification Report:

                    precision    recall  f1-score   support

      Non-Demented       0.59      1.00      0.74        10
Very-Mild-Demented       1.00      0.89      0.94        44
     Mild-Demented       1.00      0.95      0.98        43
 Moderate-Demented       1.00      0.97      0.98        60
   Severe-Demented       0.00      0.00      0.00         0

          accuracy                           0.94       157
         macro avg       0.72      0.76      0.73       157
      weighted avg       0.97      0.94      0.95       157



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
