# Cat Breed Classification - Improved Training Pipeline

This notebook implements a complete training pipeline with:
- ✅ GlobalAveragePooling2D instead of Flatten
- ✅ Fixed steps_per_epoch calculation
- ✅ Two-stage training (Feature Extraction + Fine-tuning)
- ✅ Comprehensive evaluation with confusion matrix
- ✅ Learning curve visualization
- ✅ Centralized configuration

## 1. Setup and Imports

In [None]:
# Google Colab setup (uncomment if running on Colab)
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/MyDrive/cat-classification

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from datetime import datetime

# TensorFlow imports
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.applications.resnet_v2 import preprocess_input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (
    Dense, GlobalAveragePooling2D, Dropout, Input
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard
)
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score, top_k_accuracy_score
)

# Import configuration
sys.path.append('..')
import config

print(f"TensorFlow version: {tf.__version__}")
print(f"Python version: {sys.version}")

## 2. GPU Configuration

In [None]:
# Configure GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        # Enable memory growth
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"✓ GPU available: {gpus}")
    except RuntimeError as e:
        print(f"⚠ GPU configuration error: {e}")
else:
    print("⚠ No GPU found. Training will run on CPU.")

# Enable mixed precision for faster training
if config.MIXED_PRECISION:
    from tensorflow.keras import mixed_precision
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)
    print("✓ Mixed precision enabled")

## 3. Load Configuration and Setup Paths

In [None]:
# Create necessary directories
config.create_directories()

# Print configuration
config.print_config()

# Get number of classes
try:
    NUM_CLASSES = config.get_num_classes()
    print(f"\n✓ Found {NUM_CLASSES} cat breeds")
except Exception as e:
    raise ValueError(f"Cannot determine number of classes: {e}")

## 4. Data Generators with Augmentation

In [None]:
# Training data generator with augmentation
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    **config.AUGMENTATION_CONFIG
)

# Validation and test data generators (no augmentation)
val_test_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

# Create generators
train_generator = train_datagen.flow_from_directory(
    str(config.TRAIN_DIR),
    target_size=config.IMG_SIZE,
    batch_size=config.BATCH_SIZE,
    class_mode='categorical',
    shuffle=True,
    seed=config.RANDOM_SEED
)

