In [1]:
#prepress MPII dataset
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from scipy.io import loadmat
izip = zip
import json
import numpy as np

In [26]:
def fix_wrong_joints(joint):
    if '12' in joint and '13' in joint and '2' in joint and '3' in joint:
        if ((joint['12'][0] < joint['13'][0]) and
                (joint['3'][0] < joint['2'][0])):
            joint['2'], joint['3'] = joint['3'], joint['2']
        if ((joint['12'][0] > joint['13'][0]) and
                (joint['3'][0] > joint['2'][0])):
            joint['2'], joint['3'] = joint['3'], joint['2']

    return joint


def save_joints():
    """
    Convert annotations mat file to json and save on disk.
    Only persons with annotations of all 16 joints will be written in the json.
    """
    joint_data_fn = os.path.join(MPII_OUT_DIR, 'data.json')
    mat = loadmat(os.path.join(MPII_DATA_DIR, 'mpii_human_pose_v1_u12_1.mat'))

    fp = open(joint_data_fn, 'w')

    for i, (anno, train_flag) in enumerate(
        izip(mat['RELEASE']['annolist'][0, 0][0],
            mat['RELEASE']['img_train'][0, 0][0])):

        img_fn = anno['image']['name'][0, 0][0]
        train_flag = int(train_flag)

        if 'annopoints' in str(anno['annorect'].dtype):
            annopoints = anno['annorect']['annopoints'][0]
            head_x1s = anno['annorect']['x1'][0]
            head_y1s = anno['annorect']['y1'][0]
            head_x2s = anno['annorect']['x2'][0]
            head_y2s = anno['annorect']['y2'][0]
            for annopoint, head_x1, head_y1, head_x2, head_y2 in \
                    izip(annopoints, head_x1s, head_y1s, head_x2s, head_y2s):
                if len(annopoint) > 0:
                    head_rect = [float(head_x1[0, 0]),
                                 float(head_y1[0, 0]),
                                 float(head_x2[0, 0]),
                                 float(head_y2[0, 0])]

                    # joint coordinates
                    annopoint = annopoint['point'][0, 0]
                    j_id = [str(j_i[0, 0]) for j_i in annopoint['id'][0]]
                    x = [x[0, 0] for x in annopoint['x'][0]]
                    y = [y[0, 0] for y in annopoint['y'][0]]
                    joint_pos = {}
                    for _j_id, (_x, _y) in zip(j_id, zip(x, y)):
                        joint_pos[str(_j_id)] = [float(_x), float(_y)]
                    # joint_pos = fix_wrong_joints(joint_pos)

                    # visiblity list
                    if 'is_visible' in str(annopoint.dtype):
                        vis = [v[0] if v else [0]
                               for v in annopoint['is_visible'][0]]
                        vis = dict([(k, int(v[0])) if len(v) > 0 else v
                                    for k, v in zip(j_id, vis)])
                    else:
                        vis = None

                    if len(joint_pos) == 16:
                        data = {
                            'filename': img_fn,
                            'train': train_flag,
                            'head_rect': head_rect,
                            'is_visible': vis,
                            'joint_pos': joint_pos
                        }
                        if i<50:
                            print(json.dumps(data))

                        print(json.dumps(data), file=fp)


def write_line(datum, fp):
    """
    Write a line in format:
      image_name, x1, y1, x2,y2, ...
      where xi, yi - coordinates of the i-th joint
    """
    joints = sorted([[int(k), v] for k, v in datum['joint_pos'].items()])
    joints = np.array([j for i, j in joints]).flatten()

    out = [datum['filename']]
    out.extend(joints)
    out = [str(o) for o in out]
    out = ','.join(out)

    print(out, file=fp)


def split_train_test():
    fp_test = open(os.path.join(MPII_OUT_DIR, 'test_joints.csv'), 'w')
    fp_train = open(os.path.join(MPII_OUT_DIR, 'train_joints.csv'), 'w')
    all_data = open(os.path.join(MPII_OUT_DIR, 'data.json')).readlines()
    N = len(all_data)
    N_test = int(N * 0.1)
    N_train = N - N_test

    print('N:{}'.format(N))
    print('N_train:{}'.format(N_train))
    print('N_test:{}'.format(N_test))

    np.random.seed(1701)
    perm = np.random.permutation(N)
    test_indices = perm[:N_test]
    train_indices = perm[N_test:]

    print('train_indices:{}'.format(len(train_indices)))
    print('test_indices:{}'.format(len(test_indices)))

    for i in train_indices:
        datum = json.loads(all_data[i].strip())
        write_line(datum, fp_train)

    for i in test_indices:
        datum = json.loads(all_data[i].strip())
        write_line(datum, fp_test)

In [27]:
import os

# Full path to the project root
ROOT_DIR = os.path.expanduser('~/forgit/deeppose_tf')
OUTPUT_DIR = os.path.join(ROOT_DIR, 'out')
LSP_DATASET_ROOT = os.path.join(ROOT_DIR, 'datasets/lsp')
LSP_EXT_DATASET_ROOT = os.path.join(ROOT_DIR, 'datasets/lsp_ext')
MPII_DATASET_ROOT = os.path.join(ROOT_DIR, 'datasets/mpii')
MPII_DATA_DIR = MPII_DATASET_ROOT
MPII_OUT_DIR = MPII_DATASET_ROOT

In [28]:
print(MPII_DATASET_ROOT)

/Users/peli/forgit/deeppose_tf/datasets/mpii


In [29]:
save_joints()
split_train_test()

{"head_rect": [627.0, 100.0, 706.0, 198.0], "train": 1, "is_visible": {"13": 1, "2": 1, "12": 1, "8": 0, "1": 1, "4": 1, "15": 1, "5": 1, "11": 1, "7": 1, "0": 1, "6": 0, "9": 0, "10": 1, "14": 1, "3": 0}, "joint_pos": {"13": [692.0, 185.0], "2": [573.0, 185.0], "12": [601.0, 167.0], "8": [637.0201, 189.8183], "1": [616.0, 269.0], "4": [661.0, 221.0], "15": [688.0, 313.0], "5": [656.0, 231.0], "11": [553.0, 161.0], "7": [647.0, 176.0], "0": [620.0, 394.0], "6": [610.0, 187.0], "9": [695.9799, 108.1817], "10": [606.0, 217.0], "14": [693.0, 240.0], "3": [647.0, 188.0]}, "filename": "015601864.jpg"}
{"head_rect": [841.0, 145.0, 902.0, 228.0], "train": 1, "is_visible": {"13": 1, "2": 0, "12": 0, "8": 0, "1": 1, "4": 1, "15": 1, "5": 1, "11": 1, "7": 0, "0": 1, "6": 0, "9": 0, "10": 1, "14": 1, "3": 1}, "joint_pos": {"13": [924.0, 206.0], "2": [945.0, 223.0], "12": [888.0, 174.0], "8": [912.4915, 190.6586], "1": [910.0, 279.0], "4": [961.0, 315.0], "15": [955.0, 263.0], "5": [960.0, 403.0],