# Model Training

This notebook handles the model training process including:
- Model architecture definition
- Training loop implementation
- Hyperparameter tuning
- Model checkpointing
- Training visualization

In [None]:
# Import necessary libraries
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import datetime
from tensorboard.plugins.hparams import api as hp

# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# GPU Configuration
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

## 1. Load Preprocessed Data

In [None]:
def load_preprocessed_data(data_dir='../processed_data'):
    """
    Load preprocessed satellite image data
    
    Args:
        data_dir (str): Directory containing preprocessed data
    
    Returns:
        tuple: Training, validation, and test datasets
    """
    data_dir = Path(data_dir)
    
    # Load training data
    X_train = np.load(data_dir / 'X_train.npy')
    y_train = np.load(data_dir / 'y_train.npy')
    
    # Load validation data
    X_val = np.load(data_dir / 'X_val.npy')
    y_val = np.load(data_dir / 'y_val.npy')
    
    # Load test data
    X_test = np.load(data_dir / 'X_test.npy')
    y_test = np.load(data_dir / 'y_test.npy')
    
    # Convert labels to categorical
    unique_labels = np.unique(np.concatenate([y_train, y_val, y_test]))
    label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
    
    y_train_cat = np.array([label_to_index[label] for label in y_train])
    y_val_cat = np.array([label_to_index[label] for label in y_val])
    y_test_cat = np.array([label_to_index[label] for label in y_test])
    
    y_train_one_hot = tf.keras.utils.to_categorical(y_train_cat)
    y_val_one_hot = tf.keras.utils.to_categorical(y_val_cat)
    y_test_one_hot = tf.keras.utils.to_categorical(y_test_cat)
    
    return (
        X_train, y_train_one_hot, 
        X_val, y_val_one_hot, 
        X_test, y_test_one_hot, 
        unique_labels
    )

# Load data
X_train, y_train, X_val, y_val, X_test, y_test, unique_labels = load_preprocessed_data()

print("Training data shape:", X_train.shape)
print("Number of classes:", len(unique_labels))
print("Classes:", unique_labels)

## 2. Define Model Architecture

In [None]:
def create_satellite_model(input_shape, num_classes, base_filters=32):
    """
    Create a CNN model for satellite image classification
    
    Args:
        input_shape (tuple): Shape of input images
        num_classes (int): Number of classification categories
        base_filters (int): Base number of filters in conv layers
    
    Returns:
        tf.keras.Model: Compiled classification model
    """
    model = models.Sequential([
        # First Convolutional Block
        layers.Conv2D(base_filters, (3, 3), activation='relu', input_shape=input_shape, padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Second Convolutional Block
        layers.Conv2D(base_filters * 2, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Third Convolutional Block
        layers.Conv2D(base_filters * 4, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Flatten and Dense Layers
        layers.Flatten(),
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        
        # Output Layer
        layers.Dense(num_classes, activation='softmax')
    ])
    
    # Compile the model
    model.compile(
        optimizer=optimizers.Adam(learning_rate=1e-4),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# Create model
model = create_satellite_model(
    input_shape=X_train.shape[1:], 
    num_classes=len(unique_labels)
)

# Model summary
model.summary()

## 3. Training Configuration

In [None]:
# Create log directory for TensorBoard
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

# Define callbacks
tensorboard_callback = callbacks.TensorBoard(
    log_dir=log_dir, 
    histogram_freq=1, 
    profile_batch=0
)

early_stopping = callbacks.EarlyStopping(
    monitor='val_loss', 
    patience=10, 
    restore_best_weights=True
)

model_checkpoint = callbacks.ModelCheckpoint(
    filepath='best_model.keras',
    monitor='val_accuracy', 
    save_best_only=True,
    mode='max'
)

# Learning rate reducer
lr_reducer = callbacks.ReduceLROnPlateau(
    monitor='val_loss', 
    factor=0.5, 
    patience=5, 
    min_lr=1e-6
)

# Combine callbacks
callbacks_list = [
    tensorboard_callback,
    early_stopping,
    model_checkpoint,
    lr_reducer
]

## 4. Training Loop

In [None]:
# Training hyperparameters
BATCH_SIZE = 32
EPOCHS = 50

# Train the model
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=callbacks_list,
    verbose=1
)

# Evaluate on test set
test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f"\nTest Accuracy: {test_accuracy*100:.2f}%")

## 5. Training Visualization

In [None]:
def plot_training_metrics(history):
    """
    Visualize training and validation metrics
    
    Args:
        history (keras.callbacks.History): Training history
    """
    plt.figure(figsize=(12, 4))
    
    # Accuracy subplot
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # Loss subplot
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# Plot training metrics
plot_training_metrics(history)

## 6. Save Trained Model

In [None]:
def save_model_artifacts(model, unique_labels):
    """
    Save model and associated metadata
    
    Args:
        model (tf.keras.Model): Trained model
        unique_labels (array): Array of unique class labels
    """
    # Create output directory
    output_dir = Path('../models')
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save full model
    model.save(output_dir / 'satellite_classifier.keras')
    
    # Save model weights
    model.save_weights(output_dir / 'model_weights.weights.h5')
    
    # Save class labels
    pd.Series(unique_labels).to_csv(output_dir / 'class_labels.csv', index=False)
    
    print("Model artifacts saved successfully.")

# Save model and related artifacts
save_model_artifacts(model, unique_labels)