In [63]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.optimizers import SGD
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns


In [64]:
# Enable mixed precision for better GPU memory utilization
from tensorflow.keras.mixed_precision import set_global_policy
set_global_policy('mixed_float16')

# Configure GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

print("GPU configuration set.")


GPU configuration set.


In [65]:
# Dataset paths
dataset_path = r'D:\university\FER\fer_ckplus_kdef'
train_dir = f"{dataset_path}\\train"
val_dir = f"{dataset_path}\\val"
test_dir = f"{dataset_path}\\test"

# Image and batch parameters
img_size = (224, 224)  # Optimal size for ResNet
batch_size = 16
num_classes = 8

# Data Augmentation for training
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

val_test_datagen = ImageDataGenerator(rescale=1./255)

# Load data
train_generator = train_datagen.flow_from_directory(
    train_dir, target_size=img_size, batch_size=batch_size, class_mode='categorical'
)

val_generator = val_test_datagen.flow_from_directory(
    val_dir, target_size=img_size, batch_size=batch_size, class_mode='categorical'
)

test_generator = val_test_datagen.flow_from_directory(
    test_dir, target_size=img_size, batch_size=batch_size, class_mode='categorical', shuffle=False
)


Found 23650 images belonging to 8 classes.
Found 2631 images belonging to 8 classes.
Found 6573 images belonging to 8 classes.


In [66]:
def multi_scale_embedding_layer(input_tensor):
    # Multi-scale feature extraction
    x1 = layers.Conv2D(64, (3, 3), padding="same", activation="relu")(input_tensor)
    x2 = layers.Conv2D(128, (5, 5), padding="same", activation="relu")(input_tensor)
    x3 = layers.Conv2D(256, (7, 7), padding="same", activation="relu")(input_tensor)

    # Concatenate multi-scale features
    x = layers.concatenate([x1, x2, x3], axis=-1)
    x = layers.GlobalAveragePooling2D()(x)

    # Ensure dimension compatibility for Transformer input
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Reshape((16, 32))(x)  # Reshape for transformer compatibility
    return x


In [67]:
def temporal_transformer(x):
    x = layers.Dense(512, activation='relu')(x)

    # Ensure correct reshaping by inferring dimensions dynamically
    x = layers.Reshape((16, -1))(x)  

    # Self-attention
    x = layers.MultiHeadAttention(num_heads=8, key_dim=32)(x, x)
    x = layers.LayerNormalization()(x)

    # Feed forward with projection
    ff = layers.Dense(512, activation='relu')(x)
    ff = layers.Dropout(0.5)(ff)

    # Project input to match FFN output shape
    x_proj = layers.Dense(ff.shape[-1])(x)

    x = layers.Add()([x_proj, ff])  # Residual connection

    # Temporal pooling
    x = layers.GlobalAveragePooling1D()(x)
    return x


In [68]:
# Input Layer
inputs = layers.Input(shape=(112, 112, 3))

# CNN Backbone (ResNet50 for feature extraction)
resnet_base = ResNet50(input_shape=(112, 112, 3), include_top=False, weights='imagenet')
resnet_base.trainable = False  # Freeze base model during initial training

# Extract spatial features and adjust dimensions
cnn_features = resnet_base(inputs, training=False)
cnn_features = layers.GlobalAveragePooling2D()(cnn_features)  # Reduce spatial dimensions
cnn_features = layers.Dense(512, activation='relu')(cnn_features)  # Ensure feature consistency
cnn_features = layers.Reshape((16, 32))(cnn_features)  # Reshape for compatibility

# Expand dimensions to fit Conv2D input requirement
cnn_features = tf.expand_dims(cnn_features, axis=-1)  # Expanding to (None, 16, 32, 1)

# Apply multi-scale embedding
ms_embedding = multi_scale_embedding_layer(cnn_features)

# Apply temporal transformer
tformer_output = temporal_transformer(ms_embedding)

# Classification head
x = layers.Dense(256, activation='relu')(tformer_output)
x = layers.Dropout(0.5)(x)
output = layers.Dense(num_classes, activation='softmax')(x)

# Define the model
model = Model(inputs, output)

# Compile the model
model.compile(optimizer=SGD(learning_rate=0.01, momentum=0.9),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Print model summary
model.summary()


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_21 (InputLayer)          [(None, 112, 112, 3  0           []                               
                                )]                                                                
                                                                                                  
 resnet50 (Functional)          (None, 4, 4, 2048)   23587712    ['input_21[0][0]']               
                                                                                                  
 global_average_pooling2d_13 (G  (None, 2048)        0           ['resnet50[0][0]']               
 lobalAveragePooling2D)                                                                           
                                                                                            

In [69]:
# Define callbacks
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)

# Train the model
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=20,
    callbacks=[early_stopping, reduce_lr]
)


Epoch 1/20
Epoch 2/20

KeyboardInterrupt: 

In [None]:
# Evaluate on test set
test_loss, test_accuracy = model.evaluate(test_generator)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")


In [None]:
def plot_training_curves(history):
    plt.figure(figsize=(12, 5))
    
    # Loss Curve
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Loss Curve')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Accuracy Curve
    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Accuracy Curve')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()

plot_training_curves(history)


In [None]:
# Predictions
y_pred = model.predict(test_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = test_generator.classes

# Confusion matrix
cm = confusion_matrix(y_true, y_pred_classes)
class_labels = list(test_generator.class_indices.keys())

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_labels, yticklabels=class_labels)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()

# Classification report
print("Classification Report:")
print(classification_report(y_true, y_pred_classes, target_names=class_labels))
