In [134]:
# import libraries
import cv2
import os
import json
import numpy as np
import pandas as pd

In [None]:
def oks(y_true, y_pred, visibility):
    KAPPA = np.array([1] * len(y_true))
    SCALE = 1.0
    # Compute the L2/Euclidean Distance
    distances = np.linalg.norm(y_pred - y_true, axis=-1)
    # Compute the exponential part of the equation
    exp_vector = np.exp(-(distances**2) / (2 * (SCALE**2) * (KAPPA**2)))
    # The numerator expression
    numerator = np.dot(exp_vector, visibility.astype(bool).astype(int))
    # The denominator expression
    denominator = np.sum(visibility.astype(bool).astype(int))
    return numerator / denominator    

In [169]:
def get_oks(src, skip_cnt):
    dirfiles  = search_dir(src)
    
    pred_array_list = []
    gt_array_list = []
    mask_list = []
    
    for df in dirfiles:
        if df.endswith('.DS_Store'):
            continue
        dfs = search_dir(df)
        
        try:
            json_file = [file for file in dfs if file.endswith('frameinfo_edited.json')][0]
        except:
            json_file = [file for file in dfs if file.endswith('_frameInfo.json')][0]
            
        rgb_image = [file for file in dfs if file.endswith('_color.mp4')][0]    
        cap = cv2.VideoCapture(rgb_image)
        
        cnt = 0
    
        local_dl = '{}/{}'.format(dst, 'color')
        os.makedirs(local_dl, exist_ok=True)
        
        with open(json_file, 'r') as f:
            jdict = json.loads(f.read())
        
        intr_mat = [[1054.0*0.25, 0.0, 979.0*0.25],
                     [0.0, 1058.0*0.25, 519.0*0.25],
                     [0.0, 0.0, 1.0]]
        
        pred_array = []
        gt_array = []
        mask_array = []
        while True:
            retval, frame = cap.read()
            if not retval:
                break   
                
            if cnt % skip_cnt != 0:
                pass
                
            resized_width = int(frame.shape[1]*resize_ratio)
            resized_height = int(frame.shape[0]*resize_ratio)
                                
            resized_frame = cv2.resize(frame, [resized_width, resized_height])
            results = pose.process(resized_frame)
            if results.pose_landmarks is None:
                continue
            
            dict_mediapipe = dict()
            for i, lm in enumerate(results.pose_landmarks.landmark):
                if keypoint_mapping_mediapipe[i] > -1:
                    dict_mediapipe[keypoint_mapping_mediapipe[i]] = [lm.x, lm.y]
            
            dict_label = dict()
            if len(jdict['infos']) > cnt:
                joints = jdict['infos'][cnt]['positions']
            
                for i in range(len(joints)):
                    # pass if no points
                    if keypoint_mapping[i] < 0:
                        continue
                    
                    point = np.array([[joints[i]['x'], joints[i]['y'], joints[i]['z']]])                     
                    point = point[:, :3]/ point[:, 2:3]
                    point = np.transpose(np.matmul(intr_mat, np.transpose(point)))
                    point_xy = (point[0, 0]/resized_width, point[0, 1]/resized_height)
                    dict_label[keypoint_mapping[i]] = [point_xy[0], point_xy[1]]
                    
                pred, gt = [], []
                mask = []
                for kc in keypoint_class:
                    pred.append(dict_mediapipe[kc])
                    gt.append(dict_label[kc])
                    mask.append(True)
                pred_array.append(pred)
                gt_array.append(gt)
                mask_array.append(mask)
                cnt += 1
        pred_array_list.extend(pred_array)
        gt_array_list.extend(gt_array)
        mask_list.extend(mask_array)
        cap.release()
    return gt_array_list, pred_array_list, mask_list

In [170]:
def search_dir(basedir):
    dirnames = os.listdir(basedir)
    return [ os.path.join(basedir, dirname) for dirname in dirnames]
    
import mediapipe as mp
mp_pose = mp.solutions.pose

pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
keypoint_mapping = [0, -1, 1, 2, 3, 20, 4, 5, 6, 7, -1, 8, 9, 10, 11,  12, 13,14,15,16,17,18,19,-1, -1, -1, -1, -1, 21, 22, 23, 24]

keypoint_mapping_mediapipe = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, #10
                              4, 8, 5, 9, 6, 10, -1, -1, 21, 23, 22, 24, #11~
                              12, 16, 13, 17, 14, 18, -1, -1, 15, 19]

keypoint_class12 = [
    4,5,6,8,9,10,21,23
]

keypoint_class346 = [
    12,13,14,15,16,17,18,19
]

keypoint_class5 = [
    4,5,6,8,9,10,21,22,23,24
]

result = []

for i in range(60):
    src_int = i+1
    src = f'../data/labeled/c{src_int}/'
    dst = f'../data/labeled-processed/c{src_int}/'
    keypoint_class = []
    if src_int <= 20 : # 1 ~ 20 
        keypoint_class = keypoint_class12
    elif 41 <= src_int <= 50: # 41~50
        keypoint_class = keypoint_class5
    else:
        keypoint_class = keypoint_class346
        
    resize_ratio = 0.5
                                  
    json_files = []
    
    gt, pred, mask = (get_oks(src, 3))
    np_gt = np.array(gt).reshape(-1, 2)
    np_pred = np.array(pred).reshape(-1, 2)
    np_mask = np.array(mask).reshape(-1, 1)
    acc = oks(np_gt, np_pred, np_mask)
    # print(acc)
    res = dict()
    res['class_id'] = src_int
    res['oks'] = acc[0]
    result.append(res)
    print(i, acc)
    # result[src_int] = acc
    
    df = pd.DataFrame.from_dict(result)
    df.to_csv('result_total.csv', index=False, header=True)

0 [0.99662615]
1 [0.99450506]
2 [0.9992118]
3 [0.99870484]
4 [0.99689641]
5 [0.99511831]
6 [0.98639682]
7 [0.99796046]
8 [0.98898283]
9 [0.99148606]
10 [0.99868636]
11 [0.99239485]
12 [0.99819637]
13 [0.99531137]
14 [0.98666602]
15 [0.99864324]
16 [0.9895958]
17 [0.9973947]
18 [0.99453501]
19 [0.99524923]
20 [0.99517939]
21 [0.99768879]
22 [0.99219265]
23 [0.99882217]
24 [0.98278562]
25 [0.9914704]
26 [0.99624779]
27 [0.99742666]
28 [0.9993107]
29 [0.99382922]
30 [0.99587002]
31 [0.99269935]
32 [0.99771083]
33 [0.99840977]
34 [0.99334274]
35 [0.99768509]
36 [0.99922639]
37 [0.9980287]
38 [0.99790731]
39 [0.99356694]
40 [0.99823858]
41 [0.99885572]
42 [0.99950561]
43 [0.99867778]
44 [0.99672227]
45 [0.99859572]
46 [0.99707028]
47 [0.99871793]
48 [0.9970692]
49 [0.99958609]
50 [0.99845697]
51 [0.99878368]
52 [0.99933152]
53 [0.99692903]
54 [0.99827803]
55 [0.998378]
56 [0.99129895]
57 [0.99917938]
58 [0.99981468]
59 [0.99983262]
