# InfraOwl Model Training

This notebook handles the complete model training pipeline for InfraOwl infrastructure detection.

## Training Pipeline
- Model architecture setup (EfficientNet-Lite/MobileNet)
- Transfer learning configuration
- Training with callbacks and monitoring
- Model evaluation and visualization
- TensorFlow Lite conversion for mobile deployment

In [None]:
import sys
sys.path.append('../scripts')

import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Import our training modules
from train_model import InfraOwlTrainer
from convert_to_tflite import TFLiteConverter

print("🎯 InfraOwl Model Training Notebook")
print("==================================")
print(f"TensorFlow version: {tf.__version__}")

# Check GPU availability
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    print(f"🚀 GPU available: {len(gpus)} device(s)")
    for gpu in gpus:
        print(f"  {gpu}")
else:
    print("💻 Running on CPU")

## 1. Training Configuration

In [None]:
# Load and display configuration
with open('../configs/training_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("📋 Training Configuration:")
print(f"  Model Architecture: {config['model']['architecture']}")
print(f"  Input Size: {config['model']['input_size']}")
print(f"  Number of Classes: {config['model']['num_classes']}")
print(f"  Epochs: {config['training']['epochs']}")
print(f"  Batch Size: {config['training']['batch_size']}")
print(f"  Learning Rate: {config['training']['learning_rate']}")
print(f"  Optimizer: {config['training']['optimizer']}")
print(f"  Data Augmentation: {config['data']['augmentation']['enabled']}")

print(f"\n🏷️  Classes: {config['classes']}")

## 2. Data Loading and Inspection

In [None]:
# Check processed data availability
processed_data_path = Path('../data/processed')

if not processed_data_path.exists():
    print("❌ Processed data not found!")
    print("Please run the data preprocessing notebook first.")
    print("File: 02_data_preprocessing.ipynb")
else:
    print("✅ Processed data found")
    
    # Display data statistics
    for split in ['train', 'validation', 'test']:
        split_path = processed_data_path / split
        if split_path.exists():
            total_images = 0
            print(f"\n📊 {split.capitalize()} Data:")
            
            for class_dir in split_path.iterdir():
                if class_dir.is_dir():
                    class_count = len(list(class_dir.glob('*.jpg')))
                    total_images += class_count
                    print(f"  {class_dir.name}: {class_count} images")
            
            print(f"  Total: {total_images} images")

In [None]:
# Initialize trainer and setup data generators
if processed_data_path.exists():
    print("🔄 Initializing trainer and data generators...")
    
    trainer = InfraOwlTrainer('../configs/training_config.yaml')
    
    print("✅ Data generators created successfully!")
    print(f"📊 Training samples: {trainer.train_generator.samples}")
    print(f"📊 Validation samples: {trainer.val_generator.samples}")
    print(f"📊 Test samples: {trainer.test_generator.samples}")
    print(f"📊 Classes: {list(trainer.train_generator.class_indices.keys())}")
else:
    print("⏭️  Skipping trainer initialization (no processed data)")

## 3. Model Architecture

In [None]:
# Create and examine model architecture
if processed_data_path.exists():
    print("🏗️  Creating model architecture...")
    
    model = trainer.create_model()
    model = trainer.compile_model(model)
    
    print("\n📋 Model Summary:")
    model.summary()
    
    # Visualize model architecture
    print("\n📊 Model Structure:")
    print(f"Total parameters: {model.count_params():,}")
    
    # Count trainable vs non-trainable parameters
    trainable_params = sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])
    non_trainable_params = sum([tf.keras.backend.count_params(w) for w in model.non_trainable_weights])
    
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Non-trainable parameters: {non_trainable_params:,}")
    
    # Plot model architecture (if possible)
    try:
        tf.keras.utils.plot_model(
            model, 
            to_file='../outputs/model_architecture.png',
            show_shapes=True,
            show_layer_names=True,
            rankdir='TB',
            dpi=150
        )
        print("📈 Model architecture diagram saved to ../outputs/model_architecture.png")
    except Exception as e:
        print(f"⚠️  Could not generate architecture diagram: {e}")
else:
    print("⏭️  Skipping model creation (no processed data)")

## 4. Data Augmentation Visualization

In [None]:
# Visualize training data with augmentation
if processed_data_path.exists() and config['data']['augmentation']['enabled']:
    print("🔄 Visualizing data augmentation...")
    
    # Get a batch of training data
    sample_batch = next(trainer.train_generator)
    sample_images, sample_labels = sample_batch
    
    # Display augmented samples
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    for i in range(min(8, len(sample_images))):
        row = i // 4
        col = i % 4
        
        # Denormalize image for display
        img = sample_images[i]
        if img.max() <= 1.0:  # If normalized
            img = (img * 255).astype('uint8')
        
        axes[row, col].imshow(img)
        
        # Get class name
        class_idx = np.argmax(sample_labels[i])
        class_name = list(trainer.train_generator.class_indices.keys())[class_idx]
        
        axes[row, col].set_title(f'{class_name}')
        axes[row, col].axis('off')
    
    plt.suptitle('Training Data with Augmentation')
    plt.tight_layout()
    plt.show()
    
    # Reset generator
    trainer.train_generator.reset()
