/
mediapipe_landmarks_model.py
236 lines (193 loc) · 8.7 KB
/
mediapipe_landmarks_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""
This module contains the `MediaPipeLandmarksModel` class, which is a deep learning-based
video embedding model utilizing the MediaPipe framework for extracting pose and hand
landmarks from video frames.
Classes:
MediaPipeLandmarksModel: A video embedding model that utilizes MediaPipe for pose and hand landmark extraction.
Example:
```
from sign_language_translator.models import MediaPipeLandmarksModel
from sign_language_translator.vision.utils import iter_frames_with_opencv
mediapipe_model = MediaPipeLandmarksModel(number_of_persons=1)
frame_sequence = iter_frames_with_opencv("video.mp4")
embedding = mediapipe_model.embed(frame_sequence, landmark_type="world")
print(embedding.shape)
```
"""
from os.path import join
from typing import Dict, Iterable, List, Optional, Union
try:
import mediapipe
except ImportError:
mediapipe = None
import numpy as np
import torch
from numpy.typing import NDArray
from sign_language_translator.config.assets import Assets
from sign_language_translator.models.video_embedding.video_embedding_model import (
VideoEmbeddingModel,
)
from sign_language_translator.utils import ProgressStatusCallback
class MediaPipeLandmarksModel(VideoEmbeddingModel):
"""
A video embedding model using MediaPipe to extract pose and hand landmarks from video frames.
Args:
pose_model_name (str): The name of the pose estimation model.
hand_model_name (str): The name of the hand estimation model.
number_of_persons (int): The maximum number of persons to detect in each frame.
Attributes:
n_persons (int): The maximum number of persons to detect in each frame.
Methods:
embed: Embeds a sequence of frames using pose and hand landmarks.
"""
def __init__(
self,
pose_model_name="pose_landmarker_heavy.task",
hand_model_name="hand_landmarker.task",
number_of_persons: int = 1,
) -> None:
if mediapipe is None:
raise ImportError(
"The 'mediapipe' package is required to use the 'MediaPipeLandmarksModel'. "
"Install it using `pip install sign-language-translator[mediapipe]`."
)
self._pose_class = mediapipe.tasks.vision.PoseLandmarker
self._hand_class = mediapipe.tasks.vision.HandLandmarker
path = self.__download_and_get_model_path(f"models/mediapipe/{pose_model_name}")
self._pose_options = mediapipe.tasks.vision.PoseLandmarkerOptions(
base_options=mediapipe.tasks.BaseOptions(model_asset_path=path),
running_mode=mediapipe.tasks.vision.RunningMode.VIDEO,
output_segmentation_masks=False,
num_poses=number_of_persons,
)
path = self.__download_and_get_model_path(f"models/mediapipe/{hand_model_name}")
self._hand_options = mediapipe.tasks.vision.HandLandmarkerOptions(
base_options=mediapipe.tasks.BaseOptions(model_asset_path=path),
running_mode=mediapipe.tasks.vision.RunningMode.VIDEO,
num_hands=number_of_persons * 2,
)
self.n_persons = number_of_persons
def embed(
self,
frame_sequence: Iterable[Union[torch.Tensor, NDArray[np.uint8]]],
landmark_type: str = "world" or "image" or "all",
progress_callback: Optional[ProgressStatusCallback] = None,
total_frames: Optional[int] = None,
**kwargs,
) -> torch.Tensor:
"""
Embed a sequence of frames (video) into a sequence of pose & hand landmarks.
Args:
frame_sequence (Iterable[torch.Tensor | NDArray[np.uint8]]): A sequence of video frames as 3D arrays (W, H, c).
landmark_type (str): The type of landmarks to include in the embedding ("world", "image", "all").
Returns:
torch.Tensor: A tensor containing the frame embeddings.
"""
if mediapipe is None:
raise ImportError(
"The 'mediapipe' package is required to use the 'MediaPipeLandmarksModel'. "
"Install it using `pip install sign-language-translator[mediapipe]`."
)
if landmark_type not in ("world", "image", "all"):
raise ValueError(
"landmark_type not supported, use 'world', 'image' or 'all'."
)
# TODO: Pose only or hands only
if hasattr(frame_sequence, "__len__"):
total_frames = len(frame_sequence) # type: ignore
embeddings = []
# TODO: create here or in __init__ ??
with self._pose_class.create_from_options(
self._pose_options
) as pose_landmarker, self._hand_class.create_from_options(
self._hand_options
) as hand_landmarker:
for i, frame in enumerate(frame_sequence):
# convert frame to mediapipe image
mp_image = mediapipe.Image(
image_format=mediapipe.ImageFormat.SRGB,
data=np.array(frame),
)
# infer through models
pose_result = pose_landmarker.detect_for_video(mp_image, i)
hand_result = hand_landmarker.detect_for_video(mp_image, i)
# create & append the frame embedding
poses = self._extract_from_pose_results(pose_result)
hands = self._extract_from_hand_results(hand_result)
persons = self._arange_pose_and_hands(poses, hands)
frame_embedding = self._create_frame_embedding(persons, landmark_type)
embeddings.append(frame_embedding)
if progress_callback and total_frames:
progress_callback(
{"file": f"{i / total_frames:.1%}" if total_frames else "?%"}
)
return torch.Tensor(embeddings)
def _flatten_landmarks(self, landmarks) -> List[float]:
return [
value
for lm in landmarks
for value in [lm.x, lm.y, lm.z, lm.visibility, lm.presence]
]
def _extract_from_pose_results(self, pose_result) -> Dict[str, List[List[float]]]:
poses = {"image": [], "world": []}
for pose_image, pose_world in zip(
pose_result.pose_landmarks, pose_result.pose_world_landmarks
):
poses["image"].append(self._flatten_landmarks(pose_image))
poses["world"].append(self._flatten_landmarks(pose_world))
return poses
def _extract_from_hand_results(self, hand_result) -> Dict[str, List[List[float]]]:
hands = {
"Left_image": [],
"Left_world": [],
"Right_image": [],
"Right_world": [],
}
for hnd, image, world in zip(
hand_result.handedness,
hand_result.hand_landmarks,
hand_result.hand_world_landmarks,
):
# flatten & separate
hands[hnd[0].display_name + "_image"].append(self._flatten_landmarks(image))
hands[hnd[0].display_name + "_world"].append(self._flatten_landmarks(world))
return hands
def _arange_pose_and_hands(
self,
poses: Dict[str, List[List[float]]],
hands: Dict[str, List[List[float]]],
) -> Dict[str, List[List[float]]]:
# TODO: Match left & right hands to poses
# by using minimum distance between hand image centers
# np.linalg.norm(pose[left_hand_ids].mean(axis=...), hands.mean(axis=...).T).argmin(axis=...)
default_hand = [0.0] * 5 * 21
default_pose = [0.0] * 5 * 33
for k in poses.keys():
poses[k] += [default_pose] * (self.n_persons - len(poses[k]))
for k in hands.keys():
hands[k] += [default_hand] * (self.n_persons - len(hands[k]))
return {
key: [
poses[key][p] + hands["Left_" + key][p] + hands["Right_" + key][p]
for p in range(self.n_persons)
] # TODO: order of persons should be the same across frames
for key in ["image", "world"]
}
def _create_frame_embedding(
self, persons: Dict[str, List[List[float]]], landmark_type: str
) -> List[float]:
embedding = []
# flatten & concat
if landmark_type in ("world", "all"):
embedding.extend([value for person in persons["world"] for value in person])
if landmark_type in ("image", "all"):
embedding.extend([value for person in persons["image"] for value in person])
return embedding
def __download_and_get_model_path(self, model_local_path: str):
Assets.download(
model_local_path,
progress_bar=True,
leave=False,
chunk_size=1048576,
)
return join(Assets.ROOT_DIR, model_local_path)