In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
from sklearn.metrics import classification_report, confusion_matrix

from utils.data_utils import get_data_loader  # Assuming this is a custom utility
from utils.viz_utils import plot_training_history, plot_confusion_matrix

# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

In [2]:
INPUT_SHAPE = (256, 256)
INPUT_SHAPE_MODEL = (256, 256, 3)
BATCH_SIZE = 32
NUM_CLASSES =7

In [19]:
# --- Define data directory ---
data_dir = "data/Teeth_Dataset"

# --- Get augmented train loader ---
train_loader = get_data_loader(
    data_dir=data_dir,
    batch_size=BATCH_SIZE,
    image_size=INPUT_SHAPE,
    class_mode="categorical",
    shuffle=True,
    seed=42,
    split="Training",
    augment=True
)

# --- Validation loader ---
val_loader = get_data_loader(
    data_dir=data_dir,
    batch_size=BATCH_SIZE,
    image_size=INPUT_SHAPE,
    class_mode="categorical",
    shuffle=True,
    seed=42,
    split="Validation",
    augment=False
)

# --- Test loader ---
test_loader = get_data_loader(
    data_dir=data_dir,
    batch_size=16,
    image_size=INPUT_SHAPE,
    class_mode="categorical",
    shuffle=True,
    seed=42,
    split="Testing",
    augment=False
)


Found 3087 images belonging to 7 classes.
Found 1028 images belonging to 7 classes.
Found 1028 images belonging to 7 classes.


In [4]:
images, labels = next(iter(train_loader))
print("Images shape:", images.shape)
print("Labels shape:", labels.shape)
print("Number of classes:", NUM_CLASSES)


Images shape: (32, 256, 256, 3)
Labels shape: (32, 7)
Number of classes: 7


# **START SIMPLE MODEL**

In [5]:


# --- Define model ---
model_0 = models.Sequential([
# First Conv Block
layers.Conv2D(32, (3, 3), activation='relu', padding='same',
                    input_shape=INPUT_SHAPE_MODEL, kernel_initializer='he_normal'),
layers.BatchNormalization(),
layers.Conv2D(32, (3, 3), activation='relu', padding='same',
                    kernel_initializer='he_normal'),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Dropout(0.25),

# Second Conv Block
layers.Conv2D(64, (3, 3), activation='relu', padding='same',
                    kernel_initializer='he_normal'),
layers.BatchNormalization(),
layers.Conv2D(64, (3, 3), activation='relu', padding='same',
                    kernel_initializer='he_normal'),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Dropout(0.25),

# Third Conv Block
layers.Conv2D(128, (3, 3), activation='relu', padding='same',
                    kernel_initializer='he_normal'),
layers.BatchNormalization(),
layers.Conv2D(128, (3, 3), activation='relu', padding='same',
                    kernel_initializer='he_normal'),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Dropout(0.25),

# Fourth Conv Block
layers.Conv2D(256, (3, 3), activation='relu', padding='same',
                    kernel_initializer='he_normal'),
layers.BatchNormalization(),
layers.Conv2D(256, (3, 3), activation='relu', padding='same',
                    kernel_initializer='he_normal'),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Dropout(0.25),

# Global Average Pooling instead of Flatten
layers.GlobalAveragePooling2D(),

# Dense layers
layers.Dense(512, activation='relu', kernel_initializer='he_normal'),
layers.BatchNormalization(),
layers.Dropout(0.5),
layers.Dense(256, activation='relu', kernel_initializer='he_normal'),
layers.BatchNormalization(),
layers.Dropout(0.5),
layers.Dense(NUM_CLASSES, activation='softmax')
])

model_0.summary()


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 256, 256, 32)      896       
                                                                 
 batch_normalization (BatchN  (None, 256, 256, 32)     128       
 ormalization)                                                   
                                                                 
 conv2d_1 (Conv2D)           (None, 256, 256, 32)      9248      
                                                                 
 max_pooling2d (MaxPooling2D  (None, 128, 128, 32)     0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 128, 128, 32)      0         
                                                                 
 conv2d_2 (Conv2D)           (None, 128, 128, 64)      1

In [6]:

# --- Compile model ---
model_0.compile(
    optimizer=optimizers.Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy', 'Precision', 'Recall']
)

