In [None]:
import torch
from torchvision import transforms
from utils.datasets import letterbox
from utils.general import non_max_suppression_kpt
from utils.plots import output_to_keypoint, plot_skeleton_kpts
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("TkAgg")
import cv2
import numpy as np
import socket
import json
import threading
from memory_profiler import profile

In [None]:

def load_model():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    model = torch.load('yolov7-w6-pose.pt', map_location=device)['model']
    # Put in inference mode
    model.float().eval()

    if torch.cuda.is_available():
        # half() turns predictions into float16 tensors
        # which significantly lowers inference time
        model.half().to(device)
    return model

model = load_model()

def saveData(output):
    keypoints_dict = {}
    output = non_max_suppression_kpt(output, 
                                     0.25, # Confidence Threshold
                                     0.65, # IoU Threshold
                                     nc=model.yaml['nc'], # Number of Classes
                                     nkpt=model.yaml['nkpt'], # Number of Keypoints
                                     kpt_label=True)
    with torch.no_grad():
        output = output_to_keypoint(output)
    for idx in range(output.shape[0]):
        kpts = output[idx, 7:].T
        print(kpts)  # Print kpts for debugging
        keypoints_dict.update({
            "nose": (kpts[0], kpts[1]),
            "left_eye": (kpts[2], kpts[3]),
            "right_eye": (kpts[4], kpts[5]),
            "left_ear": (kpts[6], kpts[7]),
            "right_ear": (kpts[8], kpts[9]),
            "left_shoulder": (kpts[10], kpts[11]),
            "right_shoulder": (kpts[12], kpts[13]),
            "left_elbow": (kpts[14], kpts[15]),
            "right_elbow": (kpts[16], kpts[17]),
            "left_wrist": (kpts[18], kpts[19]),
            "right_wrist": (kpts[20], kpts[21]),
            "left_hip": (kpts[22], kpts[23]),
            "right_hip": (kpts[24], kpts[25]),
            "left_knee": (kpts[26], kpts[27]),
            "right_knee": (kpts[28], kpts[29]),
            "left_ankle": (kpts[30], kpts[31]),
            "right_ankle": (kpts[32], kpts[33])
        })
    return keypoints_dict

In [None]:
def run_inference(url):
    image = cv2.imread(url) # shape: (480, 640, 3)
    # Resize and pad image
    image = letterbox(image, 960, stride=64, auto=True)[0] # shape: (768, 960, 3)
    # Apply transforms
    image = transforms.ToTensor()(image) # torch.Size([3, 768, 960])
    # Turn image into batch
    image = image.unsqueeze(0) # torch.Size([1, 3, 768, 960])
    output, _ = model(image) # torch.Size([1, 45900, 57])
    return output, image

In [None]:
def visualize_output_singlePicture(output, image):
    output = non_max_suppression_kpt(output, 
                                     0.25, # Confidence Threshold
                                     0.65, # IoU Threshold
                                     nc=model.yaml['nc'], # Number of Classes
                                     nkpt=model.yaml['nkpt'], # Number of Keypoints
                                     kpt_label=True)
    with torch.no_grad():
        output = output_to_keypoint(output)
    nimg = image[0].permute(1, 2, 0) * 255
    nimg = nimg.cpu().numpy().astype(np.uint8)
    nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)
    for idx in range(output.shape[0]):
        plot_skeleton_kpts(nimg, output[idx, 7:].T, 3)
    plt.figure(figsize=(12, 12))
    plt.axis('off')
    plt.imshow(nimg)
    plt.show()

In [None]:
output, image = run_inference('sample.jpg') #Test image 
visualize_output_singlePicture(output, image)

In [None]:

def serverSetup():
    # Define the IP address and port number for the server
    IP_ADDRESS = "localhost"
    PORT = 8080
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_socket.bind((IP_ADDRESS, PORT))

    server_socket.listen(1)
    print("Listening for incoming connections...")
    client_socket, address = server_socket.accept()
    print(f"Connection established with {address}")
    return client_socket



In [None]:

def visualize_output(frame, output):
    output = non_max_suppression_kpt(output, 
                                     0.25, # Confidence Threshold
                                     0.65, # IoU Threshold
                                     nc=model.yaml['nc'], # Number of Classes
                                     nkpt=model.yaml['nkpt'], # Number of Keypoints
                                     kpt_label=True)
    with torch.no_grad():
        output = output_to_keypoint(output)
   
    for idx in range(output.shape[0]):
        plot_skeleton_kpts(frame, output[idx, 7:].T, 3)
        
        return  output[0, 7:].T

