# Space Images Classifier - Using Kaggle dataset

https://www.kaggle.com/datasets/abhikalpsrivastava15/space-images-category?utm_source=chatgpt.com

### Notebook 4 - Training stage 1 (Frozen backbone)

# Import librairies

In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import sys
import os
import json

# Add the root folder to Python's module search path
sys.path.append(os.path.abspath(os.path.join(".."))) 
# Import the project configuration
from config import DEVICE, OUTPUT_PATH, BATCH_SIZE, NUM_WORKERS, EPOCHS_STAGE1, LEARNING_RATE_STAGE1
from models import SpaceClassifier
from train_utils import train_epoch, validate
from datasets import SpaceImageDataset, train_transforms, val_test_transforms

import shutil
from pathlib import Path
import cv2
from tqdm import tqdm
import random

import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms.functional as TF

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.utils import class_weight

# Load callback/config variables

In [2]:
ROOT_PATH = Path("..")
# Path to load JSON
CONFIG_JSON_PATH = ROOT_PATH / "config_dynamic.json"

try:
    with open(CONFIG_JSON_PATH) as f:
        dynamic_config = json.load(f)
except FileNotFoundError:
    dynamic_config = {}
    
NUM_CLASSES = dynamic_config.get("NUM_CLASSES", 0)
class_names = dynamic_config.get("class_names", [])
split_success = dynamic_config.get("split_success", 5)

# Path to the saved tensor
WEIGHTS_PATH = Path("models") / "class_weights_tensor.pth"

# Load the tensor
class_weights_tensor = torch.load(WEIGHTS_PATH, map_location=DEVICE)

In [3]:
class_weights_tensor

tensor([1.0052, 1.1092, 0.7798, 1.0904, 1.0461, 1.0546], device='mps:0')

# Training stage 1 with frozen backbone

## Create EfficientNet-B0 model

In [4]:
# Create model
model = SpaceClassifier(NUM_CLASSES, pretrained=True)
model = model.to(DEVICE)

print("=" * 80)
print(f"Model created: EfficientNet-B0")
print(f"Device: {DEVICE}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

print("=" * 80)

Model created: EfficientNet-B0
Device: mps
Total parameters: 4,796,290
Trainable parameters: 788,742


In [5]:
print(model.backbone.classifier)

Sequential(
  (0): Dropout(p=0.5, inplace=False)
  (1): Linear(in_features=1280, out_features=512, bias=True)
  (2): ReLU()
  (3): Dropout(p=0.4, inplace=False)
  (4): Linear(in_features=512, out_features=256, bias=True)
  (5): ReLU()
  (6): Dropout(p=0.3, inplace=False)
  (7): Linear(in_features=256, out_features=6, bias=True)
)


## Create Dataset

In [6]:
train_dataset = SpaceImageDataset(OUTPUT_PATH / "train", transform=train_transforms)
val_dataset = SpaceImageDataset(OUTPUT_PATH / "validation", transform=val_test_transforms)



Dataset created from ../space_images_split/train
   -> 765 valid images
7 images skipped (see ../space_images_split/train/skipped_images.txt)
Dataset created from ../space_images_split/validation
   -> 161 valid images
2 images skipped (see ../space_images_split/validation/skipped_images.txt)


## Create Data loaders

In [7]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE,
    shuffle=True, 
    num_workers=NUM_WORKERS,
    pin_memory=True if DEVICE.type == 'mps' else False
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True if DEVICE.type == 'mps' else False
)

## Training the model - Phase 1

In [None]:
if split_success:
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE_STAGE1)
    # Automatically reduces the learning rate by 50% if validation loss stops improving for 3 epochs (for small datasets)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    print("=" * 80)
    print(f"Starting Stage 1 Training")
    print("-" * 80)
    print(f"Estimated time on M1: 30-60 minutes\n")
    print("=" * 80)
    
    history_stage1 = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    best_val_acc = 0.0
    patience_counter = 0
    PATIENCE = 10
    
    for epoch in range(EPOCHS_STAGE1):
        print("=" * 80)
        print(f"Epoch {epoch+1}/{EPOCHS_STAGE1}")
        print("-" * 60)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
        
        # Validate
        val_loss, val_acc, _, _ = validate(model, val_loader, criterion, DEVICE)
        
        # Scheduler step
        scheduler.step(val_loss)
        
        # Save history
        history_stage1['train_loss'].append(train_loss)
        history_stage1['train_acc'].append(train_acc)
        history_stage1['val_loss'].append(val_loss)
        history_stage1['val_acc'].append(val_acc)
        
        print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
        print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc*100:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'models/stage1_best.pth')
            print(f"Best model saved! (Val Acc: {val_acc*100:.2f}%)")
            # Save stage 1 logs to the disk
            torch.save(history_stage1, 'models/history_stage1_del_img.pt')
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= PATIENCE:
            print(f"\nEarly stopping triggered (patience={PATIENCE})")
            break

    print("\nStage 1 training complete!")