callbacks_list = [
    callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    callbacks.ModelCheckpoint(
        filepath='models/best_model_0_teeth_model.h5',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
]

history_model_0 = model_0.fit(
    train_loader,
    steps_per_epoch=len(train_loader),
    validation_data=val_loader,
    validation_steps=len(val_loader),
    epochs=100,
    callbacks=callbacks_list,
    verbose=1
)


Epoch 1/100
Epoch 1: val_accuracy improved from -inf to 0.13716, saving model to models\best_model_0_teeth_model.h5
Epoch 2/100
Epoch 2: val_accuracy improved from 0.13716 to 0.16926, saving model to models\best_model_0_teeth_model.h5
Epoch 3/100
Epoch 3: val_accuracy improved from 0.16926 to 0.19553, saving model to models\best_model_0_teeth_model.h5
Epoch 4/100
Epoch 4: val_accuracy improved from 0.19553 to 0.21693, saving model to models\best_model_0_teeth_model.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.21693 to 0.27335, saving model to models\best_model_0_teeth_model.h5
Epoch 6/100
Epoch 6: val_accuracy did not improve from 0.27335
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.27335
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.27335
Epoch 9/100
Epoch 9: val_accuracy improved from 0.27335 to 0.35895, saving model to models\best_model_0_teeth_model.h5
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.35895
Epoch 11/100
Epoch 11: val_accuracy

In [10]:
plot_training_history(history_model_0)

In [11]:
print("\nEvaluating on test set...")
test_loss, test_accuracy, test_precision, test_recall = model_0.evaluate(
    test_loader, verbose=1
)


Evaluating on test set...


In [13]:
print(f"\nTest Results:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")


Test Results:
Test Loss: 0.0409
Test Accuracy: 0.9874
Test Precision: 0.9873
Test Recall: 0.9864


In [21]:
print("\nGenerating detailed classification report...")
class_names = list(test_loader.class_indices.keys())
print("Inferred class names:", class_names)
y_true = []
y_pred = []

steps = len(test_loader)

for i in range(steps):
          batch_images, batch_labels = next(test_loader)
          batch_images = tf.cast(batch_images, tf.float32)

          # Predict with smaller batch size to avoid OOM
          predictions = model_0.predict(
          batch_images,
          batch_size=16,
          verbose=0  # Silent predictions for cleaner output
          )

          y_true_batch = np.argmax(batch_labels, axis=1)
          y_pred_batch = np.argmax(predictions, axis=1)

          y_true.extend(y_true_batch)
          y_pred.extend(y_pred_batch)

          # Clean up memory
          del batch_images, batch_labels, predictions
          tf.keras.backend.clear_session()
          # Print progress
          print(f"Processed batch {i+1}/{steps}")

# Convert lists to numpy arrays
y_true = np.array(y_true)
y_pred = np.array(y_pred)



Generating detailed classification report...
Inferred class names: ['CaS', 'CoS', 'Gum', 'MC', 'OC', 'OLP', 'OT']
Processed batch 1/65
Processed batch 2/65
Processed batch 3/65
Processed batch 4/65
Processed batch 5/65
Processed batch 6/65
Processed batch 7/65
Processed batch 8/65
Processed batch 9/65
Processed batch 10/65
Processed batch 11/65
Processed batch 12/65
Processed batch 13/65
Processed batch 14/65
Processed batch 15/65
Processed batch 16/65
Processed batch 17/65
Processed batch 18/65
Processed batch 19/65
Processed batch 20/65
Processed batch 21/65
Processed batch 22/65
Processed batch 23/65
Processed batch 24/65
Processed batch 25/65
Processed batch 26/65
Processed batch 27/65
Processed batch 28/65
Processed batch 29/65
Processed batch 30/65
Processed batch 31/65
Processed batch 32/65
Processed batch 33/65
Processed batch 34/65
Processed batch 35/65
Processed batch 36/65
Processed batch 37/65
Processed batch 38/65
Processed batch 39/65
Processed batch 40/65
Processed batc

In [22]:
if len(y_true) != len(y_pred):
            raise ValueError(f"Mismatch in lengths: y_true ({len(y_true)}) vs y_pred ({len(y_pred)})")

In [27]:
# Generate classification report
report = classification_report(
          y_true,
          y_pred,
          target_names=class_names,
          digits=4,
)
print(report)

              precision    recall  f1-score   support

         CaS     0.9866    1.0000    0.9932       147
         CoS     1.0000    1.0000    1.0000       156
         Gum     1.0000    0.9917    0.9958       120
          MC     0.9821    0.9593    0.9706       172
          OC     0.9817    0.9640    0.9727       111
         OLP     0.9637    1.0000    0.9815       186
          OT     1.0000    0.9853    0.9926       136

    accuracy                         0.9864      1028
   macro avg     0.9877    0.9857    0.9866      1028
weighted avg     0.9865    0.9864    0.9864      1028



In [30]:
plot_confusion_matrix(y_true, y_pred, class_names=class_names)

In [31]:
model_0.save('models/final_model_0_teeth_model.h5')
print("\nModel saved as 'model_0_teeth_model.h5'")


Model saved as 'model_0_teeth_model.h5'