validation_generator = val_test_datagen.flow_from_directory(
    str(config.VAL_DIR),
    target_size=config.IMG_SIZE,
    batch_size=config.BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

test_generator = val_test_datagen.flow_from_directory(
    str(config.TEST_DIR),
    target_size=config.IMG_SIZE,
    batch_size=config.BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

# Save class indices for later inference
class_indices = train_generator.class_indices
with open(config.MODELS_DIR / 'class_indices.json', 'w') as f:
    json.dump(class_indices, f, indent=4)

print(f"\n✓ Data generators created successfully")
print(f"  Training samples: {train_generator.samples}")
print(f"  Validation samples: {validation_generator.samples}")
print(f"  Test samples: {test_generator.samples}")
print(f"  Number of classes: {len(class_indices)}")

## 5. Build Model with GlobalAveragePooling2D

In [None]:
def build_model(num_classes, trainable=False):
    """
    Build cat breed classification model using transfer learning
    
    Args:
        num_classes: Number of cat breeds to classify
        trainable: Whether base model layers are trainable
    
    Returns:
        Compiled Keras model
    """
    # Load base model
    base_model = ResNet50V2(
        weights='imagenet',
        include_top=False,
        input_shape=(config.IMG_WIDTH, config.IMG_HEIGHT, config.IMG_CHANNELS)
    )
    
    # Freeze/unfreeze base model
    base_model.trainable = trainable
    
    # Build model using Functional API
    inputs = Input(shape=(config.IMG_WIDTH, config.IMG_HEIGHT, config.IMG_CHANNELS))
    x = base_model(inputs, training=False)  # Important: set training=False for inference mode
    
    # Use GlobalAveragePooling2D instead of Flatten (much fewer parameters!)
    x = GlobalAveragePooling2D()(x)
    
    # Custom top layers
    x = Dense(config.DENSE_UNITS, activation='relu')(x)
    x = Dropout(config.DROPOUT_RATE)(x)
    
    # Output layer
    outputs = Dense(num_classes, activation='softmax', dtype='float32')(x)  # dtype for mixed precision
    
    # Create model
    model = Model(inputs, outputs)
    
    return model

# Build initial model (frozen base)
model = build_model(NUM_CLASSES, trainable=False)

print("\n" + "="*80)
print("MODEL ARCHITECTURE")
print("="*80)
model.summary()

# Count parameters
total_params = model.count_params()
trainable_params = sum([tf.size(w).numpy() for w in model.trainable_weights])
non_trainable_params = total_params - trainable_params

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {non_trainable_params:,}")

## 6. Compile Model - Stage 1 (Feature Extraction)

In [None]:
# Compile model for Stage 1
model.compile(
    optimizer=Adam(learning_rate=config.LEARNING_RATE_STAGE1),
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top5_accuracy')]
)

print("✓ Model compiled for Stage 1 (Feature Extraction)")
print(f"  Learning rate: {config.LEARNING_RATE_STAGE1}")

## 7. Setup Callbacks

In [None]:
def get_callbacks(stage="stage1"):
    """
    Create callbacks for training
    
    Args:
        stage: Training stage name (for file naming)
    
    Returns:
        List of Keras callbacks
    """
    callbacks = []
    
    # ModelCheckpoint
    checkpoint_path = config.MODELS_DIR / f'cat_classifier_{stage}_best.keras'
    checkpoint = ModelCheckpoint(
        filepath=str(checkpoint_path),
        monitor=config.CHECKPOINT_MONITOR,
        mode=config.CHECKPOINT_MODE,
        save_best_only=config.CHECKPOINT_SAVE_BEST_ONLY,
        save_weights_only=False,
        verbose=1
    )
    callbacks.append(checkpoint)
    
    # EarlyStopping
    early_stop = EarlyStopping(
        monitor=config.EARLY_STOPPING_MONITOR,
        patience=config.EARLY_STOPPING_PATIENCE,
        restore_best_weights=config.EARLY_STOPPING_RESTORE_BEST,
        verbose=1
    )
    callbacks.append(early_stop)
    
    # ReduceLROnPlateau
    reduce_lr = ReduceLROnPlateau(
        monitor=config.REDUCE_LR_MONITOR,
        factor=config.REDUCE_LR_FACTOR,
        patience=config.REDUCE_LR_PATIENCE,
        min_lr=config.REDUCE_LR_MIN_LR,
        verbose=1
    )
    callbacks.append(reduce_lr)
    
    # TensorBoard
    log_dir = config.LOGS_DIR / f"{stage}_{datetime.now().strftime('%Y%m%d-%H%M%S')}"
    tensorboard = TensorBoard(
        log_dir=str(log_dir),
        histogram_freq=1,
        write_graph=True
    )
    callbacks.append(tensorboard)
    
    return callbacks

# Get callbacks for Stage 1
callbacks_stage1 = get_callbacks("stage1")
print("✓ Callbacks configured for Stage 1")

## 8. Train Stage 1 - Feature Extraction (Fixed steps_per_epoch)

In [None]:
# Calculate steps per epoch correctly (use ceiling division)
steps_per_epoch = int(np.ceil(train_generator.samples / config.BATCH_SIZE))
validation_steps = int(np.ceil(validation_generator.samples / config.BATCH_SIZE))

print("\n" + "="*80)
print("STAGE 1: FEATURE EXTRACTION (Base Model Frozen)")
print("="*80)
print(f"Steps per epoch: {steps_per_epoch}")
print(f"Validation steps: {validation_steps}")
print(f"Epochs: {config.EPOCHS_STAGE1}")
print(f"Learning rate: {config.LEARNING_RATE_STAGE1}")
print("\nStarting training...\n")

# Train Stage 1
history_stage1 = model.fit(
    train_generator,
    epochs=config.EPOCHS_STAGE1,
    steps_per_epoch=steps_per_epoch,  # Fixed: use ceiling division
    validation_data=validation_generator,
    validation_steps=validation_steps,
    callbacks=callbacks_stage1,
    verbose=1
)

print("\n✓ Stage 1 training completed!")

# Save history
history_df_stage1 = pd.DataFrame(history_stage1.history)
history_df_stage1.to_csv(config.OUTPUTS_DIR / 'training_history_stage1.csv', index=False)
print(f"✓ Training history saved to {config.OUTPUTS_DIR / 'training_history_stage1.csv'}")

## 9. Visualize Stage 1 Training

In [None]:
def plot_training_history(history, stage="stage1"):
    """
    Plot training and validation metrics
    
    Args:
        history: Training history object or DataFrame
        stage: Training stage name
    """
    if hasattr(history, 'history'):
        history = history.history
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'Training Metrics - {stage.upper()}', fontsize=16, fontweight='bold')
    
    # Loss
    axes[0, 0].plot(history['loss'], label='Train Loss', linewidth=2)
    axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
    axes[0, 0].set_title('Model Loss', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[0, 1].plot(history['accuracy'], label='Train Accuracy', linewidth=2)
    axes[0, 1].plot(history['val_accuracy'], label='Val Accuracy', linewidth=2)
    axes[0, 1].set_title('Model Accuracy', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Top-5 Accuracy
    if 'top5_accuracy' in history:
        axes[1, 0].plot(history['top5_accuracy'], label='Train Top-5 Acc', linewidth=2)
        axes[1, 0].plot(history['val_top5_accuracy'], label='Val Top-5 Acc', linewidth=2)
        axes[1, 0].set_title('Top-5 Accuracy', fontsize=14, fontweight='bold')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Top-5 Accuracy')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
    
    # Learning Rate (if available)
    if 'lr' in history:
        axes[1, 1].plot(history['lr'], label='Learning Rate', linewidth=2, color='red')
        axes[1, 1].set_title('Learning Rate', fontsize=14, fontweight='bold')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('LR')
        axes[1, 1].set_yscale('log')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
    else:
        axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.savefig(config.PLOTS_DIR / f'training_history_{stage}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\n✓ Plot saved to {config.PLOTS_DIR / f'training_history_{stage}.png'}")

# Plot Stage 1 history
plot_training_history(history_stage1, "stage1")

## 10. Stage 2 - Fine-tuning (Unfreeze top layers)

In [None]:
print("\n" + "="*80)
print("STAGE 2: FINE-TUNING (Unfreezing top layers)")
print("="*80)

# Get the base model from our model
base_model = model.layers[1]  # ResNet50V2 is the second layer

# Unfreeze the top layers
base_model.trainable = True

# Freeze all layers except the last N
for layer in base_model.layers[:-config.UNFREEZE_LAYERS]:
    layer.trainable = False

print(f"Base model has {len(base_model.layers)} layers")
print(f"Unfreezing last {config.UNFREEZE_LAYERS} layers")

# Count trainable parameters
trainable_params = sum([tf.size(w).numpy() for w in model.trainable_weights])
total_params = model.count_params()
print(f"\nTrainable parameters: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")

# Recompile with lower learning rate
model.compile(
    optimizer=Adam(learning_rate=config.LEARNING_RATE_STAGE2),
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top5_accuracy')]
)

print(f"\n✓ Model recompiled for fine-tuning")
print(f"  Learning rate: {config.LEARNING_RATE_STAGE2} (10x lower than Stage 1)")

In [None]:
# Get callbacks for Stage 2
callbacks_stage2 = get_callbacks("stage2")

print("\nStarting Stage 2 training...\n")

# Train Stage 2
history_stage2 = model.fit(
    train_generator,
    epochs=config.EPOCHS_STAGE2,
    steps_per_epoch=steps_per_epoch,
    validation_data=validation_generator,
    validation_steps=validation_steps,
    callbacks=callbacks_stage2,
    verbose=1
)

print("\n✓ Stage 2 training completed!")

# Save history
history_df_stage2 = pd.DataFrame(history_stage2.history)
history_df_stage2.to_csv(config.OUTPUTS_DIR / 'training_history_stage2.csv', index=False)
print(f"✓ Training history saved to {config.OUTPUTS_DIR / 'training_history_stage2.csv'}")

In [None]:
# Plot Stage 2 history
plot_training_history(history_stage2, "stage2")

## 11. Test Set Evaluation

In [None]:
print("\n" + "="*80)
print("TEST SET EVALUATION")
print("="*80)

# Reset test generator
test_generator.reset()

# Get predictions
print("\nGenerating predictions on test set...")
predictions = model.predict(test_generator, steps=int(np.ceil(test_generator.samples / config.BATCH_SIZE)), verbose=1)
predicted_classes = np.argmax(predictions, axis=1)

# Get true labels
true_classes = test_generator.classes

# Calculate metrics
test_accuracy = accuracy_score(true_classes, predicted_classes)
print(f"\n✓ Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")

# Top-5 accuracy
top5_acc = top_k_accuracy_score(true_classes, predictions, k=5)
print(f"✓ Test Top-5 Accuracy: {top5_acc:.4f} ({top5_acc*100:.2f}%)")

## 12. Classification Report

In [None]:
# Get class names
class_names = list(train_generator.class_indices.keys())

# Generate classification report
print("\n" + "="*80)
print("CLASSIFICATION REPORT")
print("="*80)
report = classification_report(true_classes, predicted_classes, target_names=class_names)
print(report)

# Save report
with open(config.REPORTS_DIR / 'classification_report.txt', 'w') as f:
    f.write(report)
print(f"\n✓ Classification report saved to {config.REPORTS_DIR / 'classification_report.txt'}")

## 13. Confusion Matrix

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(true_classes, predicted_classes)

# Plot confusion matrix
plt.figure(figsize=(20, 18))
sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names,
            cbar_kws={'label': 'Number of Predictions'})
plt.title('Confusion Matrix - Cat Breed Classification', fontsize=16, fontweight='bold', pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(config.PLOTS_DIR / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\n✓ Confusion matrix saved to {config.PLOTS_DIR / 'confusion_matrix.png'}")

# Calculate per-class accuracy
per_class_accuracy = cm.diagonal() / cm.sum(axis=1)
accuracy_df = pd.DataFrame({
    'Breed': class_names,
    'Accuracy': per_class_accuracy,
    'Correct': cm.diagonal(),
    'Total': cm.sum(axis=1)
})
accuracy_df = accuracy_df.sort_values('Accuracy', ascending=False)

print("\nTop 10 Best Performing Breeds:")
print(accuracy_df.head(10).to_string(index=False))

print("\nTop 10 Worst Performing Breeds:")
print(accuracy_df.tail(10).to_string(index=False))

# Save accuracy dataframe
accuracy_df.to_csv(config.REPORTS_DIR / 'per_class_accuracy.csv', index=False)
print(f"\n✓ Per-class accuracy saved to {config.REPORTS_DIR / 'per_class_accuracy.csv'}")

## 14. Save Final Model

In [None]:
# Save final model
final_model_path = config.MODELS_DIR / 'cat_breed_classifier_final.keras'
model.save(str(final_model_path))
print(f"\n✓ Final model saved to {final_model_path}")

# Save model summary
with open(config.REPORTS_DIR / 'model_summary.txt', 'w') as f:
    model.summary(print_fn=lambda x: f.write(x + '\n'))
print(f"✓ Model summary saved to {config.REPORTS_DIR / 'model_summary.txt'}")

# Create training summary
summary = {
    'timestamp': datetime.now().isoformat(),
    'model_architecture': config.BASE_MODEL_NAME,
    'num_classes': NUM_CLASSES,
    'total_parameters': int(total_params),
    'trainable_parameters': int(trainable_params),
    'stage1': {
        'epochs': len(history_stage1.history['loss']),
        'best_val_accuracy': float(max(history_stage1.history['val_accuracy'])),
        'best_val_loss': float(min(history_stage1.history['val_loss']))
    },
    'stage2': {
        'epochs': len(history_stage2.history['loss']),
        'best_val_accuracy': float(max(history_stage2.history['val_accuracy'])),
        'best_val_loss': float(min(history_stage2.history['val_loss']))
    },
    'test_metrics': {
        'accuracy': float(test_accuracy),
        'top5_accuracy': float(top5_acc)
    }
}

with open(config.REPORTS_DIR / 'training_summary.json', 'w') as f:
    json.dump(summary, f, indent=4)

print(f"✓ Training summary saved to {config.REPORTS_DIR / 'training_summary.json'}")
print("\n" + "="*80)
print("TRAINING COMPLETE!")
print("="*80)
print(json.dumps(summary, indent=2))