In [29]:
import os
import numpy as np
import json
import torch
from torch.utils.data import Dataset, DataLoader

In [40]:
class FreiDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        def json_load(p):
            msg = 'File does not exists: %s' % p
            assert os.path.exists(p), msg
            
            with open(p, 'r') as fi:
                d = json.load(fi)
            return d
        
        k_path = os.path.join(root_dir, 'training_K.json')
        mano_path = os.path.join(root_dir, 'training_mano.json')
        xyz_path = os.path.join(root_dir, 'training_xyz.json')
        
        K_list = json_load(k_path)
        mano_list = json_load(mano_path)
        xyz_list = json_load(xyz_path)
        
        assert len(K_list) == len(mano_list), 'Size mismatch.'
        assert len(K_list) == len(xyz_list), 'Size mismatch.'
        
        self.db_data_anno = list(zip(K_list, mano_list, xyz_list))
        self.transform = transform
        
    def __len__(self):
        return len(self.db_data_anno)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        K, mano, xyz = self.db_data_anno[idx]
        uv = self.projectPoints(xyz, K)
        
        sample = {'uv': uv, 'mano': mano}
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample
            
    @staticmethod
    def projectPoints(xyz, K):
        """ Project 3D coordinates into image space. """
        xyz = np.array(xyz)
        K = np.array(K)
        uv = np.matmul(K, xyz.T).T
        return uv[:, :2] / uv[:, -1:]
    
    
    
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        return {'uv': torch.FloatTensor(sample['uv']),
                'mano': torch.FloatTensor(sample['mano'])}

In [43]:
dataset = FreiDataset(
    '/Users/romanriazantsev/Dev/Datasets/FreiHAND_pub_v1',
    transform=ToTensor()
)

for i in range(len(dataset)):
    sample = dataset[i]
    
    print(i, sample['uv'].size(), sample['mano'].size())
    
    if i == 3:
        break

0 torch.Size([21, 2]) torch.Size([1, 61])
1 torch.Size([21, 2]) torch.Size([1, 61])
2 torch.Size([21, 2]) torch.Size([1, 61])
3 torch.Size([21, 2]) torch.Size([1, 61])


In [44]:
dataloader = DataLoader(
    dataset, 
    batch_size=4,
    shuffle=True, 
    num_workers=4
)

In [47]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['uv'].size(),
          sample_batched['mano'].size())
    if i_batch == 3:
        break

0 torch.Size([4, 21, 2]) torch.Size([4, 1, 61])
1 torch.Size([4, 21, 2]) torch.Size([4, 1, 61])
2 torch.Size([4, 21, 2]) torch.Size([4, 1, 61])
3 torch.Size([4, 21, 2]) torch.Size([4, 1, 61])
