In [None]:
import os
import glob
import random
import pandas as pd
import sys
sys.path.append(r"../../data/data_checking_code/")

import folder_tool

def generate_kfold_splits(dataset_ids, num_folds):
    random.shuffle(dataset_ids)
    fold_size = len(dataset_ids) // num_folds
    kfold_splits = []
    kfold_val_splits = []

    for fold in range(num_folds):
        fold_start = fold * fold_size
        fold_end = (fold + 1) * fold_size if fold < num_folds - 1 else len(dataset_ids)
        fold_ids = dataset_ids[fold_start:fold_end]
        kfold_splits.append(fold_ids)

    for i in range(-1, 4, 1):
        kfold_val_splits.append(kfold_splits[i])
    return kfold_splits, kfold_val_splits

data_info, result_info_path = folder_tool.data_json_setting_load('../0_data_json_setting/0_run.json')
y_train_dir = data_info['y_train_dir']

num_folds = 5
fold_csv_path = f"./kfold{num_folds}_fold_1.csv"
assert not os.path.exists(fold_csv_path), fold_csv_path + ' exist'
dataset_ids = [os.path.basename(x) for x in glob.glob(f'../{y_train_dir}/*.PNG')]

kfold_splits, val_split = generate_kfold_splits(dataset_ids, num_folds)

for fold, (fold_ids, val_ids) in enumerate(zip(kfold_splits, val_split), start=1):
    fold_df = []
    for index, ids in enumerate(dataset_ids):
        if ids in fold_ids:
            fold_df.append([ids, 'test'])
        elif ids in val_ids:
            fold_df.append([ids, 'valid'])
        else:
            fold_df.append([ids, 'train'])
    
    fold_df = pd.DataFrame(fold_df, columns=['image_id', 'type'])
    fold_csv_path = f"./kfold{num_folds}_fold_{fold}.csv"
    fold_df.to_csv(fold_csv_path, index=False)


In [None]:
num_folds = 5
fold_counts = {"fold": [], "train": [], "valid": [], "test": []}

for fold in range(1, num_folds + 1):
    fold_csv_path = f"./kfold{num_folds}_fold_{fold}.csv"
    fold_df = pd.read_csv(fold_csv_path)
    
    train_count = (fold_df["type"] == "train").sum()
    valid_count = (fold_df["type"] == "valid").sum()
    test_count = (fold_df["type"] == "test").sum()
    
    fold_counts["fold"].append(f"fold_{fold}")
    fold_counts["train"].append(train_count)
    fold_counts["valid"].append(valid_count)
    fold_counts["test"].append(test_count)

counts_table = pd.DataFrame(fold_counts)
print(counts_table)
