In [2]:
import cv2
import mediapipe as mp
import numpy as np
import threading
import tensorflow as tf




In [3]:
label = "Warmup...."
n_time_steps = 10
lm_list = []
model = tf.keras.models.load_model("../../LSTM/models/nodwave.h5")




In [4]:
def make_landmark_timestep(results):
    l = []
    landmark_res = results.pose_landmarks.landmark
    for idx, lm in enumerate(landmark_res):
        l.append(lm.x)
        l.append(lm.y)
        l.append(lm.z)
        l.append(lm.visibility)
    return l

In [5]:
def draw_class_on_image(label, img):
    font = cv2.FONT_HERSHEY_SIMPLEX
    bottomLeftCornerOfText = (10, 30)
    fontScale = 1
    fontColor = (0, 255, 0)
    thickness = 2
    lineType = 2
    cv2.putText(img, label,
                bottomLeftCornerOfText,
                font,
                fontScale,
                fontColor,
                thickness,
                lineType)
    return img

In [16]:
# Define labels
# 0 - nodding
# 1 - handwave
# 2 - nothing
def detect(model, lm_list):
    global label
    lm_list = np.array(lm_list)
    lm_list = np.expand_dims(lm_list, axis=0)
    yhat = model.predict(lm_list)
    yhat = np.argmax(yhat, axis=1)
    actions = {
        0: "Nodding",
        1: "Waving",
        2: "Nothing"
    }
    # print(yhat)
    label = actions[yhat[0]]
    return actions[yhat[0]]

In [None]:
mp_drawing = mp.solutions.drawing_utils # Drawing helpers
mp_pose = mp.solutions.pose # Mediapipe Solutions
cap = cv2.VideoCapture(0)
lm_list = []
time_steps = 10

with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:
    while cap.isOpened():
        ret, frame = cap.read()
        
        # Recolor Feed
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image.flags.writeable = False        
        
        # Make Detections
        results = pose.process(image)

        # Detect after 10 frames
        if results.pose_landmarks:
            lm = make_landmark_timestep(results)
            lm_list.append(lm)
            if len(lm_list) == time_steps:
                t1 = threading.Thread(target=detect, args=(model, lm_list,))
                t1.start()
                lm_list = []
                
        # Recolor image back to BGR for rendering
        image.flags.writeable = True   
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

        # Pose Detections
        mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, 
                                 mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=4),
                                 mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
                               )
        # Prediction
        image = draw_class_on_image(label, image)
                        
        cv2.imshow('Webcam feed', image)

        if cv2.waitKey(10) & 0xFF == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()