Downloading the repos:

In [None]:
# This is my fork of HRNet, this is the same as the official repo, minus some dependencies  
!git clone https://github.com/ramarlina/Higher-HRNet-Human-Pose-Estimation.git

# adding repo to python's paths since we're not going to install it
import sys 
sys.path.append("Higher-HRNet-Human-Pose-Estimation/lib")

# Creating a HRNet Pose Estimation model

Some custom code for parsing the yaml config file:

In [None]:
import json 
import yaml

# Loading the yaml file
config_file = "Higher-HRNet-Human-Pose-Estimation/experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml"
config_json = yaml.load(open(config_file))
 
def walk(node):
    obj = {}
    for key, item in node.items():
        if isinstance(item, dict): 
            obj[key] = ConfigParser(item)
        else:
            obj[key] = item
    return obj

# Custom parser class 
class ConfigParser():
    def __init__(self, cfg_json): 
        self.__dict__ = walk(cfg_json) 

    def __getitem__(self, idx):
        return self.__dict__[idx]

    def __setitem__(self, key, value):
        self.__dict__[key] = value

    def __repr__(self):
        return json.dumps(list(self.__dict__.keys()))

config = ConfigParser(config_json)

print("Weights: ", config.MODEL.PRETRAINED)
print("Num Joints: ", config.MODEL.NUM_JOINTS)

Instantiating the model

In [None]:
from models.pose_higher_hrnet import PoseHigherResolutionNet 
import torch

# set this to "cuda" to use GPU
device = "cpu" 

# creating the model
model = PoseHigherResolutionNet(config).to(device)

Loading pre-trained weights from the official Google Drive repo

In [None]:
# downloading pretrained weights from https://drive.google.com/drive/folders/1zJbBbIHVQmHJp89t5CD1VF5TIzldpHXn
!gdown https://drive.google.com/uc?id=1V9Iz0ZYy9m8VeaspfKECDW0NKlGsYmO1

# loading weights
state_dict = torch.load("./pose_higher_hrnet_w32_512.pth")
model.load_state_dict(state_dict)

# Inference

Helper functions for loading and preprocessing of an image and for predicting pose using the model

In [None]:
from utils.transforms import resize_align_multi_scale 
from utils.transforms import get_multi_scale_size
import cv2
import torchvision
import numpy as np
from matplotlib import pyplot as plt 

transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
])
 
def load_image(fname, resolution=(512,512)): 
    image = cv2.imread(fname)  
    size = image.shape[:-1]
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 

    base_size, center, scale = get_multi_scale_size(
        image, resolution[0], 1.0, 1.0
    )

    image_resized, center, scale = resize_align_multi_scale(image, 512, 1., 1.)
    image_resized = transforms(image_resized)

    image_resized = image_resized.unsqueeze(0)
    return image, image_resized

def predict(model, X, original_size): 
    model.eval()
    outputs = model(X)

    n_joints = outputs[-1].shape[1]

    hm = 0
    for i, output in enumerate(outputs):
        if i < len(outputs):
            output = torch.nn.functional.interpolate(
                output,
                size=(original_size[0], original_size[1]),
                mode='bilinear',
                align_corners=False
            )
        hm += output[:, :n_joints].detach().cpu().numpy()

    hm /= 2

    pts = np.zeros((n_joints, 3))
    confidence = np.zeros(n_joints)

    for i, joint in enumerate(hm[0]):  
        pt = np.unravel_index(np.argmax(joint), joint.shape)
        pts[i:, :2] = pt[::-1]   
        pts[i:, 2] = joint[pt] 
        
    return pts, confidence

def visualize_pose(image, pts):
    """
        Visualizing predicted poses
    """
    skeleton = [ 
        [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7],
        [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4],  # [3, 5], [4, 6]
        [0, 5], [0, 6]
    ]

    plt.figure(figsize=(10,10)) 

    for i, joint in enumerate(skeleton):
        pt1, pt2 = pts[joint] 
        image = cv2.line(
            image, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1])),
            (0,255,0), 5
        )

    for pt in pts:
        image = cv2.circle(image, (int(pt[0]), int(pt[1])), 10, (255,0,0), -1)

    return image 
 

# Inference

In [None]:
# downloading some image to test the model on
!wget https://storage.needpix.com/rsynced_images/man-1453062_1280.jpg

Predicting body pose

In [None]:
# loading the image
image, X = load_image("man-1453062_1280.jpg")

# predicting pose
pts, confidence = predict(model, X, image.shape[:-1])

# visualizing predictions
viz = visualize_pose(image, pts)
 
plt.imshow(viz)