# 사진 자르는 dataset, dataloader

In [194]:
# classes for data loading and preprocessing
class Dataset:
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = ['buildings']
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.images_fps = glob.glob(images_dir + "*.png")
        self.masks_fps = glob.glob(masks_dir + "*.png")
        
        # convert str names to class values on masks
        self.class_values = [255]
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.reshape(-1, 1024, 1024, 3)  # extract_patches 쓰려고 차원 만들기
        mask = cv2.imread(self.masks_fps[i], 0)
        mask = mask.reshape(-1, 1024, 1024, 1)   # extract_patches 쓰려고 차원 만들기
        
        image = tf.image.extract_patches(images=image,
                                         sizes=[1, 256, 256, 1],
                                         strides=[1, 256, 256, 1],
                                         rates=[1, 1, 1, 1],
                                         padding='VALID')  # tf.extract_patches
        image = tf.reshape(image,[-1, 256,256,3]) # 차원 재정렬
        
        mask = tf.image.extract_patches(images=mask,  # tf.extract_patches
                                        sizes=[1, 256, 256, 1],
                                        strides=[1, 256, 256, 1],
                                        rates=[1, 1, 1, 1],
                                        padding='VALID')
        mask = tf.reshape(mask, [-1, 256, 256])  # 차원 재정렬

        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # add background if mask is not binary
        if mask.shape[-1] != 1:
            background = 1 - mask.sum(axis=-1, keepdims=True)
            mask = np.concatenate((mask, background), axis=-1)
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.images_fps)

In [200]:
class Dataloder(keras.utils.Sequence):
    """Load data from dataset and form batches
    
    Args:
        dataset: instance of Dataset class for image loading and preprocessing.
        batch_size: Integet number of images in batch.
        shuffle: Boolean, if `True` shuffle image indexes each epoch.
    """
    
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))

        self.on_epoch_end()

    def __getitem__(self, i):
        
        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])
        
        # transpose list of lists
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]
        
        return (np.squeeze(batch[0], axis=0), np.squeeze(batch[1], axis=0))  # 맨 앞 차원 없애기
    
    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.indexes) // self.batch_size
    
    def on_epoch_end(self):
        """Callback function to shuffle indexes each epoch"""
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)