In [None]:
%run Include.ipynb
%run FileIO.ipynb
import glob
import cv2

class Data(torch.utils.data.Dataset):
    
    def __init__(self, root, root_pds, target_type, transform=None):
        self.root             = root
        self.transform        = transform
        self.address_book     = []
        self.address_book_pds = []
        os.chdir(root)
        for file in glob.glob("*."+target_type):
            self.address_book.append(os.path.join(root, file))
            self.address_book_pds.append(os.path.join(root_pds, file+".dat"))
        img_tease = cv2.imread(self.address_book[0], cv2.IMREAD_GRAYSCALE)
        print("Image shape: " + str(img_tease.shape))
        print("Image value range: %.2f - %.2f" %(np.amin(img_tease), np.amax(img_tease)))
        print("Image data type" + str(type(img_tease[0][0])))
        print("Required data type is np.uint8")
        
    def __len__(self):
        return len(self.address_book)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img = np.uint8(cv2.imread(self.address_book[idx], cv2.IMREAD_GRAYSCALE))
        pd_path = self.address_book_pds[idx]        
        if self.transform:
            img = self.transform(img)

        instance = {'image': img, 'pd_path': pd_path}
        return instance
    
class Data_fetcher(object):
    
    @staticmethod
    def fetch_dataset(name, batch_size, batch_workers, shuffle, drop_last, scalor):
        if name == "cifar10":
            dataset = dset.CIFAR10(root=FLAGS.data_path, download=True,
                      transform=transforms.Compose([
                          transforms.Resize([64, 64]),
                          transforms.ToTensor(),
                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                      ]))
        elif name == "custom":
            dataset = Data(FLAGS.data_path, FLAGS.pds_path, FLAGS.data_extension,
                  transform=transforms.Compose(
                 [transforms.ToPILImage(),
                  transforms.ToTensor(),
                  transforms.Normalize([scalor], [scalor])
                 ]))
        else:
            raise NotImplementedError('Unrecognized dataset %s' % name)
            
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                     shuffle=shuffle, num_workers=int(batch_workers), drop_last=drop_last)
        return dataloader