This repository has been archived by the owner on Jul 2, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 306
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #787 from yuyu2172/point-data
Add COCOKeypointDataset and vis_keypoint_coco
- Loading branch information
Showing
9 changed files
with
562 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
from collections import defaultdict | ||
import json | ||
import numpy as np | ||
import os | ||
|
||
from chainercv.chainer_experimental.datasets.sliceable import GetterDataset | ||
from chainercv.datasets.coco.coco_utils import get_coco | ||
from chainercv import utils | ||
|
||
|
||
class COCOKeypointDataset(GetterDataset): | ||
|
||
"""Keypoint dataset for `MS COCO`_. | ||
This only returns annotation for objects categorized to the "person" | ||
category. | ||
.. _`MS COCO`: http://cocodataset.org/#home | ||
Args: | ||
data_dir (string): Path to the root of the training data. If this is | ||
:obj:`auto`, this class will automatically download data for you | ||
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/coco`. | ||
split ({'train', 'val'}): Select a split of the dataset. | ||
year ({'2014', '2017'}): Use a dataset released in :obj:`year`. | ||
use_crowded (bool): If true, use bounding boxes that are labeled as | ||
crowded in the original annotation. The default value is | ||
:obj:`False`. | ||
return_area (bool): If true, this dataset returns areas of masks | ||
around objects. The default value is :obj:`False`. | ||
return_crowded (bool): If true, this dataset returns a boolean array | ||
that indicates whether bounding boxes are labeled as crowded | ||
or not. The default value is :obj:`False`. | ||
This dataset returns the following data. | ||
.. csv-table:: | ||
:header: name, shape, dtype, format | ||
:obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \ | ||
"RGB, :math:`[0, 255]`" | ||
:obj:`point` [#coco_point_1]_, ":math:`(R, K, 2)`", :obj:`float32`, \ | ||
":math:`(y, x)`" | ||
:obj:`visible` [#coco_point_1]_, ":math:`(R, K)`", :obj:`bool`, \ | ||
"true when a keypoint is visible." | ||
:obj:`label` [#coco_point_1]_, ":math:`(R,)`", :obj:`int32`, \ | ||
":math:`[0, \#fg\_class - 1]`" | ||
:obj:`bbox` [#coco_point_1]_, ":math:`(R, 4)`", :obj:`float32`, \ | ||
":math:`(y_{min}, x_{min}, y_{max}, x_{max})`" | ||
:obj:`area` [#coco_point_1]_ [#coco_point_2]_, ":math:`(R,)`", \ | ||
:obj:`float32`, -- | ||
:obj:`crowded` [#coco_point_3]_, ":math:`(R,)`", :obj:`bool`, -- | ||
.. [#coco_point_1] If :obj:`use_crowded = True`, :obj:`point`, \ | ||
:obj:`visible`, :obj:`bbox`, \ | ||
:obj:`label` and :obj:`area` contain crowded instances. | ||
.. [#coco_point_2] :obj:`area` is available \ | ||
if :obj:`return_area = True`. | ||
.. [#coco_point_3] :obj:`crowded` is available \ | ||
if :obj:`return_crowded = True`. | ||
""" | ||
|
||
def __init__(self, data_dir='auto', split='train', year='2017', | ||
use_crowded=False, | ||
return_area=False, return_crowded=False): | ||
if split not in ['train', 'val']: | ||
raise ValueError('Unsupported split is given.') | ||
super(COCOKeypointDataset, self).__init__() | ||
self.use_crowded = use_crowded | ||
if data_dir == 'auto': | ||
data_dir = get_coco(split, split, year, 'instances') | ||
|
||
self.img_root = os.path.join( | ||
data_dir, 'images', '{}{}'.format(split, year)) | ||
self.data_dir = data_dir | ||
|
||
point_anno_path = os.path.join( | ||
self.data_dir, 'annotations', 'person_keypoints_{}{}.json'.format( | ||
split, year)) | ||
annos = json.load(open(point_anno_path, 'r')) | ||
|
||
self.id_to_prop = {} | ||
for prop in annos['images']: | ||
self.id_to_prop[prop['id']] = prop | ||
self.ids = sorted(list(self.id_to_prop.keys())) | ||
|
||
self.cat_ids = [cat['id'] for cat in annos['categories']] | ||
|
||
self.id_to_anno = defaultdict(list) | ||
for anno in annos['annotations']: | ||
self.id_to_anno[anno['image_id']].append(anno) | ||
|
||
self.add_getter('img', self._get_image) | ||
self.add_getter( | ||
['point', 'visible', 'bbox', 'label', 'area', 'crowded'], | ||
self._get_annotations) | ||
keys = ('img', 'point', 'visible', 'label', 'bbox') | ||
if return_area: | ||
keys += ('area',) | ||
if return_crowded: | ||
keys += ('crowded',) | ||
self.keys = keys | ||
|
||
def __len__(self): | ||
return len(self.ids) | ||
|
||
def _get_image(self, i): | ||
img_path = os.path.join( | ||
self.img_root, self.id_to_prop[self.ids[i]]['file_name']) | ||
img = utils.read_image(img_path, dtype=np.float32, color=True) | ||
return img | ||
|
||
def _get_annotations(self, i): | ||
# List[{'segmentation', 'area', 'iscrowd', | ||
# 'image_id', 'bbox', 'category_id', 'id'}] | ||
annotation = self.id_to_anno[self.ids[i]] | ||
bbox = np.array([ann['bbox'] for ann in annotation], | ||
dtype=np.float32) | ||
if len(bbox) == 0: | ||
bbox = np.zeros((0, 4), dtype=np.float32) | ||
# (x, y, width, height) -> (x_min, y_min, x_max, y_max) | ||
bbox[:, 2] = bbox[:, 0] + bbox[:, 2] | ||
bbox[:, 3] = bbox[:, 1] + bbox[:, 3] | ||
# (x_min, y_min, x_max, y_max) -> (y_min, x_min, y_max, x_max) | ||
bbox = bbox[:, [1, 0, 3, 2]] | ||
|
||
label = np.array([self.cat_ids.index(ann['category_id']) | ||
for ann in annotation], dtype=np.int32) | ||
|
||
area = np.array([ann['area'] | ||
for ann in annotation], dtype=np.float32) | ||
|
||
crowded = np.array([ann['iscrowd'] | ||
for ann in annotation], dtype=np.bool) | ||
|
||
point = np.array( | ||
[anno['keypoints'] for anno in annotation], dtype=np.float32) | ||
if len(point) > 0: | ||
x = point[:, 0::3] | ||
y = point[:, 1::3] | ||
# 0: not labeled; 1: labeled, not inside mask; | ||
# 2: labeled and inside mask | ||
v = point[:, 2::3] | ||
visible = v > 0 | ||
point = np.stack((y, x), axis=2) | ||
else: | ||
point = np.empty((0, 0, 2), dtype=np.float32) | ||
visible = np.empty((0, 0), dtype=np.bool) | ||
|
||
# Remove invisible boxes | ||
bbox_area = np.prod(bbox[:, 2:] - bbox[:, :2], axis=1) | ||
keep_mask = np.logical_and(bbox[:, 0] <= bbox[:, 2], | ||
bbox[:, 1] <= bbox[:, 3]) | ||
keep_mask = np.logical_and(keep_mask, bbox_area > 0) | ||
|
||
if not self.use_crowded: | ||
keep_mask = np.logical_and(keep_mask, np.logical_not(crowded)) | ||
|
||
point = point[keep_mask] | ||
visible = visible[keep_mask] | ||
bbox = bbox[keep_mask] | ||
label = label[keep_mask] | ||
area = area[keep_mask] | ||
crowded = crowded[keep_mask] | ||
return point, visible, bbox, label, area, crowded |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from chainercv.visualizations.vis_bbox import vis_bbox # NOQA | ||
from chainercv.visualizations.vis_image import vis_image # NOQA | ||
from chainercv.visualizations.vis_instance_segmentation import vis_instance_segmentation # NOQA | ||
from chainercv.visualizations.vis_keypoint_coco import vis_keypoint_coco # NOQA | ||
from chainercv.visualizations.vis_point import vis_point # NOQA | ||
from chainercv.visualizations.vis_semantic_segmentation import vis_semantic_segmentation # NOQA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from __future__ import division | ||
|
||
import numpy as np | ||
|
||
from chainercv.datasets import coco_keypoint_names | ||
from chainercv.visualizations.vis_image import vis_image | ||
|
||
|
||
human_id = 0 | ||
|
||
coco_point_skeleton = [ | ||
[coco_keypoint_names[human_id].index('left_eye'), | ||
coco_keypoint_names[human_id].index('right_eye')], | ||
[coco_keypoint_names[human_id].index('left_eye'), | ||
coco_keypoint_names[human_id].index('nose')], | ||
[coco_keypoint_names[human_id].index('right_eye'), | ||
coco_keypoint_names[human_id].index('nose')], | ||
[coco_keypoint_names[human_id].index('right_eye'), | ||
coco_keypoint_names[human_id].index('right_ear')], | ||
[coco_keypoint_names[human_id].index('left_eye'), | ||
coco_keypoint_names[human_id].index('left_ear')], | ||
[coco_keypoint_names[human_id].index('right_shoulder'), | ||
coco_keypoint_names[human_id].index('right_elbow')], | ||
[coco_keypoint_names[human_id].index('right_elbow'), | ||
coco_keypoint_names[human_id].index('right_wrist')], | ||
[coco_keypoint_names[human_id].index('left_shoulder'), | ||
coco_keypoint_names[human_id].index('left_elbow')], | ||
[coco_keypoint_names[human_id].index('left_elbow'), | ||
coco_keypoint_names[human_id].index('left_wrist')], | ||
[coco_keypoint_names[human_id].index('right_hip'), | ||
coco_keypoint_names[human_id].index('right_knee')], | ||
[coco_keypoint_names[human_id].index('right_knee'), | ||
coco_keypoint_names[human_id].index('right_ankle')], | ||
[coco_keypoint_names[human_id].index('left_hip'), | ||
coco_keypoint_names[human_id].index('left_knee')], | ||
[coco_keypoint_names[human_id].index('left_knee'), | ||
coco_keypoint_names[human_id].index('left_ankle')], | ||
[coco_keypoint_names[human_id].index('right_shoulder'), | ||
coco_keypoint_names[human_id].index('left_shoulder')], | ||
[coco_keypoint_names[human_id].index('right_hip'), | ||
coco_keypoint_names[human_id].index('left_hip')] | ||
] | ||
|
||
|
||
def vis_keypoint_coco( | ||
img, point, visible=None, | ||
point_score=None, thresh=2, | ||
markersize=3, linewidth=1, ax=None): | ||
"""Visualize keypoints organized as in COCO. | ||
Example: | ||
>>> from chainercv.datasets import COCOKeypointDataset | ||
>>> from chainercv.visualizations import vis_keypoint_coco | ||
>>> import matplotlib.pyplot as plt | ||
>>> data = COCOKeypointDataset(split='val') | ||
>>> img, point, visible = data[10][:3] | ||
>>> vis_keypoint_coco(img, point, visible) | ||
>>> plt.show() | ||
Args: | ||
img (~numpy.ndarray): See the table below. | ||
If this is :obj:`None`, no image is displayed. | ||
point (~numpy.ndarray): See the table below. | ||
visible (~numpy.ndarray): See the table below. If this is | ||
:obj:`None`, all points are assumed to be visible. | ||
point_score (~numpy.ndarray): See the table below. If this | ||
is :obj:`None`, the confidence of all points is infinitely | ||
large. | ||
thresh (float): Points with confidence below :obj:`thresh` are | ||
not visualized. | ||
markersize (float): The size of vertices. | ||
linewidth (float): The thickness of edges. | ||
ax (matplotlib.axes.Axis): The visualization is displayed on this | ||
axis. If this is :obj:`None` (default), a new axis is created. | ||
.. csv-table:: | ||
:header: name, shape, dtype, format | ||
:obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \ | ||
"RGB, :math:`[0, 255]`" | ||
:obj:`point`, ":math:`(R, K, 2)`", :obj:`float32`, \ | ||
":math:`(y, x)`" | ||
:obj:`visible`, ":math:`(R, K)`", :obj:`bool`, \ | ||
"true when a keypoint is visible." | ||
:obj:`point_score`, ":math:`(R, K)`", :obj:`float32`, -- | ||
Returns: | ||
~matploblib.axes.Axes: | ||
Returns the Axes object with the plot for further tweaking. | ||
""" | ||
from matplotlib import pyplot as plt | ||
|
||
# Returns newly instantiated matplotlib.axes.Axes object if ax is None | ||
ax = vis_image(img, ax=ax) | ||
|
||
cmap = plt.get_cmap('rainbow') | ||
colors = [cmap(i) for i in np.linspace(0, 1, len(coco_point_skeleton) + 2)] | ||
|
||
if point_score is None: | ||
point_score = np.inf * np.ones(point.shape[:2], dtype=np.float32) | ||
if point_score.shape != point.shape[:2]: | ||
raise ValueError('Mismatch in the number of instances or joints.') | ||
if point.shape[1:] != (len(coco_keypoint_names[human_id]), 2): | ||
raise ValueError('point has invisible shape') | ||
|
||
if visible is not None: | ||
if visible.dtype != np.bool: | ||
raise ValueError('The dtype of `visible` should be np.bool') | ||
if visible.shape != point.shape[:2]: | ||
raise ValueError('Mismatch in the number of instances or joints.') | ||
for i, vld in enumerate(visible): | ||
point_score[i, np.logical_not(vld)] = -np.inf | ||
|
||
for pnt, pnt_sc in zip(point, point_score): | ||
for l in range(len(coco_point_skeleton)): | ||
i0 = coco_point_skeleton[l][0] | ||
i1 = coco_point_skeleton[l][1] | ||
s0 = pnt_sc[i0] | ||
y0 = pnt[i0, 0] | ||
x0 = pnt[i0, 1] | ||
s1 = pnt_sc[i1] | ||
y1 = pnt[i1, 0] | ||
x1 = pnt[i1, 1] | ||
if s0 > thresh and s1 > thresh: | ||
line = ax.plot([x0, x1], [y0, y1]) | ||
plt.setp(line, color=colors[l], | ||
linewidth=linewidth, alpha=0.7) | ||
if s0 > thresh: | ||
ax.plot( | ||
x0, y0, '.', color=colors[l], | ||
markersize=markersize, alpha=0.7) | ||
if s1 > thresh: | ||
ax.plot( | ||
x1, y1, '.', color=colors[l], | ||
markersize=markersize, alpha=0.7) | ||
|
||
# for better visualization, add mid shoulder / mid hip | ||
mid_shoulder = ( | ||
pnt[coco_keypoint_names[human_id].index('right_shoulder'), :2] + | ||
pnt[coco_keypoint_names[human_id].index('left_shoulder'), :2]) / 2 | ||
mid_shoulder_sc = np.minimum( | ||
pnt_sc[coco_keypoint_names[human_id].index('right_shoulder')], | ||
pnt_sc[coco_keypoint_names[human_id].index('left_shoulder')]) | ||
|
||
mid_hip = ( | ||
pnt[coco_keypoint_names[human_id].index('right_hip'), :2] + | ||
pnt[coco_keypoint_names[human_id].index('left_hip'), :2]) / 2 | ||
mid_hip_sc = np.minimum( | ||
pnt_sc[coco_keypoint_names[human_id].index('right_hip')], | ||
pnt_sc[coco_keypoint_names[human_id].index('left_hip')]) | ||
if (mid_shoulder_sc > thresh and | ||
pnt_sc[coco_keypoint_names[human_id].index('nose')] > thresh): | ||
y = [mid_shoulder[0], | ||
pnt[coco_keypoint_names[human_id].index('nose'), 0]] | ||
x = [mid_shoulder[1], | ||
pnt[coco_keypoint_names[human_id].index('nose'), 1]] | ||
line = ax.plot(x, y) | ||
plt.setp( | ||
line, color=colors[len(coco_point_skeleton)], | ||
linewidth=linewidth, alpha=0.7) | ||
if (mid_shoulder_sc > thresh and mid_hip_sc > thresh): | ||
y = [mid_shoulder[0], mid_hip[0]] | ||
x = [mid_shoulder[1], mid_hip[1]] | ||
line = ax.plot(x, y) | ||
plt.setp( | ||
line, color=colors[len(coco_point_skeleton) + 1], | ||
linewidth=linewidth, alpha=0.7) | ||
|
||
return ax |
Oops, something went wrong.