In [None]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset

import pandas as pd
from skimage import io
import os
import glob



class PhoenixDataset(Dataset):

    def __init__(self, csv_dir, root_dir, transforms):
        self.csv_file = pd.read_csv(csv_dir)
        self.root_dir = root_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.csv_file)

    def __getitem__(self, idx):
        if(torch.is_tensor(idx)):
            idx = idx.tolist()

        clip_path = os.path.join(
            self.root_dir, self.csv_file.iloc[idx, 0].split('|')[1])
        collection = io.imread_collection(clip_path)
        clip = torch.zeros(len(collection), 3, 224, 224)
        for i, img in enumerate(collection):
            clip[i, :, :, :] = self.transforms(img)

        annotation = self.csv_file.iloc[idx, 0].split('|')[3]
        sample = {'clip': clip, 'annotation': annotation}

        return sample

    def statistic_analysis(self):
        '''
        analysis the dataset
        '''
        size = len(self.csv_file)
        annotations_len = [len(self.csv_file.iloc[i, 0].split('|')[
                               3].split()) for i in range(size)]
        print('corpus pairs: {}\n'.format(size))
        print('max_anotation_length: {}'.format(max(annotations_len)))
        print('min_anotation_length: {}'.format(min(annotations_len)))
        print('ave_anotation_length: {}\n'.format(sum(annotations_len)/size))
        
        clip_path = os.path.join(
            self.root_dir, self.csv_file.iloc[0, 0].split('|')[1])
        collection = io.imread_collection(clip_path)
        frame_size = collection[0].shape
        print('frame_size: {}'.format(frame_size))
        
        clip_pathes = [os.path.join(self.root_dir, self.csv_file.iloc[i, 0].split('|')[
                                    1]) for i in range(size)]
        clip_len = [len(glob.glob(i)) for i in clip_pathes]
        print('max_clip_length: {}'.format(max(clip_len)))
        print('min_clip_length: {}'.format(min(clip_len)))
        print('ave_clip_length: {}'.format(sum(clip_len)/size))


# if __name__ == '__main__':

#     csv_root = '/media/xieliang555/新加卷/数据集/phoenix2014-release/phoenix-2014-multisigner/annotations/manual'
#     clip_root = '/media/xieliang555/新加卷/数据集/phoenix2014-release/phoenix-2014-multisigner/features/fullFrame-210x260px'

#     # test set
#     print('================ test set ==============')
#     test_csv_dir = os.path.join(csv_root, 'test.corpus.csv')
#     test_root_dir = os.path.join(clip_root, 'test')
#     test_set = PhoenixDataset(test_csv_dir, test_root_dir, transforms=None)
#     test_set.statistic_analysis()

#     # dev set
#     print('================ dev set ==============')
#     dev_csv_dir = os.path.join(csv_root, 'dev.corpus.csv')
#     dev_root_dir = os.path.join(clip_root, 'dev')
#     dev_set = PhoenixDataset(dev_csv_dir, dev_root_dir, transforms=None)
#     dev_set.statistic_analysis()

#     # training set
#     print('================ train set ==============')
#     train_csv_dir = os.path.join(csv_root, 'train.corpus.csv')
#     train_root_dir = os.path.join(clip_root, 'train')
#     train_set = PhoenixDataset(train_csv_dir, train_root_dir, transforms=None)
#     train_set.statistic_analysis()