# Libraries

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
import os, json, cv2, torch, yaml
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import sys
sys.path.append("../")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from lib.config import cfg
import lib.models as models
import lib.dataset as dataset
from lib.utils.draw import draw_heatmaps

# Build

In [None]:
# Load config file

config_file = "../experiments/coco/hrnet/multi_person_39.yaml"
cfg.defrost()
cfg.merge_from_file(config_file)
cfg.freeze()

In [None]:
# Model

model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(cfg, is_train=False)

if cfg.TEST.MODEL_FILE:
    print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
    model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
else:
    model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    print('=> loading model from {}'.format(model_state_file))
    model.load_state_dict(torch.load(model_state_file))
    
model.eval();
model.cuda();

In [None]:
# Dataset

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
m = np.array(mean)
s = np.array(std)
normalize = transforms.Normalize(mean, std)

valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
    cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False, cfg.DATASET.MULTI_PERSON,
    transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
)

# Inference

In [None]:
# Get an item
random_idx = np.random.randint(0, len(valid_dataset))
print("random_idx", random_idx)
data = valid_dataset.__getitem__(random_idx)
input, target, target_weight, meta = data

# Get image
image = input.numpy().transpose((1,2,0))
image = (255*(image*s[None,None,:]+m[None,None,:])).astype('uint8')
H, W = image.shape[:2]

# Model inference
with torch.no_grad():
    inputs = input.unsqueeze(0).cuda()
    outputs = model(inputs)
    heatmaps = outputs[0].cpu().numpy()

# Visualize output
drawn_image = draw_heatmaps(image, heatmaps.astype('uint8'))
plt.figure(figsize=(30,15))
plt.subplot(1,2,1); plt.imshow(image); plt.title("image"); plt.axis('off')
plt.subplot(1,2,2); plt.imshow(drawn_image); plt.title("drawn_image"); plt.axis('off')
plt.show()

In [None]:
keypoints = [
    "nose","left_eye","right_eye","left_ear","right_ear",
    "left_shoulder","right_shoulder","left_elbow","right_elbow",
    "left_wrist","right_wrist","left_hip","right_hip",
    "left_knee","right_knee","left_ankle","right_ankle",
]

base_w = 15
base_h = 15
plt.figure(figsize=(3*base_h, 5*base_w))

plt.subplot(6,3,1); plt.imshow(image); plt.title("image"); plt.axis('off')
for i, heatmap in enumerate(heatmaps):
    heatmap = cv2.resize(heatmap, (W,H), interpolation=cv2.INTER_LINEAR)
    plt.subplot(6,3,i+2); plt.imshow(image); plt.imshow(heatmap, alpha=0.75, vmin=0.0, vmax=1.0, cmap='jet'); plt.title(keypoints[i]); plt.axis('off')
plt.show()