In [6]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [7]:
class NTUDataset(Dataset):
  '''Characterizes a dataset for PyTorch'''

  def __init__(self, x, labels):
        'Initialization'
        self.labels = labels
        self.x = x

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.x)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        X = self.x[index]
        y = int(self.labels[index])
        
        return X, y

In [8]:
class NTUDataloader():
    ''' Given in paper that batch_size is 32 for VA-CNN network and skeleton maps are resized to 224x224'''
    def __init__(self):
        self.params ={'batch_size': 32, 'shuffle': True, 'num_workers': 8, 'pin_memory': True, 'collate_fn':self.collate_fn}
        self.data = np.load('train_ntu_data.npy')
        self.data_labels = np.load('train_ntu_label.npy')
        self.test = np.load('test_ntu_data.npy')
        self.test_labels = np.load('test_ntu_label.npy')
        self.create_train_val_data()
        self.train_set = NTUDataset(self.train, self.train_labels)
        self.val_set = NTUDataset(self.val, self.val_labels)
        self.test_set = NTUDataset(self.test, self.test_labels)
        
    def create_train_val_data(self,split_ratio=0.05):
        self.train, self.val, self.train_labels, self.val_labels = train_test_split(self.data, self.data_labels, test_size=split_ratio, random_state=10000)
        
    def train_data_loader(self):
        return DataLoader(self.train_set, **self.params)
    
    def val_data_loader(self):
        return DataLoader(self.val_set, **self.params)    
    
    def collate_fn(self,batch):
        x, y = zip(*batch)
        x, maxmin = self.torgb(x)
        x = torch.stack([torch.from_numpy(x[i]) for i in range(len(x))], 0)
        y = torch.LongTensor(y)
        return [x,torch.FloatTensor(maxmin), y]
    
    def torgb(self, ske_joints):
        rgb = []
        maxmin = list()
        self.idx = 0
        return rgb, maxmin
    
    
    