In [1]:
from PIL import Image, ImageOps
import os
import numpy as np
import random

In [2]:
def get_filenames_from_folder(folders, sections):
    """
        args:
            folders: A list of folder name.
            section: A list of snake section.
        return:
            files_dict: A dictionary of list of file information
    """
    files_dict = {}
    for section, folder in zip(sections, folders):
        files_list = []
        for filename in os.listdir(folder):
            files_list.append({
                'file_path':folder + '/',
                'file_name':filename.split('.')[0],
                'file_type':'.' + filename.split('.')[1]
            })
        files_dict[section] = files_list
    return files_dict

In [3]:
def get_number_from_string(string):
    num_digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    num = ''
    count = 0
    for s in string:
        if s in num_digits:
            count = 1
            num = num + s
        else:
            if count == 1:
                return int(num)
    return int(num)

In [4]:
def split_test_data(min_section, section_1, section_2, number_split):
    random.shuffle(min_section)
    min_section_test = []
    section_1_test = []
    section_2_test = []
    for min_sec in min_section:
        if len(min_section_test) >= (number_split):
            break
        for sec_1 in section_1:
            for sec_2 in section_2:
                if get_number_from_string(min_sec['file_name']) == get_number_from_string(sec_1['file_name']) and get_number_from_string(min_sec['file_name']) == get_number_from_string(sec_2['file_name']):
                    if min_sec not in min_section_test and sec_1 not in section_1_test and sec_2 not in section_2_test:
                        min_section_test.append(min_sec)
                        section_1_test.append(sec_1)
                        section_2_test.append(sec_2)
    #remove selected_data
    for min_sec in min_section_test:
        min_section.remove(min_sec)
    for sec_1 in section_1_test:
        section_1.remove(sec_1)
    for sec_2 in section_2_test:
        section_2.remove(sec_2)
        
    return min_section_test.copy(), section_1_test.copy(), section_2_test.copy() 

In [5]:
def split_validation_data(section, number_split):
    random.shuffle(section)
    validation_data = []
    for sec in section:
        if len(validation_data) >= number_split:
            break
        validation_data.append(sec)
    
    #remove validation_data from list
    for v in validation_data:
        section.remove(v)
        
    return validation_data.copy()

In [6]:
def load_image_to_folders(files_1, files_2, files_3, folder_1, folder_2, folder_3):
    for files, folder in zip([files_1, files_2, files_3], [folder_1, folder_2, folder_3]):
        for file in files:
            img = Image.open(file['file_path'] + file['file_name'] + file['file_type'])
            if img is not None:
                img.save(folder + file['file_name'] + '.jpg')

In [7]:
sections = ['Head',
          'Mid',
          'Tail']
for i in range(14):
    folders = ['./images/original/class'+ str(i) + '/Head',
              './images/original/class'+ str(i) + '/Mid',
              './images/original/class'+ str(i) + '/Tail']
    files_dict = get_filenames_from_folder(folders, sections)
    head = files_dict['Head']
    mid = files_dict['Mid']
    tail = files_dict['Tail']

    if i != 12:
        number_split_test = len(tail) * 20 / 100
        tail_test, head_test, mid_test = split_test_data(tail, head, mid, number_split_test)
        tail_validation = split_validation_data(tail, len(tail) * 20 / 100) 
        head_validation = split_validation_data(head, len(head) * 20 / 100) 
        mid_validation  = split_validation_data(mid, len(mid) * 20 / 100) 
        load_image_to_folders(tail_test, head_test, mid_test, 
                              './images/Head_Mid_Tail/class' + str(i) + '/test/' + 'tail/',
                             './images/Head_Mid_Tail/class' + str(i) + '/test/' + 'head/',
                             './images/Head_Mid_Tail/class' + str(i) + '/test/' + 'mid/')
        load_image_to_folders(tail_validation, head_validation, mid_validation, 
                              './images/Head_Mid_Tail/class' + str(i) + '/validation/' + 'tail/',
                             './images/Head_Mid_Tail/class' + str(i) + '/validation/' + 'head/',
                             './images/Head_Mid_Tail/class' + str(i) + '/validation/' + 'mid/')
        load_image_to_folders(tail, head, mid, 
                              './images/Head_Mid_Tail/class' + str(i) + '/train/' + 'tail/',
                             './images/Head_Mid_Tail/class' + str(i) + '/train/' + 'head/',
                             './images/Head_Mid_Tail/class' + str(i) + '/train/' + 'mid/')
    else:
        number_split_test = len(head) * 20 / 100
        head_test, tail_test, mid_test = split_test_data(head, tail, mid, number_split_test)
        head_validation = split_validation_data(head, len(head) * 20 / 100) 
        tail_validation = split_validation_data(tail, len(tail) * 20 / 100) 
        mid_validation  = split_validation_data(mid, len(mid) * 20 / 100) 
        load_image_to_folders(head_test, tail_test, mid_test, 
                              './images/Head_Mid_Tail/class' + str(i) + '/test/' + 'head/',
                             './images/Head_Mid_Tail/class' + str(i) + '/test/' + 'tail/',
                             './images/Head_Mid_Tail/class' + str(i) + '/test/' + 'mid/')
        load_image_to_folders(head_validation, tail_validation, mid_validation, 
                              './images/Head_Mid_Tail/class' + str(i) + '/validation/' + 'head/',
                             './images/Head_Mid_Tail/class' + str(i) + '/validation/' + 'tail/',
                             './images/Head_Mid_Tail/class' + str(i) + '/validation/' + 'mid/')
        load_image_to_folders(head, tail, mid, 
                              './images/Head_Mid_Tail/class' + str(i) + '/train/' + 'head/',
                             './images/Head_Mid_Tail/class' + str(i) + '/train/' + 'tail/',
                             './images/Head_Mid_Tail/class' + str(i) + '/train/' + 'mid/')