Project: /mediapipe/_project.yaml
Book: /mediapipe/_book.yaml

<link rel="stylesheet" href="/mediapipe/site.css">

# Hand gesture recognition model customization guide

<table align="left" class="buttons">
  <td>
    <a href="https://colab.research.google.com/github/googlesamples/mediapipe/blob/main/examples/customization/gesture_recognizer.ipynb" target="_blank">
      <img src="https://developers.google.com/static/mediapipe/solutions/customization/colab-logo-32px_1920.png" alt="Colab logo"> Run in Colab
    </a>
  </td>

  <td>
    <a href="https://github.com/googlesamples/mediapipe/blob/main/examples/customization/gesture_recognizer.ipynb" target="_blank">
      <img src="https://developers.google.com/static/mediapipe/solutions/customization/github-logo-32px_1920.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
</table>

In [None]:
#@title License information
# Copyright 2023 The MediaPipe Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

The MediaPipe Model Maker package is a low-code solution for customizing on-device machine learning (ML) Models.

This notebook shows the end-to-end process of customizing a gesture recognizer model for recognizing some common hand gestures in the [HaGRID](https://www.kaggle.com/datasets/innominate817/hagrid-sample-30k-384p) dataset.

## Prerequisites

Install the MediaPipe Model Maker package.

In [None]:
!pip install --upgrade pip
!pip install mediapipe-model-maker

Import the required libraries.

In [None]:
from google.colab import files
import os
import tensorflow as tf
assert tf.__version__.startswith('2')

from mediapipe_model_maker import gesture_recognizer

import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Simple End-to-End Example

This end-to-end example uses Model Maker to customize a model for on-device gesture recognition.

### Get the dataset

The dataset for gesture recognition in model maker requires the following format: `<dataset_path>/<label_name>/<img_name>.*`. In addition, one of the label names (`label_names`) must be `none`. The `none` label represents any gesture that isn't classified as one of the other gestures.

This example uses a rock paper scissors dataset sample which is downloaded from GCS.

In [None]:
dataset_path = "drive/MyDrive/ColabWS/GestureCustomize/gesture2"

Verify the rock paper scissors dataset by printing the labels. There should be 4 gesture labels, with one of them being the `none` gesture.

In [None]:
print(dataset_path)
labels = []
for i in os.listdir(dataset_path):
  if os.path.isdir(os.path.join(dataset_path, i)):
    labels.append(i)
print(labels)

To better understand the dataset, plot a couple of example images for each gesture.

In [None]:
from google.colab.patches import cv2_imshow,cv2

NUM_EXAMPLES = 1

for label in labels:
  label_dir = os.path.join(dataset_path, label)
  example_filenames = os.listdir(label_dir)[:NUM_EXAMPLES]
  for i in range(NUM_EXAMPLES):
    img=cv2.imread(os.path.join(label_dir, example_filenames[i]))
    cv2_imshow(img)


### Run the example
The workflow consists of 4 steps which have been separated into their own code blocks.

**Load the dataset**

Load the dataset located at `dataset_path` by using the `Dataset.from_folder` method. When loading the dataset, run the pre-packaged hand detection model from MediaPipe Hands to detect the hand landmarks from the images. Any images without detected hands are ommitted from the dataset. The resulting dataset will contain the extracted hand landmark positions from each image, rather than images themselves.

The `HandDataPreprocessingParams` class contains two configurable options for the data loading process:
* `shuffle`: A boolean controlling whether to shuffle the dataset. Defaults to true.
* `min_detection_confidence`: A float between 0 and 1 controlling the confidence threshold for hand detection.

Split the dataset: 80% for training, 10% for validation, and 10% for testing.

In [None]:
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gesture recognition dataset library."""

import dataclasses
import os
import random
from typing import List, Optional

import tensorflow as tf

from mediapipe_model_maker.python.core.data import classification_dataset
from mediapipe_model_maker.python.core.utils import model_util
from mediapipe_model_maker.python.vision.gesture_recognizer import constants
from mediapipe_model_maker.python.vision.gesture_recognizer import metadata_writer
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module

_Image = image_module.Image
_HandLandmarker = hand_landmarker_module.HandLandmarker
_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions
_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult


@dataclasses.dataclass
class HandDataPreprocessingParams:
  """A dataclass wraps the hand data preprocessing hyperparameters.

  Attributes:
    shuffle: A boolean controlling if shuffle the dataset. Default to true.
    min_detection_confidence: confidence threshold for hand detection.
  """
  shuffle: bool = True
  min_detection_confidence: float = 0.7


@dataclasses.dataclass
class HandData:
  """A dataclass represents hand data for training gesture recognizer model.

  See https://google.github.io/mediapipe/solutions/hands#mediapipe-hands for
  more details of the hand gesture data API.

  Attributes:
    hand: normalized hand landmarks of shape 21x3 from the screen based
      hand-landmark model.
    world_hand: hand landmarks of shape 21x3 in world coordinates.
    handedness: Collection of handedness confidence of the detected hands (i.e.
      is it a left or right hand).
  """
  hand: List[List[float]]
  world_hand: List[List[float]]
  handedness: List[float]


def _validate_data_sample(data: _HandLandmarkerResult) -> bool:
  """Validates the input hand data sample.

  Args:
    data: input hand data sample.

  Returns:
    False if the input data namedtuple does not contain the fields including
    'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
    or any of these attributes' values are none. Otherwise, True.
  """
  if data.hand_landmarks is None or not data.hand_landmarks:
    return False
  if data.hand_world_landmarks is None or not data.hand_world_landmarks:
    return False
  if data.handedness is None or not data.handedness:
    return False
  return True


def _get_hand_data(all_image_paths: List[str],
                   min_detection_confidence: float) -> List[Optional[HandData]]:
  """Computes hand data (landmarks and handedness) in the input image.

  Args:
    all_image_paths: all input image paths.
    min_detection_confidence: hand detection confidence threshold

  Returns:
    A HandData object. Returns None if no hand is detected.
  """
  hand_data_result = []
  hand_detector_model_buffer = model_util.load_tflite_model_buffer(
      constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path()
  )
  hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
      constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path()
  )
  hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
      hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
  hand_landmarker_options = _HandLandmarkerOptions(
      base_options=base_options_module.BaseOptions(
          model_asset_buffer=hand_landmarker_writer.populate()),
      num_hands=1,
      min_hand_detection_confidence=min_detection_confidence,
      min_hand_presence_confidence=0.5,
      min_tracking_confidence=1,
  )
  with _HandLandmarker.create_from_options(
      hand_landmarker_options) as hand_landmarker:
    for index, path in enumerate(all_image_paths):
      tf.compat.v1.logging.info('Loading image %s', path)
      image = _Image.create_from_file(path)
      data = hand_landmarker.detect(image)

      img=cv2.imread(path)
      print(path, index)
      if not _validate_data_sample(data):
        cv2_imshow(img)
        hand_data_result.append(None)
        continue
      hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z]
                        for hand_landmark in data.hand_landmarks[0]]
      hand_world_landmarks = [[
          hand_landmark.x, hand_landmark.y, hand_landmark.z
      ] for hand_landmark in data.hand_world_landmarks[0]]
      handedness_scores = [
          handedness.score for handedness in data.handedness[0]
      ]
      hand_data_result.append(
          HandData(
              hand=hand_landmarks,
              world_hand=hand_world_landmarks,
              handedness=handedness_scores))

      height, width, _ = img.shape
      for i, landmark in enumerate(hand_landmarks):
        x = int(landmark[0] * width)
        y = int(landmark[1] * height)
        if i % 2 == 0:  # 偶数索引位置的点
          color = (0, 0, 0)  # 黑色
        else:
          color = (0, 255, 0)  # 绿色
        cv2.circle(img, (x, y), 5, color, -1)

      cv2.line(img, (int(hand_landmarks[0][0]*width), int(hand_landmarks[0][1]*height)), (int(hand_landmarks[1][0]*width), int(hand_landmarks[1][1]*height)),
	(0, 0, 255), 2);
      cv2.line(img, (int(hand_landmarks[1][0]*width), int(hand_landmarks[1][1]*height)), (int(hand_landmarks[2][0]*width), int(hand_landmarks[2][1]*height)),
	(0, 0, 255), 2);
      cv2.line(img, (int(hand_landmarks[2][0]*width), int(hand_landmarks[2][1]*height)), (int(hand_landmarks[3][0]*width), int(hand_landmarks[3][1]*height)),
	(0, 0, 255), 2);
      cv2.line(img, (int(hand_landmarks[3][0]*width), int(hand_landmarks[3][1]*height)), (int(hand_landmarks[4][0]*width), int(hand_landmarks[4][1]*height)),
	(0, 0, 255), 2);
      cv2.line(img, (int(hand_landmarks[0][0]*width), int(hand_landmarks[0][1]*height)), (int(hand_landmarks[5][0]*width), int(hand_landmarks[5][1]*height)),
	(0, 255, 255), 2);
      cv2.line(img, (int(hand_landmarks[5][0]*width), int(hand_landmarks[5][1]*height)), (int(hand_landmarks[6][0]*width), int(hand_landmarks[6][1]*height)),
	(0, 255, 255), 2);
      cv2.line(img, (int(hand_landmarks[6][0]*width), int(hand_landmarks[6][1]*height)), (int(hand_landmarks[7][0]*width), int(hand_landmarks[7][1]*height)),
	(0, 255, 255), 2);
      cv2.line(img, (int(hand_landmarks[7][0]*width), int(hand_landmarks[7][1]*height)), (int(hand_landmarks[8][0]*width), int(hand_landmarks[8][1]*height)),
	(0, 255, 255), 2);
      cv2.line(img, (int(hand_landmarks[0][0]*width), int(hand_landmarks[0][1]*height)), (int(hand_landmarks[9][0]*width), int(hand_landmarks[9][1]*height)),
	(255, 0, 255), 2);
      cv2.line(img, (int(hand_landmarks[9][0]*width), int(hand_landmarks[9][1]*height)), (int(hand_landmarks[10][0]*width), int(hand_landmarks[10][1]*height)),
	(255, 0, 255), 2);
      cv2.line(img, (int(hand_landmarks[10][0]*width), int(hand_landmarks[10][1]*height)), (int(hand_landmarks[11][0]*width), int(hand_landmarks[11][1]*height)),
	(255, 0, 255), 2);
      cv2.line(img, (int(hand_landmarks[11][0]*width), int(hand_landmarks[11][1]*height)), (int(hand_landmarks[12][0]*width), int(hand_landmarks[12][1]*height)),
	(255, 0, 255), 2);
      cv2.line(img, (int(hand_landmarks[0][0]*width), int(hand_landmarks[0][1]*height)), (int(hand_landmarks[13][0]*width), int(hand_landmarks[13][1]*height)),
	(255, 0, 0), 2);
      cv2.line(img, (int(hand_landmarks[13][0]*width), int(hand_landmarks[13][1]*height)), (int(hand_landmarks[14][0]*width), int(hand_landmarks[14][1]*height)),
	(255, 0, 0), 2);
      cv2.line(img, (int(hand_landmarks[14][0]*width), int(hand_landmarks[14][1]*height)), (int(hand_landmarks[15][0]*width), int(hand_landmarks[15][1]*height)),
	(255, 0, 0), 2);
      cv2.line(img, (int(hand_landmarks[15][0]*width), int(hand_landmarks[15][1]*height)), (int(hand_landmarks[16][0]*width), int(hand_landmarks[16][1]*height)),
	(255, 0, 0), 2);
      cv2.line(img, (int(hand_landmarks[0][0]*width), int(hand_landmarks[0][1]*height)), (int(hand_landmarks[17][0]*width), int(hand_landmarks[17][1]*height)),
	(255, 255, 255), 2);
      cv2.line(img, (int(hand_landmarks[17][0]*width), int(hand_landmarks[17][1]*height)), (int(hand_landmarks[18][0]*width), int(hand_landmarks[18][1]*height)),
	(255, 255, 255), 2);
      cv2.line(img, (int(hand_landmarks[18][0]*width), int(hand_landmarks[18][1]*height)), (int(hand_landmarks[19][0]*width), int(hand_landmarks[19][1]*height)),
	(255, 255, 255), 2);
      cv2.line(img, (int(hand_landmarks[19][0]*width), int(hand_landmarks[19][1]*height)), (int(hand_landmarks[20][0]*width), int(hand_landmarks[20][1]*height)),
	(255, 255, 255), 2);
      cv2_imshow(img)
  return hand_data_result

class DatasetII(classification_dataset.ClassificationDataset):
  """Dataset library for hand gesture recognizer."""

  @classmethod
  def from_folder_base(
      cls,
      dirname: str,
      hparams: Optional[HandDataPreprocessingParams] = None
  ) -> List[Optional[HandData]]:
    data_root = os.path.abspath(dirname)

    # Assumes the image data of the same label are in the same subdirectory,
    # gets image path and label names.
    all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
    if not all_image_paths:
      raise ValueError('Image dataset directory is empty.')

    if not hparams:
      hparams = HandDataPreprocessingParams()

    if hparams.shuffle:
      # Random shuffle data.
      random.shuffle(all_image_paths)

    return _get_hand_data(
        all_image_paths=all_image_paths,
        min_detection_confidence=hparams.min_detection_confidence)

  @classmethod
  def from_folder(
      cls,
      dirname: str,
      hand_data: List[Optional[HandData]]
  ) -> classification_dataset.ClassificationDataset:
    """Loads images and labels from the given directory.

    Directory contents are expected to be in the format:
    <root_dir>/<gesture_name>/*.jpg". One of the `gesture_name` must be `none`
    (case insensitive). The `none` sub-directory is expected to contain images
    of hands that don't belong to other gesture classes in <root_dir>. Assumes
    the image data of the same label are in the same subdirectory.

    Args:
      dirname: Name of the directory containing the data files.
      hparams: Optional hyperparameters for processing input hand gesture
        images.

    Returns:
      Dataset containing landmarks, labels, and other related info.

    Raises:
      ValueError: if the input data directory is empty or the label set does not
        contain label 'none' (case insensitive).
    """
    data_root = os.path.abspath(dirname)

    # Assumes the image data of the same label are in the same subdirectory,
    # gets image path and label names.
    all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
    if not all_image_paths:
      raise ValueError('Image dataset directory is empty.')

    label_names = sorted(
        name for name in os.listdir(data_root)
        if os.path.isdir(os.path.join(data_root, name)))
    if 'none' not in [v.lower() for v in label_names]:
      raise ValueError('Label set does not contain label "None".')
    # Move label 'none' to the front of label list.
    none_idx = [v.lower() for v in label_names].index('none')
    none_value = label_names.pop(none_idx)
    label_names.insert(0, none_value)

    index_by_label = dict(
        (name, index) for index, name in enumerate(label_names))
    all_gesture_indices = [
        index_by_label[os.path.basename(os.path.dirname(path))]
        for path in all_image_paths
    ]

    # Get a list of the valid hand landmark sample in the hand data list.
    valid_indices = [
        i for i in range(len(hand_data)) if hand_data[i] is not None
    ]
    # Remove 'None' element from the hand data and label list.
    valid_hand_data = [dataclasses.asdict(hand_data[i]) for i in valid_indices]
    if not valid_hand_data:
      raise ValueError('No valid hand is detected.')

    valid_label = [all_gesture_indices[i] for i in valid_indices]

    # Convert list of dictionaries to dictionary of lists.
    hand_data_dict = {
        k: [lm[k] for lm in valid_hand_data] for k in valid_hand_data[0]
    }
    hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)

    embedder_model = model_util.load_keras_model(
        constants.GESTURE_EMBEDDER_KERAS_MODEL_FILES.get_path()
    )

    hand_ds = hand_ds.batch(batch_size=1)
    hand_embedding_ds = hand_ds.map(
        map_func=lambda feature: embedder_model(dict(feature)),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    hand_embedding_ds = hand_embedding_ds.unbatch()

    # Create label dataset
    label_ds = tf.data.Dataset.from_tensor_slices(
        tf.cast(valid_label, tf.int64))

    label_one_hot_ds = label_ds.map(
        map_func=lambda index: tf.one_hot(index, len(label_names)),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # Create a dataset with (hand_embedding, one_hot_label) pairs
    hand_embedding_label_ds = tf.data.Dataset.zip(
        (hand_embedding_ds, label_one_hot_ds))

    tf.compat.v1.logging.info(
        'Load valid hands with size: {}, num_label: {}, labels: {}.'.format(
            len(valid_hand_data), len(label_names), ','.join(label_names)))
    return DatasetII(
        dataset=hand_embedding_label_ds,
        label_names=label_names,
        size=len(valid_hand_data),
    )

In [None]:
hand_data_res = DatasetII.from_folder_base(dirname=dataset_path,
    hparams=gesture_recognizer.HandDataPreprocessingParams())

In [None]:
valid_indices = [
        i for i in range(len(hand_data_res)) if hand_data_res[i] is not None
    ]

print(valid_indices)

In [None]:
b = [1, 3, 5] # set this
for index in b:
  if 0 <= index < len(hand_data_res):
    hand_data_res[index] = None

In [None]:
valid_indices = [
        i for i in range(len(hand_data_res)) if hand_data_res[i] is not None
    ]

print(valid_indices)

In [None]:
data = DatasetII.from_folder(dirname=dataset_path,
    hand_data=hand_data_res)

In [None]:
# data = gesture_recognizer.Dataset.from_folder(
#     dirname=dataset_path,
#     hparams=gesture_recognizer.HandDataPreprocessingParams()
# )
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)

**Train the model**

Train the custom gesture recognizer by using the create method and passing in the training data, validation data, model options, and hyperparameters. For more information on model options and hyperparameters, see the [Hyperparameters](#hyperparameters) section below.

In [None]:
hparams = gesture_recognizer.HParams(export_dir="exported_model")
options = gesture_recognizer.GestureRecognizerOptions(hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create(
    train_data=train_data,
    validation_data=validation_data,
    options=options
)

**Evaluate the model performance**

After training the model, evaluate it on a test dataset and print the loss and accuracy metrics.

In [None]:
loss, acc = model.evaluate(test_data, batch_size=1)
print(f"Test loss:{loss}, Test accuracy:{acc}")

**Export to Tensorflow Lite Model**

After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label file.

In [None]:
model.export_model()
!ls exported_model

In [None]:
files.download('exported_model/gesture_recognizer.task')

## Run the model on-device

To use the TFLite model for on-device usage through MediaPipe Tasks, refer to the Gesture Recognizer [overview page](https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer).

## Hyperparameters {:#hyperparameters}


You can further customize the model using the `GestureRecognizerOptions` class, which has two optional parameters for `ModelOptions` and `HParams`. Use the `ModelOptions` class to customize parameters related to the model itself, and the `HParams` class to customize other parameters related to training and saving the model.

`ModelOptions` has one customizable parameter that affects accuracy:
* `dropout_rate`: The fraction of the input units to drop. Used in dropout layer. Defaults to 0.05.
* `layer_widths`: A list of hidden layer widths for the gesture model. Each element in the list will create a new hidden layer with the specified width. The hidden layers are separated with BatchNorm, Dropout, and ReLU. Defaults to an empty list(no hidden layers).

`HParams` has the following list of customizable parameters which affect model accuracy:
* `learning_rate`: The learning rate to use for gradient descent training. Defaults to 0.001.
* `batch_size`: Batch size for training. Defaults to 2.
* `epochs`: Number of training iterations over the dataset. Defaults to 10.
* `steps_per_epoch`: An optional integer that indicates the number of training steps per epoch. If not set, the training pipeline calculates the default steps per epoch as the training dataset size divided by batch size.
* `shuffle`: True if the dataset is shuffled before training. Defaults to False.
* `lr_decay`: Learning rate decay to use for gradient descent training. Defaults to 0.99.
* `gamma`: Gamma parameter for focal loss. Defaults to 2

Additional `HParams` parameter that does not affect model accuracy:
* `export_dir`: The location of the model checkpoint files and exported model files.

For example, the following trains a new model with the dropout_rate of 0.2 and learning rate of 0.003.

In [None]:
hparams = gesture_recognizer.HParams(learning_rate=0.003, export_dir="exported_model_2")
model_options = gesture_recognizer.ModelOptions(dropout_rate=0.2)
options = gesture_recognizer.GestureRecognizerOptions(model_options=model_options, hparams=hparams)
model_2 = gesture_recognizer.GestureRecognizer.create(
    train_data=train_data,
    validation_data=validation_data,
    options=options
)

Evaluate the newly trained model.

In [None]:
loss, accuracy = model_2.evaluate(test_data)
print(f"Test loss:{loss}, Test accuracy:{accuracy}")