In [None]:
import os
import cv2
import shutil
import mediapipe as mp
from tqdm import tqdm
from torchvision import datasets

# 初始化 Mediapipe
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True)

# 定义输入和输出路径
input_root = "./affectnet_3750subset"
output_root = "./affectnet_3750subset_with_landmarks"

# 确保输出文件夹结构与输入一致
os.makedirs(output_root, exist_ok=True)
for phase in ["train", "test"]:
    input_phase_dir = os.path.join(input_root, phase)
    output_phase_dir = os.path.join(output_root, phase)
    os.makedirs(output_phase_dir, exist_ok=True)
    for class_dir in os.listdir(input_phase_dir):
        input_class_dir = os.path.join(input_phase_dir, class_dir)
        output_class_dir = os.path.join(output_phase_dir, class_dir)
        os.makedirs(output_class_dir, exist_ok=True)

# 提取 landmarks 函数
def extract_mediapipe_landmarks(image_np):
    """
    使用 Mediapipe 提取面部关键点。
    """
    image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
    results = face_mesh.process(image_rgb)
    if results.multi_face_landmarks:
        landmarks = []
        for face_landmarks in results.multi_face_landmarks:
            for lm in face_landmarks.landmark:
                x = int(lm.x * image_np.shape[1])
                y = int(lm.y * image_np.shape[0])
                landmarks.append((x, y))
        return landmarks
    return None

# 绘制 landmarks
def draw_landmarks_on_image(image_np, landmarks):
    """
    在原始图像上绘制 landmarks。
    """
    for x, y in landmarks:
        cv2.circle(image_np, (x, y), radius=1, color=(0, 255, 0), thickness=-1)
    return image_np

# 遍历所有图像并处理
for phase in ["train", "test"]:
    input_phase_dir = os.path.join(input_root, phase)
    output_phase_dir = os.path.join(output_root, phase)

    for class_dir in os.listdir(input_phase_dir):
        input_class_dir = os.path.join(input_phase_dir, class_dir)
        output_class_dir = os.path.join(output_phase_dir, class_dir)

        for image_name in tqdm(os.listdir(input_class_dir), desc=f"Processing {phase}/{class_dir}"):
            input_image_path = os.path.join(input_class_dir, image_name)
            output_image_path = os.path.join(output_class_dir, image_name)

            # 加载图像
            image_np = cv2.imread(input_image_path)
            if image_np is None:
                print(f"Failed to load image: {input_image_path}")
                continue

            # 提取 landmarks
            landmarks = extract_mediapipe_landmarks(image_np)

            if landmarks:
                # 绘制 landmarks 并保存
                image_with_landmarks = draw_landmarks_on_image(image_np.copy(), landmarks)
                cv2.imwrite(output_image_path, image_with_landmarks)
            else:
                # 如果没有 landmarks，直接复制原图
                shutil.copy(input_image_path, output_image_path)
