In [None]:
from collections import Counter
import os
from torchvision import datasets
import matplotlib.pyplot as plt
import mediapipe as mp
import cv2

# Dataset path
data_path = "/CSCI2952X/datasets/affectnet_3750subset"

# Load the dataset using ImageFolder
train_dataset = datasets.ImageFolder(root=os.path.join(data_path, 'train'), transform=None)
test_dataset = datasets.ImageFolder(root=os.path.join(data_path, 'test'), transform=None)

# Print dataset distribution
print("Training Dataset Distribution:")
print(Counter(train_dataset.targets))

print("Test Dataset Distribution:")
print(Counter(test_dataset.targets))

# Initialize Mediapipe face mesh
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True)

def extract_mediapipe_landmarks(image_np):
    """
    Extract facial landmarks using Mediapipe from a NumPy image.
    """
    # Ensure image is in RGB format
    image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)

    # Detect face landmarks
    results = face_mesh.process(image_rgb)
    if results.multi_face_landmarks:
        landmarks = []
        for face_landmarks in results.multi_face_landmarks:
            for lm in face_landmarks.landmark:
                x = int(lm.x * image_np.shape[1])
                y = int(lm.y * image_np.shape[0])
                landmarks.append((x, y))
        return landmarks
    return None  # No face detected

# Process images in the dataset
for idx, (image_path, label) in enumerate(train_dataset.samples):  # ImageFolder stores samples as (path, label)
    print(f"Processing image {idx + 1}...")

    # Load the image using OpenCV
    image_np = cv2.imread(image_path)  # OpenCV loads in BGR format

    # Extract Mediapipe landmarks for the image
    landmarks = extract_mediapipe_landmarks(image_np)
    if landmarks:
        print(f"Image {idx} (Label {label}) Landmarks: {landmarks[:5]}...")  # Display first 5 landmarks

        # Plot the image and landmarks
        plt.figure(figsize=(5, 5))
        plt.imshow(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB))  # Convert BGR to RGB for display
        plt.title(f"Label: {label}")
        plt.axis('off')

        # Plot landmarks on the image
        for x, y in landmarks:
            plt.scatter(x, y, c='red', s=5)  # Scatter plot for landmarks

        plt.show()
    else:
        print(f"Image {idx} (Label {label}): No face detected")