elif processed_data_path.exists():
    print("ℹ️  Data augmentation is disabled")
else:
    print("⏭️  Skipping augmentation visualization (no processed data)")

## 5. Training Callbacks Setup

In [None]:
# Setup and display training callbacks
if processed_data_path.exists():
    print("⚙️  Setting up training callbacks...")
    
    callbacks = trainer.setup_callbacks()
    
    print(f"📋 Configured Callbacks: {len(callbacks)}")
    for i, callback in enumerate(callbacks, 1):
        callback_name = callback.__class__.__name__
        print(f"  {i}. {callback_name}")
        
        # Show specific callback configurations
        if callback_name == 'ModelCheckpoint':
            print(f"     Saving best model to: {callback.filepath}")
        elif callback_name == 'EarlyStopping':
            print(f"     Monitoring: {callback.monitor}, Patience: {callback.patience}")
        elif callback_name == 'ReduceLROnPlateau':
            print(f"     Reducing LR by factor {callback.factor} with patience {callback.patience}")
        elif callback_name == 'TensorBoard':
            print(f"     Logging to: {callback.log_dir}")
    
    print("\n💡 Monitor training with TensorBoard:")
    print("   tensorboard --logdir ../logs")
else:
    print("⏭️  Skipping callbacks setup (no processed data)")

## 6. Model Training

In [None]:
# Train the model
if processed_data_path.exists():
    print("🚀 Starting model training...")
    print(f"Training for {config['training']['epochs']} epochs")
    print("This may take several minutes to hours depending on:")
    print("  • Dataset size")
    print("  • Model complexity")
    print("  • Hardware (CPU vs GPU)")
    print("\nTraining progress will be displayed below...\n")
    
    try:
        # Train the model
        history = trainer.train_model(model, callbacks)
        
        print("\n✅ Training completed successfully!")
        
        # Load best model from checkpoint
        best_model_path = Path('../models/checkpoints/best_model.h5')
        if best_model_path.exists():
            model = keras.models.load_model(str(best_model_path))
            print("✅ Loaded best model from checkpoint")
        
    except Exception as e:
        print(f"❌ Training failed: {e}")
        history = None
        
else:
    print("⏭️  Skipping training (no processed data)")
    history = None

## 7. Training Results Visualization

