# Space Images Classifier - Using Kaggle dataset

https://www.kaggle.com/datasets/abhikalpsrivastava15/space-images-category?utm_source=chatgpt.com

### Notebook 4 - Training stage 1 (Frozen backbone)

# Import librairies

In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import sys
import os
import json

# Add the root folder to Python's module search path
sys.path.append(os.path.abspath(os.path.join(".."))) 
# Import the project configuration
from config import DEVICE, OUTPUT_PATH, BATCH_SIZE, NUM_WORKERS, EPOCHS_STAGE1, LEARNING_RATE_STAGE1
from models import SpaceClassifier
from train_utils import train_epoch, validate
from datasets import SpaceImageDataset, train_transforms, val_test_transforms

import shutil
from pathlib import Path
import cv2
from tqdm import tqdm
import random

import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms.functional as TF

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.utils import class_weight

# Load callback/config variables

In [2]:
ROOT_PATH = Path("..")
# Path to load JSON
CONFIG_JSON_PATH = ROOT_PATH / "config_dynamic.json"

try:
    with open(CONFIG_JSON_PATH) as f:
        dynamic_config = json.load(f)
except FileNotFoundError:
    dynamic_config = {}
    
NUM_CLASSES = dynamic_config.get("NUM_CLASSES", 0)
class_names = dynamic_config.get("class_names", [])
split_success = dynamic_config.get("split_success", 5)

# Path to the saved tensor
WEIGHTS_PATH = Path("models") / "class_weights_tensor.pth"

# Load the tensor
class_weights_tensor = torch.load(WEIGHTS_PATH, map_location=DEVICE)

In [3]:
class_weights_tensor

tensor([1.0052, 1.1092, 0.7798, 1.0904, 1.0461, 1.0546], device='mps:0')

# Training stage 1 with frozen backbone

## Create EfficientNet-B0 model

In [4]:
# Create model
model = SpaceClassifier(NUM_CLASSES, pretrained=True)
model = model.to(DEVICE)

print("=" * 80)
print(f"Model created: EfficientNet-B0")
print(f"Device: {DEVICE}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

print("=" * 80)

Model created: EfficientNet-B0
Device: mps
Total parameters: 4,796,290
Trainable parameters: 788,742


In [5]:
print(model.backbone.classifier)

Sequential(
  (0): Dropout(p=0.4, inplace=False)
  (1): Linear(in_features=1280, out_features=512, bias=True)
  (2): ReLU()
  (3): Dropout(p=0.3, inplace=False)
  (4): Linear(in_features=512, out_features=256, bias=True)
  (5): ReLU()
  (6): Dropout(p=0.2, inplace=False)
  (7): Linear(in_features=256, out_features=6, bias=True)
)


## Create Dataset

In [6]:
train_dataset = SpaceImageDataset(OUTPUT_PATH / "train", transform=train_transforms)
val_dataset = SpaceImageDataset(OUTPUT_PATH / "validation", transform=val_test_transforms)

## Create Data loaders

In [7]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE,
    shuffle=True, 
    num_workers=NUM_WORKERS,
    pin_memory=True if DEVICE.type == 'mps' else False
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True if DEVICE.type == 'mps' else False
)

## Training the model - Phase 1

In [8]:
if split_success:
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE_STAGE1)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

    print("=" * 80)
    print(f"Starting Stage 1 Training")
    print("-" * 80)
    print(f"Estimated time on M1: 30-60 minutes\n")
    print("=" * 80)
    
    history_stage1 = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    best_val_acc = 0.0
    patience_counter = 0
    PATIENCE = 10
    
    for epoch in range(EPOCHS_STAGE1):
        print("=" * 80)
        print(f"Epoch {epoch+1}/{EPOCHS_STAGE1}")
        print("-" * 60)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
        
        # Validate
        val_loss, val_acc, _, _ = validate(model, val_loader, criterion, DEVICE)
        
        # Scheduler step
        scheduler.step(val_loss)
        
        # Save history
        history_stage1['train_loss'].append(train_loss)
        history_stage1['train_acc'].append(train_acc)
        history_stage1['val_loss'].append(val_loss)
        history_stage1['val_acc'].append(val_acc)
        
        print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
        print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc*100:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'models/stage1_best.pth')
            print(f"Best model saved! (Val Acc: {val_acc*100:.2f}%)")
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= PATIENCE:
            print(f"\nEarly stopping triggered (patience={PATIENCE})")
            break
    
    print("\nStage 1 training complete!")
    
    # Load best model
    model.load_state_dict(torch.load('models/stage1_best.pth'))

else:
    history_stage1 = None

print("=" * 80)

Starting Stage 1 Training
--------------------------------------------------------------------------------
Estimated time on M1: 30-60 minutes

Epoch 1/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.14s/it, loss=0.8966, acc=38.47%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.97s/it, loss=2.1977, acc=48.47%]



Train Loss: 1.5849 | Train Acc: 38.47%
Val Loss:   1.2901 | Val Acc:   48.47%
Best model saved! (Val Acc: 48.47%)
Epoch 2/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:27<00:00,  1.11s/it, loss=1.0217, acc=55.96%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.89s/it, loss=2.1168, acc=50.31%]



Train Loss: 1.1535 | Train Acc: 55.96%
Val Loss:   1.2116 | Val Acc:   50.31%
Best model saved! (Val Acc: 50.31%)
Epoch 3/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:27<00:00,  1.11s/it, loss=1.2719, acc=59.59%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.88s/it, loss=2.4240, acc=51.53%]



