Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/keypoints from mediapipe #1232

Merged
merged 11 commits into from
Jun 13, 2024
87 changes: 87 additions & 0 deletions supervision/keypoint/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,93 @@ class IDs, and confidences of each keypoint.
data=data,
)

@classmethod
def from_mediapipe(
cls, mediapipe_results, resolution_wh: Tuple[int, int]
) -> KeyPoints:
"""
Creates a KeyPoints instance from a
[Pose landmark detection](https://ai.google.dev/edge/mediapipe/solutions/vision/pose_landmarker/python)
inference result.


Args:
mediapipe_results
(Union[mediapipe.tasks.python.vision.pose_landmarker.PoseLandmarkerResult,
mediapipe.python.solution_base.SolutionOutputs]):
The output results from Mediapipe. It supports both: the inference
result from mp.tasks.vision.pose_landmaker.PoseLandmarker and the legacy
one from mp.solutions.pose.Pose.
resolution_wh (Tuple[int, int]): A tuple of the form `(width, height)`
representing the resolution of the frame.

Returns:
KeyPoints: A new KeyPoints object.

Example:
```python
import cv2
import mediapipe as mp
import supervision as sv

image = cv2.imread(<SOURCE_IMAGE_PATH>)
image_height, image_width, _ = image.shape
mediapipe_image = mp.Image(
image_format=mp.ImageFormat.SRGB,
data=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
)

# Download model task from: https://ai.google.dev/edge/mediapipe/solutions/vision/pose_landmarker/index#models
options = mp.tasks.vision.PoseLandmarkerOptions(
base_options=mp.tasks.BaseOptions(
model_asset_path="pose_landmarker_heavy.task"
),
running_mode=mp.tasks.vision.RunningMode.IMAGE,
num_poses=2,
)

PoseLandmarker = mp.tasks.vision.PoseLandmarker
with PoseLandmarker.create_from_options(options) as landmarker:
pose_landmarker_result = landmarker.detect(mediapipe_image)

keypoints = sv.KeyPoints.from_mediapipe(
pose_landmarker_result, (image_width, image_height)
)
```
"""
results = mediapipe_results.pose_landmarks
if not isinstance(mediapipe_results.pose_landmarks, list):
if mediapipe_results.pose_landmarks is None:
results = []
else:
results = [
[landmark for landmark in mediapipe_results.pose_landmarks.landmark]
]

if len(results) == 0:
return cls.empty()

xy = []
confidence = []
for pose in results:
prediction_xy = []
prediction_confidence = []
for landmark in pose:
keypoint_xy = [
landmark.x * resolution_wh[0],
landmark.y * resolution_wh[1],
]
prediction_xy.append(keypoint_xy)
prediction_confidence.append(landmark.visibility)

xy.append(prediction_xy)
confidence.append(prediction_confidence)

return cls(
xy=np.array(xy, dtype=np.float32),
confidence=np.array(confidence, dtype=np.float32),
)

@classmethod
def from_ultralytics(cls, ultralytics_results) -> KeyPoints:
"""
Expand Down
38 changes: 38 additions & 0 deletions supervision/keypoint/skeletons.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,44 @@ class Skeleton(Enum):
(17, 15),
]

GHUM = [
(1, 2),
(1, 5),
(2, 3),
(3, 4),
(4, 8),
(5, 6),
(6, 7),
(7, 9),
(10, 11),
(12, 13),
(12, 14),
(12, 24),
(13, 15),
(13, 25),
(14, 16),
(15, 17),
(16, 18),
(15, 19),
(16, 22),
(17, 19),
(17, 21),
(17, 23),
(18, 20),
(19, 21),
(24, 25),
(24, 26),
(25, 27),
(26, 28),
(27, 29),
(28, 30),
(28, 32),
(29, 31),
(29, 33),
(30, 32),
(31, 33),
]


SKELETONS_BY_EDGE_COUNT: Dict[int, Edges] = {}
SKELETONS_BY_VERTEX_COUNT: Dict[int, Edges] = {}
Expand Down