In [11]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
from os.path import isfile, join
import errno
import csv

def get_dataset(dataset_folder, ids=False):
    result = {}
    x_set = []
    y_set = []
    taxons_dirs = next(os.walk(dataset_folder))[1]
    n_taxons = len(taxons_dirs)
    disp_progress = display("0/"+str(n_taxons),display_id=True)
    for i, taxon in enumerate(taxons_dirs):
        '''
        if not taxon in id_map:
            print("WARNING: Taxon",taxon,"not found in id_map !")
            continue
        '''
        if ids: taxon_id = id_map[taxon]
        else: taxon_id = taxon
        
        path = join(dataset_folder, taxon)
        files = [f for f in os.listdir(path) if isfile(join(path, f))]
        for file in files:
            x_set.append(join(path, file))
            y_set.append(taxon_id)
            result.setdefault(taxon_id, []).append(join(path, file))
        disp_progress.update(str(i+1)+"/"+str(n_taxons))
    return x_set, y_set, result

def get_last_epoch(log_file):
    if os.path.exists(log_file):
        csv_reader = csv.reader(open(log_file), delimiter=',')
        return int(list(csv_reader)[-1][0])
    else:
        return 0
    
def save_model(model, root):
    model_path = os.path.join(root, "model.json")
    weights_path = os.path.join(root, "model.h5")
    check_dirs(model_path)

    model_json = model.to_json()
    with open(model_path, "w") as json_file:
        json_file.write(model_json)
    model.save_weights(weights_path)

    print("Saved model to", model_path)
    print("Saved weights to", weights_path)

def get_taxa_list(list_path):
    taxa_list = []
    with open(list_path, newline='') as csvfile: 
        csv_reader = csv.reader(csvfile, delimiter=' ', quotechar='|')
        for row in csv_reader:
            taxa_list.extend(row)
    return taxa_list