In [None]:
import os
import random
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

# Set the seed to ensure reproducibility
random.seed(42)

# Define the paths to your image and label directories and the number of classes
image_dir = '/path/to/your/image/directory'
label_dir = '/path/to/your/label/directory'
num_classes = 3

# Create a list of the class names
class_names = ['duck', 'goose', 'crane']

# Load the image labels into a pandas DataFrame
labels = pd.read_csv(os.path.join(label_dir, 'image_labels.csv'))

# Extract the site and date information from the image filenames
labels['site'] = labels['filename'].apply(lambda x: x.split('_')[0])
labels['date'] = labels['filename'].apply(lambda x: x.split('_')[1])

# Loop over all sites and collect the image paths and labels
image_paths = []
labels_list = []
for site in labels['site'].unique():
    site_labels = labels[labels['site'] == site]
    for date in site_labels['date'].unique():
        date_labels = site_labels[site_labels['date'] == date]
        for _, row in date_labels.iterrows():
            image_path = os.path.join(image_dir, row['filename'])
            image_paths.append(image_path)
            labels_list.append(row['class_id'])

# Convert the image paths and labels to numpy arrays for stratified sampling
image_paths = np.array(image_paths)
labels = np.array(labels_list)

# Use stratified sampling to split the dataset into training, validation, and test sets
trainval_images, test_images, trainval_labels, test_labels = train_test_split(
    image_paths, labels, test_size=0.15, stratify=labels)

train_images, val_images, train_labels, val_labels = train_test_split(
    trainval_images, trainval_labels, test_size=0.18, stratify=trainval_labels)

# Print the number of images in each set and the class distribution
print(f'Total number of images: {len(image_paths)}')
print(f'Train set: {len(train_images)} images')
print(np.bincount(train_labels))
print(f'Validation set: {len(val_images)} images')
print(np.bincount(val_labels))
print(f'Test set: {len(test_images)} images')
print(np.bincount(test_labels))