In [2]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
from os.path import isfile, join
import errno
    
def check_dirs(filename):
    if not os.path.exists(os.path.dirname(filename)):
        try:
            os.makedirs(os.path.dirname(filename))
        except OSError as exc: # Guard against race condition
            if exc.errno != errno.EEXIST:
                raise
    
def get_selected_taxons():
    selected_taxons = {}
    f = open(SELECTED_TAXONS, 'r') 
    lines = f.readlines() 
    del lines[0]
    for line in lines:
        line = line.strip()
        taxon, taxon_id = line.split(',')[0], int(line.split(',')[1])
        selected_taxons[taxon] = taxon_id
    return selected_taxons

def convert_to_square(image):
    square_size = np.max(image.shape)
    h, w = image.shape
    delta_w = square_size - w
    delta_h = square_size - h
    top, bottom = delta_h//2, delta_h-(delta_h//2)
    left, right = delta_w//2, delta_w-(delta_w//2)
    square_image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_REPLICATE)
    return square_image

def expand(image):
    if len(image.shape)==2:
        image = np.expand_dims(image, -1)
        image = np.repeat(image, 3, 2)
    return image

def get_dataset(percentage_train = 0.9):
    x_set = []
    y_set = []
    taxons_dirs = next(walk(DATASET_PATH))[1]
    n_taxons = len(taxons_dirs)
    disp_progress = display("0/"+str(n_taxons),display_id=True)
    for i, taxon in enumerate(taxons_dirs):
        if not taxon in id_map:
            print("WARNING: Taxon",taxon,"not found in id_map !")
            continue
        taxon_id = id_map[taxon]
        path = join(DATASET_PATH, taxon)
        files = [f for f in listdir(path) if isfile(join(path, f))]
        for file in files:
            x_set.append(join(path, file))
            y_set.append(taxon_id)
        disp_progress.update(str(i+1)+"/"+str(n_taxons))
    # Shuffling
    xy = list(zip(x_set, y_set))
    
    random.shuffle(xy)
    xy_train, xy_test = np.split(xy,[int(percentage_train * len(xy))])
    x_train, y_train = zip(*xy_train)
    x_test, y_test = zip(*xy_test)
    return x_train, y_train, x_test, y_test

In [None]:
# using Keras sequence
class DiatomSequence(Sequence):
    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        return np.array([cv2.imread(file_name) for file_name in batch_x]), np.array(batch_y).astype(int)