else:
    history_stage1 = None

print("=" * 80)

Starting Stage 1 Training
--------------------------------------------------------------------------------
Estimated time on M1: 30-60 minutes

Epoch 1/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:29<00:00,  1.21s/it, loss=1.4653, acc=35.56%]
Validation: 100%|██████████| 6/6 [00:24<00:00,  4.01s/it, loss=2.4556, acc=47.83%]



Train Loss: 1.6289 | Train Acc: 35.56%
Val Loss:   1.2816 | Val Acc:   47.83%
Best model saved! (Val Acc: 47.83%)
Epoch 2/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:27<00:00,  1.17s/it, loss=0.8312, acc=51.24%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.90s/it, loss=3.1586, acc=54.04%]



Train Loss: 1.2509 | Train Acc: 51.24%
Val Loss:   1.1691 | Val Acc:   54.04%
Best model saved! (Val Acc: 54.04%)
Epoch 3/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=1.5532, acc=55.69%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.89s/it, loss=2.7361, acc=54.66%]



Train Loss: 1.2045 | Train Acc: 55.69%
Val Loss:   1.1057 | Val Acc:   54.66%
Best model saved! (Val Acc: 54.66%)
Epoch 4/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.9222, acc=58.56%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.89s/it, loss=2.8848, acc=63.35%]



Train Loss: 1.1029 | Train Acc: 58.56%
Val Loss:   1.0657 | Val Acc:   63.35%
Best model saved! (Val Acc: 63.35%)
Epoch 5/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.8515, acc=60.92%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.89s/it, loss=2.6332, acc=60.87%]



Train Loss: 1.0172 | Train Acc: 60.92%
Val Loss:   1.0463 | Val Acc:   60.87%
Epoch 6/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:27<00:00,  1.15s/it, loss=0.9101, acc=60.78%]
Validation: 100%|██████████| 6/6 [00:24<00:00,  4.01s/it, loss=2.7932, acc=60.87%]



Train Loss: 1.0449 | Train Acc: 60.78%
Val Loss:   1.0408 | Val Acc:   60.87%
Epoch 7/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.9872, acc=60.26%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.90s/it, loss=3.8262, acc=60.87%]



Train Loss: 0.9992 | Train Acc: 60.26%
Val Loss:   1.0595 | Val Acc:   60.87%
Epoch 8/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=1.0120, acc=60.13%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.90s/it, loss=2.9080, acc=63.98%]



Train Loss: 1.0369 | Train Acc: 60.13%
Val Loss:   1.0156 | Val Acc:   63.98%
Best model saved! (Val Acc: 63.98%)
Epoch 9/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=0.9623, acc=60.92%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.98s/it, loss=3.6110, acc=62.11%]



Train Loss: 0.9683 | Train Acc: 60.92%
Val Loss:   1.0016 | Val Acc:   62.11%
Epoch 10/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=1.1488, acc=61.96%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.91s/it, loss=3.2989, acc=63.35%]



Train Loss: 0.9988 | Train Acc: 61.96%
Val Loss:   0.9971 | Val Acc:   63.35%
Epoch 11/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.9439, acc=63.01%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.89s/it, loss=2.9579, acc=57.14%]



Train Loss: 0.9957 | Train Acc: 63.01%
Val Loss:   1.0452 | Val Acc:   57.14%
Epoch 12/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:27<00:00,  1.15s/it, loss=1.0456, acc=62.88%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.97s/it, loss=2.6211, acc=62.73%]



Train Loss: 0.9597 | Train Acc: 62.88%
Val Loss:   1.0086 | Val Acc:   62.73%
Epoch 13/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:27<00:00,  1.17s/it, loss=1.1007, acc=65.75%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.94s/it, loss=3.2140, acc=66.46%]



Train Loss: 0.9382 | Train Acc: 65.75%
Val Loss:   1.0056 | Val Acc:   66.46%
Best model saved! (Val Acc: 66.46%)
Epoch 14/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:27<00:00,  1.15s/it, loss=0.9024, acc=64.05%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.89s/it, loss=2.7715, acc=63.35%]



Train Loss: 0.9128 | Train Acc: 64.05%
Val Loss:   1.0336 | Val Acc:   63.35%
Epoch 15/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.7568, acc=66.27%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.96s/it, loss=3.4995, acc=62.73%]



