In [None]:
"""
Load face image dataset.
"""
import os
import random
import glob

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

#np.random.seed(1)
def load_data(data_path="./data/test_masked_cloth_cropped", train_size=460, val_size=20, test_size=20,
              batch_size=120, use_mask=True, label_mask=False, use_gpu=True):
    """
    Load dataset and split into train, val and test.
    Args:
        path: (string) data path
        train_size: (int) training set size
        val_size: (int) validation set size
        test_size: (int) test set size
        batch_size: (int) batch size
        use_gpu: (bool) use gpu or not
    Return:
        out: (tuple) (train_dset, val_dset, test_dset), (train_loader, val_loader, test_loader)
    """
    idens_paths = glob.glob('{}/*'.format(data_path))
    idens = [ path.split('/')[-1] for path in idens_paths]
    np.random.shuffle(idens)
    print(len(idens))
    
    train_dset = CustomDataset(data_path, idens[:train_size], batch_size, use_mask, label_mask, use_gpu=use_gpu)
    val_dset   = CustomDataset(data_path, idens[train_size: train_size + val_size], batch_size, use_mask, label_mask, use_gpu=use_gpu)
    test_dset  = CustomDataset(data_path, idens[train_size + val_size:], batch_size, use_mask, label_mask, use_gpu=use_gpu)
    
    print('train_set:{}, val_set:{}, test_set:{}'.format(train_dset.size, val_dset.size, test_dset.size))
    train_loader = DataLoader(train_dset, batch_size=batch_size, shuffle=False)
    val_loader = DataLoader(val_dset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dset, batch_size=batch_size, shuffle=False)

    return (train_dset, val_dset, test_dset), (train_loader, val_loader, test_loader)

class CustomDataset(Dataset):
    """
    Custome dataset.
    """
    def __init__(self, data_path, idens, batch_size, use_mask=True, label_mask=False, img_size=224, use_gpu=True):
        self.data_path = data_path
        self.idens = idens
        self.n_idens = len(idens)
        self.batch_size = batch_size
        self.img_size = img_size
        self.use_mask = use_mask
        self.label_mask = label_mask
        self.use_gpu = use_gpu
        
        # dict {iden => [clear_img_filename,...]
        self.iden_clear_faces = {}
        # dict {iden => [masked_img_filename,...]
        self.iden_masked_faces = {}
        
        # A estimate size, it is ok if not accurate
        # use self.n_idens * 10 for debugging
        self.size = self.n_idens * 100
        
        # Array of current batch data
        self.batch_data = []
    
        # Current batch count
        self.batch_count = None
    
    def _new_batch(self):
        """ 
        Each batch has 3 identities, each identity has 
        3 clear face images and 3 masked face images.
        
        batch_size = idens_per_batch * clear_per_iden * masked_per_iden
        num_iteration = self.size / batch_size
        """
        
        idens_per_batch = min(self.n_idens, 20)
        
        if self.use_mask:
            clear_per_iden  = 3
            masked_per_iden = 3
        else:
            clear_per_iden  = 6
            masked_per_iden = 0
        
        self.batch_data, batch_data = [], []
        idens = [ self.idens[idx] for idx in random.sample(range(self.n_idens), idens_per_batch) ]
        
        for nid, iden in enumerate(idens, start=1):
            # nid (int): identity id within current batch
            # iden (string): identity
            if iden in self.iden_clear_faces:
                clears = self.iden_clear_faces[iden]
                masks = self.iden_masked_faces[iden]
            else:
                # Lazy load images for given identity
                clears = glob.glob('{}/{}/*_*.jpg'.format(self.data_path, iden))
                self.iden_clear_faces[iden] = clears
                masks = glob.glob('{}/{}/*_*_*.jpg'.format(self.data_path, iden))
                self.iden_masked_faces[iden] = masks

            if self.use_mask:
                if self.label_mask:
                    batch_data.extend([
                        (clears[idx], (nid, False)) for idx in random.sample(range(len(clears)), clear_per_iden) ])
                    batch_data.extend([
                        (masks[idx], (nid, True))   for idx in random.sample(range(len(masks)), masked_per_iden) ])
                else:
                    batch_data.extend([
                        (clears[idx], nid) for idx in random.sample(range(len(clears)), clear_per_iden) ])
                    batch_data.extend([
                        (masks[idx], nid)   for idx in random.sample(range(len(masks)), masked_per_iden) ])
            else:
                batch_data.extend([
                    (clears[idx], nid) for idx in random.sample(range(len(clears)), clear_per_iden) ])
        
        np.random.shuffle(batch_data)
        self.batch_data = batch_data
 
    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        """ input format (image, (identity (string), masked (bool))
        """
        batch_count = int(idx / self.batch_size)
        b_idx = idx % self.batch_size
        if batch_count != self.batch_count:
            self.batch_count = batch_count
            self._new_batch()
        
        filename, label = self.batch_data[b_idx]
        input_image = Image.open(filename)
        preprocess = transforms.Compose([
            # Not needed for now
            #transforms.Resize(256),
            #transforms.CenterCrop(224),
            
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        input_tensor = preprocess(input_image)
        label_tensor = torch.tensor(label)
        
        if self.use_gpu:
            input_tensor = input_tensor.to('cuda')
            label_tensor = label_tensor.to('cuda')

        return input_tensor, label_tensor
