<a href="https://colab.research.google.com/github/ykitaguchi77/ImageProcessing/blob/master/split_dataset_for_crossvalidation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Split datasets for cross validation**
Trainデータセットを5分割し、そのうち1つをvalセットに、残りの合計をtestセットに分割するスクリプト


In [1]:
import random
import glob
import os
import shutil
import numpy as np
import time

from google.colab import drive
drive.mount('/content/drive')


'''
-----orig_data-----grav
                |--cont
↓
↓

-----dst_data[0]------dst_train[0]----grav
  |                |               |-- cont
  |                |--dst_val[0]------grav
  |                                |--cont
  |
  |--dst_data[1]------dst_train[1]----grav
  |                |               |-- cont
  |                |--dst_val[1]------grav
  |                                |--cont
  ...
  |--dst_data[1]------dst_train[9]----grav
                   |               |-- cont
                   |--dst_val[9]------grav
                                   |--cont
'''

Mounted at /content/drive


'\n-----orig_data-----grav\n                |--cont\n↓\n↓\n\n-----dst_data[0]------dst_train[0]----grav\n  |                |               |-- cont\n  |                |--dst_val[0]------grav\n  |                                |--cont\n  |\n  |--dst_data[1]------dst_train[1]----grav\n  |                |               |-- cont\n  |                |--dst_val[1]------grav\n  |                                |--cont\n  ...\n  |--dst_data[1]------dst_train[9]----grav\n                   |               |-- cont\n                   |--dst_val[9]------grav\n                                   |--cont\n'

#**Module群**

In [2]:
def get_path(orig_path, dst_path, split_num):
    classes = os.listdir(orig_path) #クラス名を取得
    #データの分割数を設定
    data_list = [0]*len(classes)
    k=0
    for i in classes:
        data_list[k] = glob.glob(orig_path+'/'+i+'/*')
        k+=1
    split_length = int(len(data_list)/split_num)
    return data_list, classes, split_length

def makefolder(orig_path, dst_path, classes):
    #フォルダを作成
    if not os.path.exists(dst_path):  # ディレクトリがなかったら
        os.mkdir(dst_path)  # 作成したいフォルダ名を作成
        for i in range(split_num):
            os.mkdir(dst_path+'/'+str(i))
            os.mkdir(dst_path+'/'+str(i)+'/train')
            os.mkdir(dst_path+'/'+str(i)+'/val')
            for j in classes:
                os.mkdir(dst_path+'/'+str(i)+'/train/'+j)
                os.mkdir(dst_path+'/'+str(i)+'/val/'+j)

def split_data_list(data_list, split_num):
    split_data, dst_data, dst_train, dst_val, dst_test = [0]*split_num, [0]*split_num, [0]*split_num, [0]*split_num, [0]*split_num

    #データの分割
    split_data = list(np.array_split(data_list, split_num))

    #データセット全体と分割したデータの差分を取り、dst_dataに格納

    dst_data = [0] * split_num
    for i in range(split_num):
        dst_data[i] = [x for x in data_list if x not in split_data[i]]

    #トレーニングセット、バリデーションセット、テストセットのリスト作成
    for i in range(split_num):
        dst_train[i] = dst_data[i]
        dst_val[i] = split_data[i]  #テストセット
    
    return dst_train, dst_val

def copy_to_folders(split_num, class_name, dst_train, dst_val):
    k=0
    for i in range(split_num):
        dst_path_train = '/content/drive/My Drive/gravcont_crossvalidation/'+str(i)+'/train/'+class_name
        dst_path_val = '/content/drive/My Drive/gravcont_crossvalidation/'+str(i)+'/val/'+class_name
        for p in dst_train[k]:  # 選択したファイルを目的フォルダにコピー
            shutil.copy(p, dst_path_train)
            #print(p)
            print(dst_path_train)

        for p in dst_val[k]:  # 選択したファイルを目的フォルダにコピー
            shutil.copy(p, dst_path_val)
            #print(p)    
            print(dst_path_val)

        k+=1

def main():
    orig_path = "/content/drive/My Drive/Grav_bootcamp/PrePlusTrain"
    dst_path = "/content/drive/My Drive/gravcont_crossvalidation"  # フォルダ名
    split_num = 5  #データをいくつに分割するかを記載

    data_list, classes, split_length = get_path(orig_path, dst_path, split_num)
    makefolder(orig_path, dst_path, classes)

    k=0
    for i in range(len(classes)):
        dst_train, dst_val = split_data_list(data_list[k], split_num)
        copy_to_folders(split_num, classes[k], dst_train, dst_val)
        k+=1

In [None]:
if __name__ == "__main__":
    start = time.time()
    main()
    elapsed_time = time.time() - start
    print ("elapsed_time:{0}".format(elapsed_time) + "[sec]")

In [None]:
orig_path = "/content/drive/My Drive/PrePlusTrain/558"
dst_path = "/content/drive/My Drive/gravcont_crossvalidation"  # フォルダ名
split_num = 5  #データをいくつに分割するかを記載

data_list, classes, split_length = get_path(orig_path, dst_path, split_num)
makefolder(orig_path, dst_path, classes)


print(classes)
k=0
for i in range(len(classes)):
    dst_train, dst_val = split_data_list(data_list[k], split_num)
    print(classes[k])
    copy_to_folders(split_num, classes[k], dst_train, dst_val)
    k+=1

#**作ったフォルダを削除したいとき**

In [None]:
dst_path = "/content/drive/My Drive/gravcont_crossvalidation"
directory = dst_path
try:
    shutil.rmtree(directory)
except FileNotFoundError:
    pass