[![Roboflow Notebooks](https://media.roboflow.com/notebooks/template/bannertest2-2.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672932710194)](https://github.com/roboflow/notebooks)

# Zero-shot Pose Estimation using ViTPose++

The ViTPose model for human pose estimation was originally proposed by the paper [ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation](https://arxiv.org/abs/2204.12484). The architecture of ViTPose consists of a standard, non-hierarchical Vision Transformer as backbone for the task of keypoint estimation with a simple decoder head is added on top to predict the heatmaps from a given image. [ViTPose++](https://arxiv.org/abs/2212.04246) is an improved version of the original ViTPose model that was pre-trained with more data with a mixture-of-experts (MoE) module in the ViT backbone.

## ⚡ Before you start

Let's make sure that we have access to GPU. We can use `nvidia-smi` command to do that. In case of any problems navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `GPU`, and then click `Save`.

In [None]:
!nvidia-smi

## 🧪 Install Supervision and Inference

In [None]:
!pip install -qq -U inference-gpu
!pip install -qq -U supervision

We will also download an image and a video to demonstrate the capabilities of ViTPose.

In [None]:
!wget https://media.roboflow.com/notebooks/examples/yoga-pose.jpg
!wget https://media.roboflow.com/notebooks/examples/boxing.mp4

## 🤖 Defining the Models

We will use the `yolov11m` model for detecting persons using [inference](https://inference.roboflow.com), and the [ViTPose++](https://huggingface.co/usyd-community/vitpose-base-simple) model for pose estimation using [transformers](https://huggingface.co/docs/transformers/index).

In [None]:
from inference import get_model
from transformers import AutoProcessor, VitPoseForPoseEstimation

# Initialize the person detection model
person_detection_model = get_model("yolov11m-1280")

# Initialize ViTPose++ model for pose estimation
pose_model_name = "usyd-community/vitpose-base-simple"
pose_image_processor = AutoProcessor.from_pretrained(pose_model_name)
pose_model = VitPoseForPoseEstimation.from_pretrained(pose_model_name, device_map="cuda")

## 🔧 Utility functions for Pose Estimation

We will create a function called `get_person_detections` that gets the bounding boxes from a given image using the object detection model. The function converts the results from the detectoin model inference into a [`sv.Detections`](https://supervision.roboflow.com/latest/detection/core/) object which lets us easily use the data for pose estimation.

In [None]:
import supervision as sv


def get_person_detections(frame, confidence_threshold=0.7):
    # Get the inference results for the frame
    result = person_detection_model.infer(
        frame, confidence=confidence_threshold
    )[0]

    # Convert the inference results to a Detections object
    detections = sv.Detections.from_inference(result)

    # Filter out detections that are not persons
    detections = detections[detections.class_id == 0]
    return detections

We also need to create a function called `get_person_key_points` that gets the keypoints and the pose labels using ViTPose++. The function converts the results from the frames and the person detections and returns a [`sv.KeyPoints`](https://supervision.roboflow.com/latest/keypoint/core/) object which lets us easily visualize the results using [Supervision's annotators](https://supervision.roboflow.com/latest/keypoint/annotators/).

In [None]:
import torch


def get_person_key_points(frame, detections):
    with torch.inference_mode():
        # Convert the detection data to xywh format
        person_detections_xywh = sv.xyxy_to_xywh(detections.xyxy)
        
        # prepare the inputs for the ViTPose++ model
        inputs = pose_image_processor(
            frame, boxes=[person_detections_xywh], return_tensors="pt"
        ).to("cuda")

        # for vitpose-plus-base checkpoint we should additionally provide dataset_index
        # to specify which MOE experts to use for inference
        if pose_model.config.backbone_config.num_experts > 1:
            dataset_index = torch.tensor([0] * len(inputs["pixel_values"]))
            dataset_index = dataset_index.to(inputs["pixel_values"].device)
            inputs["dataset_index"] = dataset_index
        
        # run the pose estimation model
        outputs = pose_model(**inputs)

        # post-process the results
        pose_results = pose_image_processor.post_process_pose_estimation(
            outputs, boxes=[person_detections_xywh]
        )[0]
    
    # Convert the results to a KeyPoints object
    key_point = sv.KeyPoints.from_transformers(pose_results)

    # Convert the results to a list of pose labels
    pose_labels = []
    for pose_result in pose_results:
        pose_labels.append([
            pose_model.config.id2label[label.item()]
            for label in pose_result["labels"]]
        )

    return key_point, pose_labels

## 📸 Performing Pose Estimation on an Image

### Reading the Image

In [None]:
from PIL import Image

image = Image.open("/content/yoga-pose.jpg").convert("RGB")

### Using Supervision Annotators

We will define the annotators that let us easily plot annotations for a particular computer vision task.

- [`BoxAnnotator`](https://supervision.roboflow.com/latest/detection/annotators/) for drawing bounding boxes on an image using provided detections
- [`EdgeAnnotator`](https://supervision.roboflow.com/latest/keypoint/annotators/#supervision.keypoint.annotators.VertexAnnotator) for drawing skeleton vertices on corresponding to the pose on an image. It uses specified key points to determine the locations where the vertices should be drawn.
- [`VertexLabelAnnotator`](https://supervision.roboflow.com/latest/keypoint/annotators/#supervision.keypoint.annotators.VertexLabelAnnotator) for drawing labels of skeleton vertices for poses on images. It uses specified key points to determine the locations where the vertices should be drawn.

In [None]:
COLORS = [
    "#FF6347", "#FF6347", "#FF6347", "#FF6347",
    "#FF6347", "#FF1493", "#00FF00", "#FF1493",
    "#00FF00", "#FF1493", "#00FF00", "#FFD700",
    "#00BFFF", "#FFD700", "#00BFFF", "#FFD700",
    "#00BFFF"
]
COLORS = [sv.Color.from_hex(color_hex=c) for c in COLORS]

box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE)

line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=(image.width, image.height))
edge_annotator = sv.EdgeAnnotator(color=sv.Color.WHITE, thickness=line_thickness)

text_scale = sv.calculate_optimal_text_scale(resolution_wh=(image.width, image.height))
vertex_label_annotator = sv.VertexLabelAnnotator(
    color=COLORS,
    smart_position=True,
    border_radius=20,
    text_thickness=5,
    text_scale=text_scale
)

We will write a simple function called `get_annotated_frame` for visualizing the keypoints on a single image.

In [None]:
import numpy as np


def get_annotated_frame(
    frame, box_annotator, edge_annotator, vertex_label_annotator, idx=None
):
    # Convert the frame to a numpy array if it is a PIL Image
    # This is necessary because the `VertexLabelAnnotator` expects a numpy array
    frame = (
        np.array(frame)
        if isinstance(frame, Image.Image)
        else frame
    )

    # Get the person detections and key points
    person_detections = get_person_detections(frame)
    person_key_points, pose_labels = get_person_key_points(frame, person_detections)

    # Annotate the frame with the person detections
    annotated_frame = box_annotator.annotate(
        scene=frame.copy(),
        detections=person_detections
    )

    # Annotate the frame with the person key points
    annotated_frame = edge_annotator.annotate(
        scene=annotated_frame,
        key_points=person_key_points
    )

    # Annotate the frame with the person key points labels
    for label in pose_labels:
        annotated_frame = vertex_label_annotator.annotate(
            scene=annotated_frame,
            key_points=person_key_points,
            labels=label
        )

    return annotated_frame

Finally, we can simply call the `get_annotated_frame` with the image and the annotators and get the annotated frame visualizing the keypoints.

In [None]:
annotated_image = get_annotated_frame(
    image, box_annotator, edge_annotator, vertex_label_annotator
)
Image.fromarray(annotated_image)

## 🎥 Performing Pose Estimation on a Video

We re-define the annotators to work with our video.

In [None]:
# Get the video information: the resolution and the frame rate
video_info = sv.VideoInfo.from_video_path("boxing.mp4")

# Create the annotators for visualizing keypoints on the video
box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE)
line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=(video_info.width, video_info.height))
edge_annotator = sv.EdgeAnnotator(color=sv.Color.WHITE, thickness=line_thickness)
vertex_label_annotator = sv.VertexLabelAnnotator(
    color=COLORS,
    smart_position=True,
    border_radius=20,
    text_thickness=5,
    text_scale=1,
)

Finally, we use the `sv.process_video` function from Supervision to perform pose estimation on the entire video by using the `get_annotated_frame` function as a callback.

In [None]:
# Define the callback function for processing the video
# This is the same one we defined earllier to annotate the image
callback = lambda frame, _: get_annotated_frame(
    frame, box_annotator, edge_annotator, vertex_label_annotator
)

# Process the video and save the annotated video
sv.process_video(
    source_path="boxing.mp4",
    target_path="annotated_video.mp4",
    show_progress=True,
    callback=callback,
    max_frames=50, # We can also set this to None to process the entire video
)

<div align="center">
  <p>
    Looking for more tutorials or have questions?
    Check out our <a href="https://github.com/roboflow/notebooks">GitHub repo</a> for more notebooks,
    or visit our <a href="https://discord.gg/GbfgXGJ8Bk">discord</a>.
  </p>
  
  <p>
    <strong>If you found this helpful, please consider giving us a ⭐
    <a href="https://github.com/roboflow/notebooks">on GitHub</a>!</strong>
  </p>

</div>