In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
import os
from PIL import Image, ImageFile

# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

class CustomImageDataGenerator(ImageDataGenerator):
    def _get_batches_of_transformed_samples(self, index_array):
        batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=self.dtype)
        batch_y = np.zeros((len(index_array), self.num_classes), dtype=self.dtype)
        
        for i, j in enumerate(index_array):
            img_path = self.filepaths[j]
            try:
                img = load_img(img_path, target_size=self.target_size)
                img_array = img_to_array(img)
                img_array = self.standardize(img_array)
                batch_x[i] = img_array
                batch_y[i] = self.labels[j]
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                batch_x[i] = np.zeros(self.image_shape)
                batch_y[i] = np.zeros(self.num_classes)
                
        return batch_x, batch_y

# Directories
train_dir = r"C:\Users\abhid\Downloads\archive (2)\FaceShape Dataset\training_set"
val_dir = r"C:\Users\abhid\Downloads\archive (2)\FaceShape Dataset\testing_set"

# Data generators with custom preprocessing function
train_datagen = CustomImageDataGenerator(rescale=1./255)
val_datagen = CustomImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(64, 64),
    batch_size=32,
    class_mode='categorical'
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(64, 64),
    batch_size=32,
    class_mode='categorical'
)

# Check the classes
print("Classes found: ", train_generator.class_indices)


Found 4000 images belonging to 5 classes.
Found 1000 images belonging to 5 classes.
Classes found:  {'Heart': 0, 'Oblong': 1, 'Oval': 2, 'Round': 3, 'Square': 4}


In [2]:
import os
from PIL import Image, ImageFile

# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

def remove_corrupted_images(directory):
    for subdir, _, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(subdir, file)
            try:
                with Image.open(file_path) as img:
                    img.verify()  # Verify that it is a valid image
            except (IOError, SyntaxError) as e:
                print(f"Removing corrupted image: {file_path}")
                os.remove(file_path)

# Clean training and validation directories
remove_corrupted_images(train_dir)
remove_corrupted_images(val_dir)


In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam

# Model definition
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(pool_size=(2, 2)),

    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(5, activation='softmax')  # 5 classes for 5 face shapes
])

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

# Model training
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // train_generator.batch_size,
    validation_data=val_generator,
    validation_steps=val_generator.samples // val_generator.batch_size,
    epochs=10
)

# Save the model
model.save('face_shape_model.h5')

# Plot training history
import matplotlib.pyplot as plt

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()
