In [7]:
import numpy as np
import os
from collections import defaultdict
import random
from skimage.io import imread
from skimage.transform import resize

In [12]:
""" 
Modified from source: 
https://github.com/Vladkryvoruchko/PSPNet-Keras-tensorflow/blob/master/utils/preprocessing.py
"""
def update_inputs(batch_size, resize_tuple, num_classes):
    return np.zeros([batch_size, resize_tuple[0], resize_tuple[1], 3]), \
           np.zeros([batch_size, resize_tuple[0], resize_tuple[1], num_classes])


def data_generator(image_dir='train_images', n_classes = 50, 
                   batch_size = 16, resize_shape_tuple=(128,128), 
                   separator='_', test_nmb=50):
    
    if not os.path.exists(image_dir):
        raise FileNotFoundError('ERROR! The folder {} does not'
                                ' exist\n'.format(image_dir))
    
    data = defaultdict(dict)
    
    for image_path in os.listdir(image_dir):
        # image number, might want to replace this or rename the images
        nmb = image_path.split(separator)[0]
        data[nmb]['image'] = image_path
    
    for anno_path in os.listdir(anno_dir):
        # image number, might want to replace this or rename the images
        nmb = anno_path.split(separator)[0]
        data[nmb]['annotation'] = anno_path
    
    values = data.values()
    random.shuffle(values)
    return generate(values[batch_size:], n_classes, 
                    batch_size, resize_shape_tuple, image_dir, anno_dir), \
           generate(values[:batch_size], n_classes, 
                    batch_size, resize_shape_tuple, image_dir, anno_dir)

def generate(values, n_classes, batch_size, resize_shape_tuple, image_dir, anno_dir):
    while 1:
        #random.shuffle(values)
        images, labels = update_inputs(batch_size=batch_size,
                                       input_size=input_size, 
                                       num_classes=nb_classes)
        for i, vals in enumerate(values):
            img = resize(imread(os.path.join(image_dir, vals['image'])), 
                                   resize_shape_tuple)
            
            """
            Need to figure out the ground truth labels
            """
            y = imread(os.path.join(anno_dir, vals['annotation']))
            y = (np.arange(n_classes) == y[:,:,None]).astype('float32')
            #assert y.shape[2] == n_classes
            images[i%batch_size] = img
            labels[i%batch_size] = y
            if (i+1)%batch_size == 0:
                yield images, labels
                images, labels = update_inputs(batch_size=batch_size,
                                               input_size=input_size, 
                                               num_classes=nb_classes)


In [None]:
""" 
Modified from source: 
https://github.com/Vladkryvoruchko/PSPNet-Keras-tensorflow/blob/master/utils/preprocessing.py
"""
from keras.utils import Sequence
class DataGenerator(Sequence):
    def __init__(self,batch_size, n_classes,resize_tuple, image_dir,anno_dir):
        
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.resize_tuple = resize_tuple
        self.image_dir = image_dir
        self.anno_dir = anno_dir
        

    def _update_inputs(batch_size, resize_tuple, num_classes):
        return np.zeros([batch_size, resize_tuple[0], resize_tuple[1], 3]), \
               np.zeros([batch_size, resize_tuple[0], resize_tuple[1], num_classes])


    def data_generator(image_dir='train_images', anno_dir='', n_classes = 50, 
                       batch_size = 16, resize_shape_tuple=(128,128), 
                       separator='_', test_nmb=50):

        if not os.path.exists(image_dir):
            raise FileNotFoundError('ERROR! The folder {} does not'
                                    ' exist\n'.format(image_dir))

        data = defaultdict(dict)

        for image_path in os.listdir(image_dir):
            # image number, might want to replace this or rename the images
            nmb = image_path.split(separator)[0]
            data[nmb]['image'] = image_path

        for anno_path in os.listdir(anno_dir):
            # image number, might want to replace this or rename the images
            nmb = anno_path.split(separator)[0]
            data[nmb]['annotation'] = anno_path

        values = data.values()
        random.shuffle(values)
        return generate(values[batch_size:], n_classes, 
                        batch_size, resize_shape_tuple, image_dir, anno_dir), \
               generate(values[:batch_size], n_classes, 
                        batch_size, resize_shape_tuple, image_dir, anno_dir)

    def generate(values, n_classes, batch_size, resize_shape_tuple, image_dir, anno_dir):
        while 1:
            #random.shuffle(values)
            images, labels = _update_inputs(batch_size=batch_size,
                                           input_size=input_size, 
                                           num_classes=nb_classes)
            for i, vals in enumerate(values):
                img = resize(imread(os.path.join(image_dir, vals['image'])), 
                                       resize_shape_tuple)

                """
                Need to figure out the ground truth labels
                """
                y = imread(os.path.join(anno_dir, vals['annotation']))
                y = (np.arange(n_classes) == y[:,:,None]).astype('float32')
                #assert y.shape[2] == n_classes
                images[i%batch_size] = img
                labels[i%batch_size] = y
                if (i+1)%batch_size == 0:
                    yield images, labels
                    images, labels = _update_inputs(batch_size=batch_size,
                                                   input_size=input_size, 
                                                   num_classes=nb_classes)
