In [1]:
import numpy as np
import os
from pathlib import Path
import mediapipe as mp
import cv2
import random
from PIL import Image
from typing import Tuple
import csv
import pickle
from tqdm import tqdm

current_working_dir = os.getcwd()
GENERAL_DATA_PATH = Path(current_working_dir).parent / 'data'
DATA_PATH = GENERAL_DATA_PATH

with open(DATA_PATH / 'keypoints_names.pkl', 'rb') as file:
    KEYPOINTS_NAMES = pickle.load(file)

In [8]:
def get_image_numpy(image_file_name : str) -> np.ndarray:
    return np.array(Image.open(DATA_PATH / 'train' / image_file_name))

def get_random_image_numpy() -> np.ndarray:
    images = os.listdir(DATA_PATH / 'train')
    random_image_file_name = random.choice(images)
    return get_image_numpy(random_image_file_name)

class PoseExtractor():
    keypoints_names = KEYPOINTS_NAMES.copy()
    extraction_output_len = 132
    empty_extraction = [None for _ in range(4 * extraction_output_len)]
    
    def __init__(
            self,
            source_data_path : str,
            destination_data_path : str,
            model_complexity : int,
            min_detection_confidence : float = 0.5
            ) -> None:
        self.source_data_path = Path(source_data_path)
        self.destination_data_path = Path(destination_data_path)
        self.pose = mp.solutions.pose.Pose(
            static_image_mode=True,
            model_complexity=model_complexity,
            smooth_landmarks=True,
            enable_segmentation=False,
            smooth_segmentation=False,
            min_detection_confidence=min_detection_confidence,
            min_tracking_confidence=0.5
        )
    
    @property
    def columns_names(self) -> np.ndarray:
        columns_names = []
        
        for keypoint_name in self.keypoints_names:
            columns_names.extend([f'{keypoint_name}_x', f'{keypoint_name}_y', f'{keypoint_name}_z', f'{keypoint_name}_visibility'])

        return columns_names

    def load_image_as_ndarray(
            self,
            image_file_name : str,
            train : bool
            ) -> np.ndarray:
        if train:
            dataset_type = 'train'
        else:
            dataset_type = 'test'
        
        image = np.array(Image.open(self.source_data_path / dataset_type / image_file_name))

        return image

    def extract_pose(
            self,
            image: np.ndarray
            ) -> Tuple:
        extraction_res = self.pose.process(image).pose_landmarks
        
        if extraction_res is None:
            return None

        return list(np.hstack([np.array([landmark.x, landmark.y, landmark.z, landmark.visibility], dtype=np.float64) for landmark in extraction_res.landmark]))
    
    def gray_to_rgb(self, image: np.ndarray) -> np.ndarray:
        return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

    def extract_poses_and_write_to_csv(
            self,
            save_file_name : str,
            sample : bool,
            train : bool = True,
            n_samples : int = 1,
            delimiter : str = ';'
            ) -> None:
        if train:
            images = os.listdir(self.source_data_path / 'train')
        else:
            images = os.listdir(self.source_data_path / 'test')
        
        if sample:
            images = random.sample(images, n_samples)
        
        with open(self.destination_data_path / save_file_name, 'w', newline='') as write_file:
            writer = csv.writer(write_file, delimiter=delimiter)
            writer.writerow(self.columns_names)
            
            data_rows = []
            
            for image_filename in tqdm(images):
                image = self.load_image_as_ndarray(image_filename, train)
                
                if len(image.shape) < 3:
                    image = self.gray_to_rgb(image)
                extracted_pose = None
                
                if image.shape[2] == 3:
                    extracted_pose = self.extract_pose(image)
                
                if extracted_pose is not None:
                    data_rows.append(extracted_pose)
                else:
                    data_rows.append(self.empty_extraction)
            
            writer.writerows(data_rows)

In [9]:
pose_extractor = PoseExtractor(
    source_data_path=DATA_PATH,
    destination_data_path=GENERAL_DATA_PATH,
    model_complexity=1
)

In [10]:
data_rows = pose_extractor.extract_poses_and_write_to_csv(
    sample=False,
    train=False,
    save_file_name='test-poses.csv',
    )

100%|██████████| 4232/4232 [04:23<00:00, 16.08it/s]
