In [None]:
import os
import random
import shutil
from pathlib import Path
from sklearn.model_selection import train_test_split
from PIL import Image
import torchvision.transforms.functional as TF

# Paths
data_path = Path("data/data_ver2/")  # We can change the path later
output_root = Path("data/data_ver2/split_data")
train_dir = output_root / "train"
val_dir = output_root / "val"

# Create output dirs
for split_dir in [train_dir, val_dir]:
    split_dir.mkdir(parents=True, exist_ok=True)

# Aggressive augmentation function
def augment_image(img):
    if random.random() > 0.5:
        img = TF.hflip(img)
    if random.random() > 0.5:
        img = TF.vflip(img)
    img = TF.rotate(img, random.uniform(-25, 25))
    img = TF.adjust_brightness(img, random.uniform(0.7, 1.3))
    img = TF.adjust_contrast(img, random.uniform(0.7, 1.3))
    return img

# Collect image paths and labels
all_image_paths = []
all_labels = []

# Replace source_root with data_path
for class_name in sorted(os.listdir(data_path)):
    class_folder = data_path / class_name
    if not class_folder.is_dir():
        continue
    for img_path in class_folder.glob("*"):
        if img_path.suffix.lower() in [".jpg", ".jpeg", ".png"]:
            all_image_paths.append(img_path)
            all_labels.append(class_name)

# Stratified split into train/val
train_paths, val_paths, train_labels, val_labels = train_test_split(
    all_image_paths, all_labels, test_size=0.2, stratify=all_labels, random_state=42
)

# Copy files to split folders
def copy_to_split(paths, labels, split_dir):
    for img_path, label in zip(paths, labels):
        target_dir = split_dir / label
        target_dir.mkdir(parents=True, exist_ok=True)
        shutil.copy(img_path, target_dir / img_path.name)

copy_to_split(train_paths, train_labels, train_dir)
copy_to_split(val_paths, val_labels, val_dir)

# Oversample minority classes in train/ using augmentation (especially folder 3 because it has only 2 tick images)
# Count current class distribution
from collections import Counter
train_counts = Counter(train_labels)
max_count = max(train_counts.values())

print("Initial train class counts:", train_counts)

for class_name in train_counts:
    class_folder = train_dir / class_name
    images = list(class_folder.glob("*"))
    num_to_generate = max_count - train_counts[class_name]

    for i in range(num_to_generate):
        base_img_path = random.choice(images)
        base_img = Image.open(base_img_path).convert("RGB")
        aug_img = augment_image(base_img)
        aug_img.save(class_folder / f"aug_{i}_{base_img_path.name}")

# Summary
print(f"\n Data prepared.")
print(f"Train data saved to: {train_dir}")
print(f"Validation data saved to: {val_dir}")

# Check new distribution
new_counts = {cls: len(list((train_dir / cls).glob("*"))) for cls in os.listdir(train_dir)}
print("Balanced train class counts:", new_counts)

Initial train class counts: Counter({'2': 77, '1': 41, '4': 25, '3': 2})

 Data prepared.
Train data saved to: data/data_ver2/split_data/train
Validation data saved to: data/data_ver2/split_data/val
Balanced train class counts: {'3': 226, '4': 178, '1': 149, '2': 77}


In [None]:
# Create dataloaders
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_dataset = datasets.ImageFolder("data/data_ver2/split_data/train", transform=train_transform)
val_dataset = datasets.ImageFolder("data/data_ver2/split_data/val", transform=val_transform)