In [1]:
import os
from collections import defaultdict
from pathlib import Path

In [2]:
def count_samples_in_split(split_path: str) -> dict:
    """
    Count samples in each class directory within a split.

    Args:
        split_path: Path to the split directory (train/val/test)

    Returns:
        Dictionary with class names as keys and sample counts as values
    """
    class_counts = defaultdict(int)
    total_samples = 0

    # Get all subdirectories (classes)
    for class_dir in sorted(os.listdir(split_path)):
        class_path = os.path.join(split_path, class_dir)

        # Skip if not a directory
        if not os.path.isdir(class_path):
            continue

        # Count files in this class directory
        file_count = len([f for f in os.listdir(class_path)
                         if os.path.isfile(os.path.join(class_path, f))])

        class_counts[class_dir] = file_count
        total_samples += file_count

    class_counts['TOTAL'] = total_samples
    return class_counts

In [3]:
base_path = "/Users/noahmv/Desktop/noah/studies/Master's/Chalmers/2nd year/sp_1/DML/Project/code/DL_Project_Processed_Data"

splits = ['train', 'val', 'test']
results = {}

print("Dataset Sample Counts")
print("-" * 21)

for split in splits:
    split_path = os.path.join(base_path, split)

    if not os.path.exists(split_path):
        print(f"Warning: {split_path} does not exist")
        continue

    print(f"\n{split.upper()} Split:")
    print("-" * (len(split) + 6))

    class_counts = count_samples_in_split(split_path)
    results[split] = class_counts

    # Print class breakdown
    for class_name, count in class_counts.items():
        if class_name != 'TOTAL':
            print(f"  {class_name}: {count} samples")
        else:
            print(f"  {class_name}: {count} samples")

# Print summary table
print("\n\nSummary Table")
print("-" * 14)
print(f"{'Split':<10} {'Samples':<10}")
print("-" * 21)

for split in splits:
    if split in results:
        print(f"{split.upper():<10} {results[split]['TOTAL']:<10}")

print("-" * 21)
total_all = sum(results[split]['TOTAL'] for split in splits if split in results)
print(f"{'TOTAL':<10} {total_all:<10}")

Dataset Sample Counts
---------------------

TRAIN Split:
-----------
  (0, 0, 0, 0, 0): 400 samples
  (0, 0, 0, 0, 1): 240 samples
  (0, 0, 0, 1, 0): 240 samples
  (0, 0, 1, 0, 0): 240 samples
  (0, 1, 0, 0, 0): 240 samples
  (1, 0, 0, 0, 0): 68 samples
  multiple: 400 samples
  TOTAL: 1828 samples

VAL Split:
---------
  (0, 0, 0, 0, 0): 50 samples
  (0, 0, 0, 0, 1): 30 samples
  (0, 0, 0, 1, 0): 30 samples
  (0, 0, 1, 0, 0): 30 samples
  (0, 1, 0, 0, 0): 30 samples
  (1, 0, 0, 0, 0): 8 samples
  multiple: 50 samples
  TOTAL: 228 samples

TEST Split:
----------
  (0, 0, 0, 0, 0): 50 samples
  (0, 0, 0, 0, 1): 30 samples
  (0, 0, 0, 1, 0): 30 samples
  (0, 0, 1, 0, 0): 30 samples
  (0, 1, 0, 0, 0): 30 samples
  (1, 0, 0, 0, 0): 9 samples
  multiple: 50 samples
  TOTAL: 229 samples


Summary Table
--------------
Split      Samples   
---------------------
TRAIN      1828      
VAL        228       
TEST       229       
---------------------
TOTAL      2285      
