In [None]:
from google.colab import drive
import torch
import os
import h5py
import numpy as np
from torchvision.transforms import ToPILImage

drive.mount('/content/gdrive')

DATASET_GOOGLE_DRIVE_PATH = '/content/gdrive/MyDrive/CS523/Project/dataset'


def get_configuration(batch_size):
    config = {}
    config['batch_size'] = batch_size
    config['num_workers'] = 1
    return config

def get_file_name(files_list):
  with open(files_list) as f:
    return [line.rstrip()[:] for line in f]

# dataset class 
class FaceMaskData(torch.utils.data.Dataset):

  def __init__(self, mode, config, transform=None):

    self.data_dir = DATASET_GOOGLE_DRIVE_PATH
    self.transform = transform

    if mode == 'train':
      self.files = get_file_name(os.path.join(self.data_dir, 'train_files.txt'))
    else:
      self.files = get_file_name(os.path.join(self.data_dir, 'test_files.txt'))

    image = []
    label = []

    # iterate hdf5 files listed in txt and store data into self.image and self.label
    for dataset in self.files:
      path = os.path.join(self.data_dir, dataset)
      self.file = h5py.File(path, 'r')
      self.total_num_imgs, self.H, self.W, self.C = self.file['image'].shape
      image.append(self.file['image'][:])
      label.append(self.file['labels'][:])

    self.image = np.vstack(image)
    self.label = np.vstack(label)

    self.num_images = len(self.image) 

  def __getitem__(self, index):
    image = self.image[index]
    label = self.label[index]
    if self.transform:
        image = self.transform(image)
    return image, torch.FloatTensor(label)

  def __len__(self):
    return self.num_images
      

config = get_configuration(16)
dataset = FaceMaskData('train', config)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                              batch_size=config['batch_size'],
                              shuffle=True, 
                              num_workers=config['num_workers'])

for i, data in enumerate(data_loader, 0):
   image, label = data
   image, label = image.cuda(), label.cuda()

# show sample image
to_img = ToPILImage()
image[0].T.size()
to_img(image[0].T)