Train Loss: 1.0511 | Train Acc: 59.59%
Val Loss:   1.2189 | Val Acc:   51.53%
Best model saved! (Val Acc: 51.53%)
Epoch 4/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:27<00:00,  1.12s/it, loss=1.6481, acc=60.75%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.90s/it, loss=2.5256, acc=50.92%]



Train Loss: 0.9989 | Train Acc: 60.75%
Val Loss:   1.1723 | Val Acc:   50.92%
Epoch 5/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:27<00:00,  1.10s/it, loss=2.3716, acc=65.54%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.90s/it, loss=1.3650, acc=53.37%]



Train Loss: 0.9495 | Train Acc: 65.54%
Val Loss:   1.1528 | Val Acc:   53.37%
Best model saved! (Val Acc: 53.37%)
Epoch 6/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:27<00:00,  1.11s/it, loss=0.7461, acc=65.41%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.89s/it, loss=1.7410, acc=56.44%]



Train Loss: 0.9164 | Train Acc: 65.41%
Val Loss:   1.1137 | Val Acc:   56.44%
Best model saved! (Val Acc: 56.44%)
Epoch 7/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.14s/it, loss=1.9419, acc=67.75%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.91s/it, loss=2.1232, acc=57.06%]



Train Loss: 0.9010 | Train Acc: 67.75%
Val Loss:   1.1073 | Val Acc:   57.06%
Best model saved! (Val Acc: 57.06%)
Epoch 8/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:27<00:00,  1.11s/it, loss=1.4759, acc=68.01%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.92s/it, loss=2.0145, acc=55.83%]



Train Loss: 0.8845 | Train Acc: 68.01%
Val Loss:   1.1331 | Val Acc:   55.83%
Epoch 9/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.12s/it, loss=2.0547, acc=65.67%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.90s/it, loss=1.8932, acc=57.06%]



Train Loss: 0.8715 | Train Acc: 65.67%
Val Loss:   1.1153 | Val Acc:   57.06%
Epoch 10/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.13s/it, loss=1.5913, acc=70.98%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.94s/it, loss=1.3646, acc=59.51%]



Train Loss: 0.8393 | Train Acc: 70.98%
Val Loss:   1.1237 | Val Acc:   59.51%
Best model saved! (Val Acc: 59.51%)
Epoch 11/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.13s/it, loss=1.0052, acc=70.85%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.96s/it, loss=1.5138, acc=58.90%]



Train Loss: 0.8224 | Train Acc: 70.85%
Val Loss:   1.1501 | Val Acc:   58.90%
Epoch 12/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.13s/it, loss=1.2525, acc=70.98%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.88s/it, loss=1.6903, acc=57.06%]



Train Loss: 0.7310 | Train Acc: 70.98%
Val Loss:   1.2197 | Val Acc:   57.06%
Epoch 13/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.12s/it, loss=0.8887, acc=68.26%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.87s/it, loss=1.6152, acc=54.60%]



Train Loss: 0.7791 | Train Acc: 68.26%
Val Loss:   1.2065 | Val Acc:   54.60%
Epoch 14/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.14s/it, loss=0.9720, acc=69.95%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.87s/it, loss=1.4256, acc=58.28%]



Train Loss: 0.7740 | Train Acc: 69.95%
Val Loss:   1.1499 | Val Acc:   58.28%
Epoch 15/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:27<00:00,  1.12s/it, loss=1.1144, acc=72.28%]
Validation: 100%|██████████| 6/6 [00:24<00:00,  4.02s/it, loss=1.4179, acc=57.67%]



Train Loss: 0.7312 | Train Acc: 72.28%
Val Loss:   1.1310 | Val Acc:   57.67%
Epoch 16/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.15s/it, loss=1.7703, acc=70.60%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.94s/it, loss=1.7114, acc=58.90%]



Train Loss: 0.7802 | Train Acc: 70.60%
Val Loss:   1.1262 | Val Acc:   58.90%
Epoch 17/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.13s/it, loss=1.3745, acc=74.09%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.90s/it, loss=1.6769, acc=55.83%]



Train Loss: 0.7381 | Train Acc: 74.09%
Val Loss:   1.1937 | Val Acc:   55.83%
Epoch 18/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.13s/it, loss=0.6238, acc=72.41%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.90s/it, loss=1.7710, acc=57.67%]



Train Loss: 0.7039 | Train Acc: 72.41%
Val Loss:   1.1546 | Val Acc:   57.67%
Epoch 19/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.13s/it, loss=0.6617, acc=71.76%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.89s/it, loss=1.5408, acc=57.67%]



Train Loss: 0.7008 | Train Acc: 71.76%
Val Loss:   1.1522 | Val Acc:   57.67%
Epoch 20/25
------------------------------------------------------------


Training: 100%|██████████| 25/25 [00:28<00:00,  1.13s/it, loss=0.5674, acc=77.46%]
Validation: 100%|██████████| 6/6 [00:23<00:00,  3.89s/it, loss=1.2391, acc=57.67%]


Train Loss: 0.6483 | Train Acc: 77.46%
Val Loss:   1.1503 | Val Acc:   57.67%

Early stopping triggered (patience=10)

Stage 1 training complete!





# End of notebook 4 - Training stage 1