In [None]:
import cv2
import numpy as np
from os.path import isfile, join
from os import listdir
from random import shuffle


class DataSetGenerator:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.data_labels = self.get_data_labels()
        self.data_info = self.get_data_paths()
        
    def get_data_labels(self):
        data_labels = []
        for filename in listdir(self.data_dir):
            if not isfile(join(self.data_dir, filename)):
                data_labels.append(filename)
        return data_labels
    
    def get_data_paths(self):
        data_paths = []
        for label in self.data_labels:
            img_lists=[]
            path = join(self.data_dir, label)
            for filename in listdir(path):
                tokens = filename.split('.')
                if tokens[-1] == 'jpg':
                    image_path=join(path, filename)
                    img_lists.append(image_path)
            shuffle(img_lists)
            data_paths.append(img_lists)
        return data_paths
    
    def get_batches(self, batch_size=20, image_size=(64, 64), allchannel=False):
        counter = 0
        images = []
        labels = []
        empty = False
        current_batch_size = int(batch_size/len(self.data_info))
        while True:
            for i in range(len(self.data_labels)):
                label = np.zeros(len(self.data_labels), dtype=int)
                label[i] = 1
                if len(self.data_info[i]) &lt; counter+1:
                    empty=True
                    continue
                empty = False
                img = cv2.imread(self.data_info[i][counter])
                if not allchannel:
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                    img = np.reshape(img, (img.shape[0], img.shape[1], 1))
                images.append(img)
                images.append(label)
            counter+=1
            
            if empty:
                break
            
            if (counter)%current_batch_size == 0:
                yield np.array(images, dtype=np.uint8), np.array(labels, dtype=np.uint8)
                del images
                del labels
                images = []
                labels = []