-
Notifications
You must be signed in to change notification settings - Fork 184
/
decode.py
69 lines (54 loc) · 2.67 KB
/
decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
from posenet.constants import *
def traverse_to_targ_keypoint(
edge_id, source_keypoint, target_keypoint_id, scores, offsets, output_stride, displacements
):
height = scores.shape[0]
width = scores.shape[1]
source_keypoint_indices = np.clip(
np.round(source_keypoint / output_stride), a_min=0, a_max=[height - 1, width - 1]).astype(np.int32)
displaced_point = source_keypoint + displacements[
source_keypoint_indices[0], source_keypoint_indices[1], edge_id]
displaced_point_indices = np.clip(
np.round(displaced_point / output_stride), a_min=0, a_max=[height - 1, width - 1]).astype(np.int32)
score = scores[displaced_point_indices[0], displaced_point_indices[1], target_keypoint_id]
image_coord = displaced_point_indices * output_stride + offsets[
displaced_point_indices[0], displaced_point_indices[1], target_keypoint_id]
return score, image_coord
def decode_pose(
root_score, root_id, root_image_coord,
scores,
offsets,
output_stride,
displacements_fwd,
displacements_bwd
):
num_parts = scores.shape[2]
num_edges = len(PARENT_CHILD_TUPLES)
instance_keypoint_scores = np.zeros(num_parts)
instance_keypoint_coords = np.zeros((num_parts, 2))
instance_keypoint_scores[root_id] = root_score
instance_keypoint_coords[root_id] = root_image_coord
for edge in reversed(range(num_edges)):
target_keypoint_id, source_keypoint_id = PARENT_CHILD_TUPLES[edge]
if (instance_keypoint_scores[source_keypoint_id] > 0.0 and
instance_keypoint_scores[target_keypoint_id] == 0.0):
score, coords = traverse_to_targ_keypoint(
edge,
instance_keypoint_coords[source_keypoint_id],
target_keypoint_id,
scores, offsets, output_stride, displacements_bwd)
instance_keypoint_scores[target_keypoint_id] = score
instance_keypoint_coords[target_keypoint_id] = coords
for edge in range(num_edges):
source_keypoint_id, target_keypoint_id = PARENT_CHILD_TUPLES[edge]
if (instance_keypoint_scores[source_keypoint_id] > 0.0 and
instance_keypoint_scores[target_keypoint_id] == 0.0):
score, coords = traverse_to_targ_keypoint(
edge,
instance_keypoint_coords[source_keypoint_id],
target_keypoint_id,
scores, offsets, output_stride, displacements_fwd)
instance_keypoint_scores[target_keypoint_id] = score
instance_keypoint_coords[target_keypoint_id] = coords
return instance_keypoint_scores, instance_keypoint_coords