# Space Images Classifier - Using Kaggle dataset

https://www.kaggle.com/datasets/abhikalpsrivastava15/space-images-category?utm_source=chatgpt.com

### This notebook aims for data preparation

# Import librairies

In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import sys
import os
import json
# Add the root folder to Python's module search path
sys.path.append(os.path.abspath(os.path.join(".."))) 
# Import the project configuration
from config import DEVICE, ORIGINAL_DATA_PATH, OUTPUT_PATH, IMG_SIZE, BATCH_SIZE, NUM_WORKERS, SEED, TRAIN_RATIO, VAL_RATIO, TEST_RATIO

import shutil
from pathlib import Path
import cv2
from tqdm import tqdm
import random

import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms.functional as TF

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.utils import class_weight

# Data preparation

In [2]:
# Create Train/Val/Test folders (Only executed once)

def create_train_val_test_split(original_path, output_path, train_ratio, val_ratio, test_ratio, seed=42):
    """
    Create stratified train/validation/test split
    
    This function:
    1. Preserves class distribution in all splits
    2. Ensures no overlapping between train/val/test
    3. Creates separate directories for each split
    4. Is reproducible (same split every time with same seed)
    """
    
    if not original_path.exists():
        print(f"Error : Dataset not found at {original_path}")
        print("Please download from: https://www.kaggle.com/datasets/abhikalpsrivastava15/space-images-category")
        return False
    
    # Check if split already exists
    if output_path.exists():
        response = input(f"{output_path} already exists. Recreate split? [YES/NO]: ")
        if response.lower() != 'yes':
            print("Using existing split")
            return True
        shutil.rmtree(output_path)
    
    # Get all classes
    classes = sorted([d.name for d in original_path.iterdir() if d.is_dir()])
    
    if not classes:
        print("Error : No class folders found in dataset")
        return False
    
    print('=' * 80)
    print(f"Found {len(classes)} classes: {classes}")

    # Create output directories
    train_path = output_path / "train"
    val_path = output_path / "validation"
    test_path = output_path / "test"
    
    for split_path in [train_path, val_path, test_path]:
        for cls in classes:
            (split_path / cls).mkdir(parents=True, exist_ok=True)
    
    print(f"Created directory structure")
    
    # Split each class
    split_summary = []
    
    print(f"Splitting images (seed={seed} for reproducibility)")

    # Loop through classes with progress bar
    for cls in tqdm(classes, desc="Processing classes"):
        cls_path = original_path / cls
        
        # Get all images
        images = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
            images.extend(list(cls_path.glob(ext)))
        
        if not images:
            print(f"No images found for class: {cls}")
            continue
        
        # Shuffle with seed for reproducibility
        np.random.seed(seed)
        images = np.array(images)
        np.random.shuffle(images)
        
        # Calculate split indices
        n_total = len(images)
        n_train = int(n_total * train_ratio)
        n_val = int(n_total * val_ratio)
        n_test = n_total - n_train - n_val
        
        # Split images
        train_images = images[:n_train]
        val_images = images[n_train:n_train+n_val]
        test_images = images[n_train+n_val:]
        
        # Copy images to respective folders
        for img in train_images:
            shutil.copy2(img, train_path / cls / img.name)
        
        for img in val_images:
            shutil.copy2(img, val_path / cls / img.name)
        
        for img in test_images:
            shutil.copy2(img, test_path / cls / img.name)
        
        split_summary.append({
            'class': cls,
            'total': n_total,
            'train': n_train,
            'validation': n_val,
            'test': n_test,
            'train_pct': n_train/n_total*100,
            'val_pct': n_val/n_total*100,
            'test_pct': n_test/n_total*100
        })
    
    # Display summary
    df_summary = pd.DataFrame(split_summary)
    print(f"\n{'=' * 80}")
    print("Overview split summary")
    print('-' * 80)
    print(df_summary.to_string(index=False))
    print('=' * 80)

    # Save summary
    df_summary.to_csv(output_path / 'split_summary.csv', index=False)
    print(f"Split summary saved: {output_path / 'split_summary.csv'}")
    
    # Verify no overlap
    print("Verifying that there is no data leaking")
    train_files = set([f.name for f in (train_path / classes[0]).glob("*")])
    val_files = set([f.name for f in (val_path / classes[0]).glob("*")])
    test_files = set([f.name for f in (test_path / classes[0]).glob("*")])
    
    overlap_train_val = train_files & val_files
    overlap_train_test = train_files & test_files
    overlap_val_test = val_files & test_files
    
    if overlap_train_val or overlap_train_test or overlap_val_test:
        print("Warning : Data leaking detected")
    else:
        print("Verified : No overlap between train/val/test sets")
        print('=' * 80)
    
    # Visualize split distribution
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Absolute counts
    x = np.arange(len(classes))
    width = 0.25
    
    axes[0].bar(x - width, df_summary['train'], width, label='Train')
    axes[0].bar(x, df_summary['validation'], width, label='Validation')
    axes[0].bar(x + width, df_summary['test'], width, label='Test')
    axes[0].set_xlabel('Class', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('Number of images', fontsize=12, fontweight='bold')
    axes[0].set_title('Image Distribution across splits', fontsize=14, fontweight='bold')
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(classes, rotation=45, ha='right')
    axes[0].legend()
    axes[0].grid(axis='y', alpha=0.3)
    
    # Percentages
    axes[1].bar(x - width, df_summary['train_pct'], width, label='Train')
    axes[1].bar(x, df_summary['val_pct'], width, label='Validation')
    axes[1].bar(x + width, df_summary['test_pct'], width, label='Test')
    axes[1].set_xlabel('Class', fontsize=12, fontweight='bold')
    axes[1].set_ylabel('Percentage (%)', fontsize=12, fontweight='bold')
    axes[1].set_title('Percentage Distribution', fontsize=14, fontweight='bold')
    axes[1].set_xticks(x)
    axes[1].set_xticklabels(classes, rotation=45, ha='right')
    axes[1].legend()
    axes[1].grid(axis='y', alpha=0.3)
    axes[1].axhline(y=70, color='gray', linestyle='--', alpha=0.5)
    axes[1].axhline(y=15, color='gray', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.savefig(output_path / 'split_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print('=' * 80)
    print(f"Split visualization saved: {output_path / 'split_visualization.png'}")
    print("Data split complete with no data leaking")
    print('=' * 80)
    
    return True

In [3]:
# Create the split
split_success = create_train_val_test_split(
    ORIGINAL_DATA_PATH, 
    OUTPUT_PATH, 
    TRAIN_RATIO, 
    VAL_RATIO, 
    TEST_RATIO, 
    SEED
)

Using existing split


# Custom PyTorch Datasets

In [11]:
class SpaceImageDataset(Dataset):
    """Custom PyTorch Dataset for space images"""
    
    # Call the initializer/constructor to set up everything the object needs
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.classes = sorted([d.name for d in self.root_dir.iterdir() if d.is_dir()])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        # Load all image paths and labels
        self.samples = []
        for cls in self.classes:
            cls_path = self.root_dir / cls
            for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
                for img_path in cls_path.glob(ext):
                    self.samples.append((img_path, self.class_to_idx[cls]))
    
    # Number of samples (Needed for batching)
    def __len__(self):
        return len(self.samples)
    
    # Called when PyTorch wants an image
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Load image and converts BGR to RGB
        image = cv2.imread(str(img_path))
        
        if image is None:
            # Skip or return a black image
            image = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [5]:
# Define transforms
train_transforms = transforms.Compose([
    # ToPILImage() because of cv2.imread() defined earlier (returned a NumPy array but need a PIL Image)
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomRotation(45),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    # Normalize for faster training, better converging
    # Pre-trained models like ResNet, VGG or EfficientNet were trained on ImageNet and these number are
    # the channel-wise mean and std of ImageNet images
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [10]:
try:
    # Create datasets
    train_dataset = SpaceImageDataset(OUTPUT_PATH / "train", transform=train_transforms)
    val_dataset = SpaceImageDataset(OUTPUT_PATH / "validation", transform=val_test_transforms)
    test_dataset = SpaceImageDataset(OUTPUT_PATH / "test", transform=val_test_transforms)
    
    NUM_CLASSES = len(train_dataset.classes)
    class_names = train_dataset.classes

    ROOT_PATH = Path("..")
    # Path to save JSON
    CONFIG_JSON_PATH = ROOT_PATH / "config_dynamic.json"

    # Save to a JSON file for scalability
    data = {
        "NUM_CLASSES": NUM_CLASSES,
        "class_names": class_names,
        "split_success": split_success
    }
    with open(CONFIG_JSON_PATH, "w") as f:
        json.dump(data, f)
    print("=" * 80)
    print(f"Dynamic config json file created with NUM_CLASSES, class_names, split_success variables added")
    
    print(f"Datasets created:")
    print(f"  Training:   {len(train_dataset):5d} images")
    print(f"  Validation: {len(val_dataset):5d} images")
    print(f"  Test:       {len(test_dataset):5d} images")
    print("=" * 80)
    print(f"Classes ({NUM_CLASSES}): {class_names}")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE,
        # Randomized order for training
        shuffle=True, 
        num_workers=NUM_WORKERS,
        # Optimization for faster GPU transfer on M1 MacBook Chip
        pin_memory=True if DEVICE.type == 'mps' else False
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=NUM_WORKERS,
        pin_memory=True if DEVICE.type == 'mps' else False
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=NUM_WORKERS,
        pin_memory=True if DEVICE.type == 'mps' else False
    )
    
    print(f"DataLoaders created (batch_size={BATCH_SIZE})")
    print("=" * 80)
    
    # Calculate class weights to balance imbalanced datasets
    labels = [label for _, label in train_dataset.samples]
    class_weights_array = class_weight.compute_class_weight(
        'balanced',
        classes=np.unique(labels),
        y=labels
    )

    # Convert weights to a PyTorch tensor on the current device (mps here)
    class_weights_tensor = torch.FloatTensor(class_weights_array).to(DEVICE)
    torch.save(class_weights_tensor, "models/class_weights_tensor.pth")
    print("Class weights calculated")
    print("Class weights tensor:", class_weights_tensor)
    
except Exception as e:
    print("Cannot create datasets or DataLoaders")
    print(f"Error: {e}")

Dynamic config json file created with NUM_CLASSES, class_names, split_success variables added
Datasets created:
  Training:     772 images
  Validation:   163 images
  Test:         172 images
Classes (6): ['constellation', 'cosmos space', 'galaxies', 'nebula', 'planets', 'stars']
DataLoaders created (batch_size=32)
Class weights calculated
Class weights tensor: tensor([1.0052, 1.1092, 0.7798, 1.0904, 1.0461, 1.0546], device='mps:0')


# End of notebook 2 - Data preparation