Train Loss: 0.9022 | Train Acc: 66.27%
Val Loss:   0.9846 | Val Acc:   62.73%
Epoch 16/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:27<00:00,  1.15s/it, loss=1.0082, acc=65.23%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.96s/it, loss=3.1776, acc=65.84%]



Train Loss: 0.9268 | Train Acc: 65.23%
Val Loss:   0.9971 | Val Acc:   65.84%
Epoch 17/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=0.7814, acc=66.01%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.89s/it, loss=3.1234, acc=62.11%]



Train Loss: 0.8767 | Train Acc: 66.01%
Val Loss:   1.0042 | Val Acc:   62.11%
Epoch 18/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:27<00:00,  1.15s/it, loss=0.6975, acc=64.44%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.94s/it, loss=3.6010, acc=65.84%]



Train Loss: 0.8953 | Train Acc: 64.44%
Val Loss:   0.9645 | Val Acc:   65.84%
Epoch 19/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.9465, acc=66.93%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.91s/it, loss=3.4140, acc=63.35%]



Train Loss: 0.8487 | Train Acc: 66.93%
Val Loss:   0.9851 | Val Acc:   63.35%
Epoch 20/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, loss=0.7136, acc=68.89%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.88s/it, loss=3.8031, acc=63.35%]



Train Loss: 0.8131 | Train Acc: 68.89%
Val Loss:   0.9724 | Val Acc:   63.35%
Epoch 21/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.8324, acc=68.10%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.99s/it, loss=3.3189, acc=63.35%]



Train Loss: 0.8261 | Train Acc: 68.10%
Val Loss:   0.9978 | Val Acc:   63.35%
Epoch 22/25
------------------------------------------------------------


Training: 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, loss=1.0563, acc=67.45%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.95s/it, loss=3.3645, acc=63.98%]



Train Loss: 0.8774 | Train Acc: 67.45%
Val Loss:   0.9644 | Val Acc:   63.98%
Epoch 23/25
------------------------------------------------------------


Training:  71%|███████   | 17/24 [00:06<00:01,  3.78it/s, loss=0.6546, acc=65.62%]

# Comparison between skipping images or deleting images methods

In [None]:
import torch
import matplotlib.pyplot as plt

# Load both histories
history_skp = torch.load("models/history_stage1_skp_img.pt")
history_del = torch.load("models/history_stage1_del_img.pt")

# Ensure they have the same keys
print("Keys in history:", history_skp.keys())

# Quick numeric comparison
for key in history_skp.keys():
    skp_vals = history_skp[key]
    del_vals = history_del[key]
    
    print(f"\n=== {key.upper()} ===")
    print(f"Last (skp): {skp_vals[-1]:.4f}")
    print(f"Last (del): {del_vals[-1]:.4f}")
    print(f"Best (skp): {max(skp_vals) if 'acc' in key else min(skp_vals):.4f}")
    print(f"Best (del): {max(del_vals) if 'acc' in key else min(del_vals):.4f}")


In [None]:
epochs_skp = range(1, len(history_skp['train_loss']) + 1)
epochs_del = range(1, len(history_del['train_loss']) + 1)

plt.figure(figsize=(12, 6))

# --- Training Loss ---
plt.subplot(1, 2, 1)
plt.plot(epochs_skp, history_skp['train_loss'], 'b-', label='SKP Train Loss')
plt.plot(epochs_skp, history_skp['val_loss'], 'b--', label='SKP Val Loss')
plt.plot(epochs_del, history_del['train_loss'], 'r-', label='DEL Train Loss')
plt.plot(epochs_del, history_del['val_loss'], 'r--', label='DEL Val Loss')
plt.title("Loss Comparison")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)

# --- Accuracy ---
plt.subplot(1, 2, 2)
plt.plot(epochs_skp, history_skp['train_acc'], 'b-', label='SKP Train Acc')
plt.plot(epochs_skp, history_skp['val_acc'], 'b--', label='SKP Val Acc')
plt.plot(epochs_del, history_del['train_acc'], 'r-', label='DEL Train Acc')
plt.plot(epochs_del, history_del['val_acc'], 'r--', label='DEL Val Acc')
plt.title("Accuracy Comparison")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


- Left plot : DEL method has a more stable val loss curve, slightly lower overall > good sign of better generalization

- For training accuracy, both methods reach ~70–73%, with DEL slightly ahead at some points and same for the validation accuracy, DEL consistently stays above SKP for most epochs. Moreover, SKP method seems to plateau around 55–60%, while DEL reaches ~61–63%. Overall the DEL method tends to be more stable (right plot) suggesting better generalization

- DEL seems to be a better method with :
    1. Slightly lower train/val loss.
    2. Higher and more stable validation accuracy.
    3. Less overfitting behavior.

# End of notebook 4 - Training stage 1