In [None]:
# Plot training history
if history is not None:
    print("📈 Visualizing training results...")
    
    # Create comprehensive training plots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot 1: Accuracy
    axes[0, 0].plot(history.history['accuracy'], label='Training Accuracy', linewidth=2)
    axes[0, 0].plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2)
    axes[0, 0].set_title('Model Accuracy')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Loss
    axes[0, 1].plot(history.history['loss'], label='Training Loss', linewidth=2)
    axes[0, 1].plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
    axes[0, 1].set_title('Model Loss')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 3: Learning Rate (if available)
    if 'lr' in history.history:
        axes[1, 0].plot(history.history['lr'], linewidth=2, color='orange')
        axes[1, 0].set_title('Learning Rate Schedule')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].set_yscale('log')
        axes[1, 0].grid(True, alpha=0.3)
    else:
        axes[1, 0].text(0.5, 0.5, 'Learning Rate\nNot Tracked', 
                       ha='center', va='center', transform=axes[1, 0].transAxes)
        axes[1, 0].set_title('Learning Rate Schedule')
    
    # Plot 4: Training Summary
    final_train_acc = history.history['accuracy'][-1]
    final_val_acc = history.history['val_accuracy'][-1]
    best_val_acc = max(history.history['val_accuracy'])
    
    summary_text = f"""Training Summary:
    
Epochs: {len(history.history['accuracy'])}
    Final Train Accuracy: {final_train_acc:.3f}
    Final Val Accuracy: {final_val_acc:.3f}
    Best Val Accuracy: {best_val_acc:.3f}
    
    Model: {config['model']['architecture']}
    Batch Size: {config['training']['batch_size']}
    Learning Rate: {config['training']['learning_rate']}"""
    
    axes[1, 1].text(0.1, 0.9, summary_text, transform=axes[1, 1].transAxes, 
                    fontsize=12, verticalalignment='top', fontfamily='monospace')
    axes[1, 1].set_title('Training Summary')
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.savefig('../outputs/training_results.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"📊 Best validation accuracy: {best_val_acc:.4f}")
    print(f"📈 Training plots saved to ../outputs/training_results.png")
    
else:
    print("⏭️  Skipping training visualization (no training history)")

## 8. Model Evaluation

In [None]:
# Evaluate model on test data
if history is not None and processed_data_path.exists():
    print("📊 Evaluating model on test data...")
    
    try:
        # Evaluate on test set
        test_loss, test_accuracy = trainer.evaluate_model(model)
        
        print(f"\n🎯 Final Test Results:")
        print(f"  Test Accuracy: {test_accuracy:.4f}")
        print(f"  Test Loss: {test_loss:.4f}")
        
        # Generate predictions for confusion matrix
        print("\n🔮 Generating predictions...")
        predictions = model.predict(trainer.test_generator, verbose=1)
        predicted_classes = np.argmax(predictions, axis=1)
        true_classes = trainer.test_generator.classes
        
        # Plot confusion matrix
        from sklearn.metrics import confusion_matrix, classification_report
        
        cm = confusion_matrix(true_classes, predicted_classes)
        class_names = list(trainer.test_generator.class_indices.keys())
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names)
        plt.title('Confusion Matrix - Test Set')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.savefig('../outputs/confusion_matrix.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Classification report
        report = classification_report(true_classes, predicted_classes, 
                                     target_names=class_names)
        print("\n📋 Classification Report:")
        print(report)
        
        # Save detailed results
        with open('../outputs/test_results.txt', 'w') as f:
            f.write(f"InfraOwl Model Test Results\n")
            f.write(f"========================\n\n")
            f.write(f"Test Accuracy: {test_accuracy:.4f}\n")
            f.write(f"Test Loss: {test_loss:.4f}\n\n")
            f.write(f"Classification Report:\n")
            f.write(report)
        
        print("📄 Detailed results saved to ../outputs/test_results.txt")
        
    except Exception as e:
        print(f"❌ Evaluation failed: {e}")
        
else:
    print("⏭️  Skipping evaluation (no trained model)")

## 9. Model Saving

In [None]:
# Save the final model
if history is not None:
    print("💾 Saving trained model...")
    
    try:
        trainer.save_model(model)
        
        # Verify saved model
        saved_model_path = Path('../models/saved_models/infraowl_model.h5')
        if saved_model_path.exists():
            model_size = saved_model_path.stat().st_size / (1024 * 1024)  # MB
            print(f"✅ Model saved successfully!")
            print(f"📁 Location: {saved_model_path}")
            print(f"📏 Size: {model_size:.2f} MB")
        
    except Exception as e:
        print(f"❌ Model saving failed: {e}")
        
else:
    print("⏭️  Skipping model saving (no trained model)")

## 10. TensorFlow Lite Conversion

In [None]:
# Convert to TensorFlow Lite for mobile deployment
if history is not None:
    print("📱 Converting to TensorFlow Lite...")
    
    try:
        converter = TFLiteConverter('../configs/training_config.yaml')
        converter.run_conversion()
        
        print("\n✅ TensorFlow Lite conversion completed!")
        
        # Show TFLite models
        tflite_dir = Path('../models/tflite_models')
        if tflite_dir.exists():
            print("\n📱 TensorFlow Lite Models:")
            for tflite_file in tflite_dir.glob('*.tflite'):
                size_mb = tflite_file.stat().st_size / (1024 * 1024)
                print(f"  {tflite_file.name}: {size_mb:.2f} MB")
        
    except Exception as e:
        print(f"❌ TFLite conversion failed: {e}")
        print("You can manually convert later using:")
        print("python ../scripts/convert_to_tflite.py")
        
else:
    print("⏭️  Skipping TFLite conversion (no trained model)")

## 11. Training Summary and Next Steps

In [None]:
# Final summary
print("🎯 InfraOwl Training Summary")
print("==========================")

if history is not None:
    print("✅ Training completed successfully!")
    print("\n📊 Results:")
    print(f"  Final Training Accuracy: {history.history['accuracy'][-1]:.4f}")
    print(f"  Final Validation Accuracy: {history.history['val_accuracy'][-1]:.4f}")
    print(f"  Best Validation Accuracy: {max(history.history['val_accuracy']):.4f}")
    
    if 'test_accuracy' in locals():
        print(f"  Test Accuracy: {test_accuracy:.4f}")
    
    print("\n📁 Generated Files:")
    output_files = [
        '../models/saved_models/infraowl_model.h5',
        '../models/checkpoints/best_model.h5',
        '../outputs/training_results.png',
        '../outputs/training_history.png',
        '../outputs/model_summary.txt'
    ]
    
    for file_path in output_files:
        if Path(file_path).exists():
            print(f"  ✅ {file_path}")
        else:
            print(f"  ❌ {file_path}")
    
    print("\n🚀 Next Steps:")
    print("1. 📊 Detailed Evaluation:")
    print("   python ../scripts/evaluate_model.py")
    print("\n2. 📱 Deploy to Flutter App:")
    print("   Copy the TFLite model to ../assets/")
    print("   Update the labels.txt file if needed")
    print("\n3. 🔄 Improve Model (if needed):")
    print("   • Collect more training data")
    print("   • Adjust hyperparameters")
    print("   • Try different architectures")
    print("   • Increase training epochs")
    
else:
    print("❌ Training was not completed")
    print("\nTo start training:")
    print("1. Ensure processed data exists (run 02_data_preprocessing.ipynb)")
    print("2. Re-run this notebook or use: python ../scripts/train_model.py")

print("\n💡 Tips:")
print("• Monitor training with TensorBoard: tensorboard --logdir ../logs")
print("• Adjust configuration in ../configs/training_config.yaml")
print("• Use GPU for faster training if available")
print("• Regularly backup your trained models")