In [3]:
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.preprocessing.image import ImageDataGenerator 
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam 
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.metrics import classification_report, confusion_matrix 

In [4]:
# Set paths
train_path = 'Datasets/RAF-FER-SFEW/train'
test_path = 'Datasets/RAF-FER-SFEW/test'

# Image and augmentation parameters
img_width, img_height = 48, 48  # Adjust accordingly if images are of a different size
batch_size = 32
class_mode = 'categorical'  # Since you have multiple emotion categories

# Create train and validation set generators, include rescaling 
train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20, 
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True
)

validation_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_path,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode=class_mode
)

validation_generator = validation_datagen.flow_from_directory(
    test_path,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode=class_mode
)

Found 41839 images belonging to 7 classes.
Found 10651 images belonging to 7 classes.


In [8]:
# Analyze class distribution: (Optional but recommended)
print(train_generator.class_indices)  # To understand  emotion label mappings

# Define your CNN model (Same CNN architecture from before)
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(img_width, img_height, 3)),  # Adjusted to 3 color channels
    BatchNormalization(), 
    MaxPooling2D(),

    Conv2D(64, (3, 3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D(),

    # More Conv2D and MaxPooling2D layers as needed 

    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),  
    Dense(7, activation='softmax') 
])


# Compile the model (Consider weighted loss if imbalance is found)
model.compile(optimizer=Adam(), 
              loss='categorical_crossentropy',  # Or consider weighted loss
              metrics=['accuracy'])

checkpoint = ModelCheckpoint('best_model.h5', save_best_only=True) 
early_stopping = EarlyStopping(patience=5) 

# Train the model
history = model.fit(
    train_generator,
    epochs=50,
    validation_data=validation_generator,
    callbacks=[checkpoint, early_stopping]
)

# Evaluation
y_pred = model.predict(validation_generator)  # Predict using existing generator
y_pred_classes = np.argmax(y_pred, axis=1)    # Get class labels from predictions

# Assuming validation_generator has true labels accessible:
y_true = validation_generator.classes 

print(classification_report(y_true, y_pred_classes))
print(confusion_matrix(y_true, y_pred_classes))

{'angry': 0, 'disgust': 1, 'fear': 2, 'happy': 3, 'neutral': 4, 'sad': 5, 'surprise': 6}
Epoch 1/50
Epoch 2/50


  saving_api.save_model(


Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50

KeyboardInterrupt: 

In [9]:
# Evaluation
y_pred = model.predict(validation_generator)  # Predict using existing generator
y_pred_classes = np.argmax(y_pred, axis=1)    # Get class labels from predictions

# Assuming validation_generator has true labels accessible:
y_true = validation_generator.classes 

print(classification_report(y_true, y_pred_classes))
print(confusion_matrix(y_true, y_pred_classes))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00      1196
           1       0.00      0.00      0.00       294
           2       0.00      0.00      0.00      1143
           3       0.29      0.33      0.31      3021
           4       0.00      0.00      0.00      1995
           5       0.18      0.01      0.01      1793
           6       0.11      0.66      0.19      1209

    accuracy                           0.17     10651
   macro avg       0.08      0.14      0.07     10651
weighted avg       0.12      0.17      0.11     10651

[[   0    0    0  390    0    5  801]
 [   0    0    0   86    0    0  208]
 [   0    0    0  350    1    9  783]
 [   1    0    0  991    0   17 2012]
 [   0    0    0  678    0   13 1304]
 [   0    0    1  577    0   11 1204]
 [   0    0    0  404    0    6  799]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