def processKeyPointData(rawData):
    #Process the keypoints for Unity Client app.
    
    #needs to be called every frame
    #Todo : Determine ideal frame rate at which this method should be called. 
    counter = 1 
    x = []
    y = [] 
    confidence = []
    names = ["nose","left_eye","right_eye","left_ear","right_ear","left_shoulder","right_shoulder","left_elbow","right_elbow","left_wrist","right_wrist","left_hip","right_hip","left_knee","right_knee","left_ankle","right_ankle"]
    for obj in range(len(rawData)):
        if counter==1:
            x.append(rawData[obj])
        elif counter ==2 : 
            y.append(rawData[obj])
        elif counter ==3 : 
            confidence.append(rawData[obj])
            counter = 0
        counter+= 1
    resultDictionary = {}
    for c in range(len(confidence)):
        if (confidence[c]>0.8):
#             resultX.append(x[c])
#             resultY.append(y[c])
#             resultLabel.append(names[c])
            resultDictionary[names[c]] = [x[c],y[c]]
    return resultDictionary

In [None]:
#Live cam pose detection and keypoint extraction 

def visualize_output(frame, output):
    output = non_max_suppression_kpt(output, 
                                     0.25, # Confidence Threshold
                                     0.65, # IoU Threshold
                                     nc=model.yaml['nc'], # Number of Classes
                                     nkpt=model.yaml['nkpt'], # Number of Keypoints
                                     kpt_label=True)
    with torch.no_grad():
        output = output_to_keypoint(output)
   
    for idx in range(output.shape[0]):
        plot_skeleton_kpts(frame, output[idx, 7:].T, 3)
        
        return  output[0, 7:].T

#Yolov7 model loaded 
model = load_model()
# WEbcame Video Capture (Current resolution = 100x100. Higher res than this causes frame rate issues****)
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 128)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 128)
cv2.namedWindow("Pose Estimation", cv2.WINDOW_NORMAL)



# Window resize = 950 x 950
cv2.resizeWindow("Pose Estimation", 950, 950)


tempDict = 0 
client_socket = serverSetup()
counter = 0 ;
while True:
    # Read a frame from the video stream
    ret, frame = cap.read()
    if not ret:
        break

    # Resize and pad image
    image = letterbox(frame, 128, stride=64, auto=True)[0]
    # Apply transforms
    image = transforms.ToTensor()(image)
    
    # Turn image into batch
    image = image.unsqueeze(0)
    # Make predictions on the image batch
    output, _ = model(image)

    # Visualize the detected keypoints on the video frame
    tempDict = visualize_output(frame, output)
    if counter%2 ==0:
        try:
            tempDict = processKeyPointData(tempDict)
            joint_positions_json = json.dumps(tempDict)
            
            # Send message to server in a new thread
    #         threading.Thread(target=send_message, args=(client_socket, joint_positions_json)).start()
            client_socket.sendall((joint_positions_json + "\n").encode())
        except Exception as e:
            print(e)
        counter =0
    counter += 1

    # Video with keypoints. 
    cv2.imshow('Pose Estimation', frame)

    
    if cv2.waitKey(1) & 0xFF == ord('q'):
        client_socket.close()
        break


cap.release()
cv2.destroyAllWindows()

In [None]:

model = load_model()
# WEbcame Video Capture (Current resolution = 100x100. Higher res than this causes frame rate issues****)
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 60)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 60)
cv2.namedWindow("Pose Estimation", cv2.WINDOW_NORMAL)

# Window resize = 950 x 950
cv2.resizeWindow("Pose Estimation", 950, 950)

tempDict = 0 
client_socket = serverSetup()
counter = 0 
frame_num = 0
while True:
    # Read a frame from the video stream
    ret, frame = cap.read()
    if not ret:
        break
    
    frame_num += 1
    
    if frame_num % 2 != 0:
        continue
    
    # Resize and pad image
    image = cv2.resize(frame, (60, 60))
    image = transforms.ToTensor()(image)
    image = image.unsqueeze(0)
    
    # Make predictions on the image batch
    output, _ = model(image)

    # Visualize the detected keypoints on the video frame
    tempDict = visualize_output(frame, output)
    
    try:
        tempDict = processKeyPointData(tempDict)
        joint_positions_json = json.dumps(tempDict)
        
        # Send message to server
        client_socket.sendall((joint_positions_json + "\n").encode())
    except Exception as e:
        print(e)

    # Video with keypoints. 
    cv2.imshow('Pose Estimation', frame)
    tempDict = {}
    if cv2.waitKey(1) & 0xFF == ord('q'):
        client_socket.close()
        break


cap.release()
cv2.destroyAllWindows()

In [None]:
#Visualizing Keyframes 
t = processKeyPointData(tempDict)

%matplotlib inline
plt.gca().invert_yaxis()
plt.scatter(t[1],t[2])


for i, name in enumerate(t[0]):
    plt.annotate(name, (t[1][i], t[2][i]))
