# Example code for splitting the dataset.
This file is used to split the dataset into training, validation, and testing sets.

In [1]:
import os
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import cv2
from collections import defaultdict
import csv

In [2]:
def load_images(data_dir):    
    images = []
    labels = []
    
    classes = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
    for class_name in classes:
        class_dir = os.path.join(data_dir, class_name)
        for file in tqdm(os.listdir(class_dir), desc=f"Loading {class_name}"):
            if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                img = cv2.imread(str(os.path.join(class_dir, file)))
                # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                images.append(img)
                labels.append(class_name)
    return images, labels

In [3]:
def split_sets(images, labels, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2, random_state=42):
    assert np.isclose(train_ratio + val_ratio + test_ratio, 1.0)
    
    if labels is not None:
        X_train, X_temp, y_train, y_temp = train_test_split(
            images, labels,
            train_size=train_ratio,
            stratify=labels,
            random_state=random_state,
        )
        
        val_test_ratio = val_ratio / (val_ratio + test_ratio)
        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp,
            test_size=1 - val_test_ratio,
            stratify=y_temp,
            random_state=random_state,
        )
        return (X_train, y_train), (X_val, y_val), (X_test, y_test)

In [4]:
def save_set(set_dir, X, y):
    basename = os.path.basename(set_dir)
    if basename != 'test':
        for class_name in set(y):
            os.makedirs(os.path.join(set_dir, class_name), exist_ok=True)
            
        counter = defaultdict(int)
        for image, class_name in tqdm(zip(X, y), desc=f"Saving {os.path.basename(set_dir)} set"):
            cv2.imwrite(os.path.join(set_dir, str(class_name), str(counter[class_name]) + '.jpg'), image)
            counter[class_name] += 1
    else:
        os.makedirs(set_dir, exist_ok=True)
        
        indices = np.arange(len(X))
        np.random.shuffle(indices)
        X_shuffled = [X[i] for i in indices.tolist()]
        y_shuffled = [y[i] for i in indices.tolist()]
        
        counter = 0
        csv_data = []
        for image, class_name in tqdm(zip(X_shuffled, y_shuffled), desc=f"Saving test set"):
            cv2.imwrite(os.path.join(set_dir, str(counter) + '.jpg'), image)
            csv_data.append([str(counter) + '.jpg', class_name])
            counter += 1
            
        csv_headers = ['file_name', 'label']
        with open("output.csv", "w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow(csv_headers)
            writer.writerows(csv_data)

def save_sets(set_dir, X_train, y_train, X_val, y_val, X_test, y_test):
    save_set(os.path.join(set_dir, 'train'), X_train, y_train)
    save_set(os.path.join(set_dir, 'val'), X_val, y_val)
    save_set(os.path.join(set_dir, 'test'), X_test, y_test)

In [5]:
images, labels = load_images("../dataset")
(X_train, y_train), (X_val, y_val), (X_test, y_test) = split_sets(images, labels)
save_sets("../dataset_split", X_train, y_train, X_val, y_val, X_test, y_test)

Loading african_elephant: 100%|██████████| 1300/1300 [00:03<00:00, 376.19it/s]
Loading airliner: 100%|██████████| 1300/1300 [00:02<00:00, 496.29it/s]
Loading banana: 100%|██████████| 1300/1300 [00:03<00:00, 380.84it/s]
Loading convertible_car: 100%|██████████| 1300/1300 [00:03<00:00, 331.75it/s]
Loading golden_retriever: 100%|██████████| 1300/1300 [00:03<00:00, 407.21it/s]
Loading goldfish: 100%|██████████| 1300/1300 [00:03<00:00, 379.47it/s]
Loading parachute: 100%|██████████| 1300/1300 [00:03<00:00, 390.72it/s]
Loading rugby: 100%|██████████| 1300/1300 [00:03<00:00, 355.34it/s]
Loading sunglasses: 100%|██████████| 1300/1300 [00:03<00:00, 403.51it/s]
Loading tiger_cat: 100%|██████████| 1300/1300 [00:04<00:00, 309.10it/s]
Saving train set: 7800it [00:09, 794.27it/s]
Saving val set: 2600it [00:03, 683.73it/s]
Saving test set: 2600it [00:03, 746.02it/s]
