In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import scipy.misc

In [2]:
class NTUDataset(Dataset):
  '''Characterizes a dataset for PyTorch'''

  def __init__(self, x, labels):
        'Initialization'
        self.labels = labels
        self.x = x

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.x)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        X = self.x[index]
        y = int(self.labels[index])
        
        return X, y

In [1]:
class NTUDataloader():
    ''' Given in paper that batch_size is 32 for VA-CNN network and skeleton maps are resized to 224x224'''
    def __init__(self):
        self.params ={'batch_size': 32, 'shuffle': True, 'num_workers': 8, 'collate_fn':self.collate_fn, 'pin_memory': True}
        self.data = self.load_data(train=True)
        self.data_labels = np.load('train_ntu_label.npy')
        self.test = self.load_data(train=False)
        self.test_labels = np.load('test_ntu_label.npy')
        self.create_train_val_data()
        print(self.train.shape)
        print(self.test.shape)
        self.max = 5.18858098984
        self.min = -5.28981208801
        self.train_set = NTUDataset(self.train, self.train_labels)
        self.val_set = NTUDataset(self.val, self.val_labels)
        self.test_set = NTUDataset(self.test, self.test_labels)
        
    def load_data(self,train):
        if train == True:
            file_name = 'train_ntu_data.npy'
        else:
            file_name = 'test_ntu_data.npy'
        temp = np.load(file_name)
        tot_vid = temp.shape[0]
        # As Max no of video frames in the dataset are 300, we take 300 frames for each video.
        max_frames = 300
        data = np.zeros((tot_vid,max_frames,150), dtype=np.float32)
        for i in range(tot_vid):
            no_of_frames = temp[i].shape[0]
            # no_of_frames * 50 * 3 => nod_of_frames * 150
            data[i, :no_of_frames] = np.reshape(temp[i],(no_of_frames,150))
        return data
        
    def create_train_val_data(self,split_ratio=0.05):
        self.train, self.val, self.train_labels, self.val_labels = train_test_split(self.data, self.data_labels, test_size=split_ratio, random_state=10000)
        return
        
    def train_data_loader(self):
        return DataLoader(self.train_set, **self.params)
    
    def val_data_loader(self):
        return DataLoader(self.val_set, **self.params)    
    
    def test_data_loader(self):
        return DataLoader(self.test_set, **self.params)    
    
    def collate_fn(self,batch):
        x, y = zip(*batch)
        y = torch.LongTensor(y)
#         print(y.size())
        maxmin = torch.FloatTensor([[self.max,self.min] for i in range(y.size()[0])])
        x = self.create_imagemap(x)
        x = torch.stack([torch.from_numpy(x[i]) for i in range(len(x))], 0)
        return [x, maxmin, y]
    
    def create_imagemap(self, x):
        images = []
        for each_seq in x:
            del_rows = []
            for i in range(each_seq.shape[0]):
                if np.count_nonzero(each_seq[i]) == 0:
                    del_rows.append(i)
            each_seq = np.delete(each_seq, del_rows, axis=0)
            if np.count_nonzero(each_seq[:,0:75]) == 0:
                each_seq = np.delete(each_seq, range(75), axis=1)
            elif np.count_nonzero(each_seq[:,75:150]) == 0:
                each_seq = np.delete(each_seq, range(75, 150), axis=1)
            each_seq = 255 * (each_seq - self.min)/(self.max - self.min)
#             print(each_seq.shape)
            image = np.reshape(each_seq,(each_seq.shape[0],int(each_seq.shape[1]/3),3))
            image = scipy.misc.imresize(image, (224, 224)).astype(np.float32)
            image = image - 110
            image = np.transpose(image,[1,0,2])
            image = np.transpose(image,[2,1,0])
            images.append(image)
        
#         print(cnt)    
        return images
    
    def imagemap_images(self,seq):
        del_rows = []
        for i in range(seq.shape[0]):
            if np.count_nonzero(seq[i]) == 0:
                del_rows.append(i)
        seq = np.delete(seq, del_rows, axis=0)
        if np.count_nonzero(seq[:,0:75]) == 0:
            seq = np.delete(seq, range(75), axis=1)
        elif np.count_nonzero(seq[:,75:150]) == 0:
            seq = np.delete(seq, range(75, 150), axis=1)
        seq = 255 * (seq - self.min)/(self.max - self.min)
        image = np.reshape(seq,(seq.shape[0],int(seq.shape[1]/3),3))
        image = scipy.misc.imresize(image, (224, 224)).astype(np.float32)
#         print(image)
        plt.imshow(image.astype(np.uint8))
        plt.show()
        
    def plot_image(self,x,transformed=False):
        x = np.transpose(x,[2,1,0])
        x = np.transpose(x,[1,0,2])
#         print(x)
#         print("before")
        if transformed == False:
            x += 110
#         print(x)
        plt.imshow(x.astype(np.uint8))
        plt.show()
        
    
    

In [4]:
# d = NTUDataloader()
# d.imagemap_images(d.train[0,:])
# d.plot_image(d.train[0,:])
# d.collate_fn(d.train)

In [5]:
# print(d.train[0,:].shape)