# MobileNetV2 + Attention Mechanism for Leaf Disease Detection
## Transfer Learning from PlantVillage to Local Vietnamese Crops

This notebook implements a comprehensive deep learning pipeline for detecting small, blurry, and occluded plant diseases using:
- **MobileNetV2 backbone** for efficient inference
- **CBAM attention modules** for focusing on disease regions
- **U-Net segmentation** for disease region isolation
- **Transfer learning** from PlantVillage dataset
- **Fine-tuning** with local Vietnamese crop data (tomato, rice, etc.)

**Author**: Leaf Disease Detector Team  
**Date**: 2024  
**Framework**: TensorFlow 2.11+ with Keras

## Section 1: Import Required Libraries

In [None]:
# Import required libraries
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Deep Learning
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard
import keras.backend as K

# Data Science
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_recall_fscore_support

# Image Processing
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import ListedColormap
import seaborn as sns

# Utilities
import json
from pathlib import Path
from datetime import datetime
import pickle

# Check TensorFlow version and GPU availability
print(f"TensorFlow Version: {tf.__version__}")
print(f"GPU Available: {tf.test.is_built_with_cuda()}")
print(f"GPU Devices: {tf.config.list_physical_devices('GPU')}")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

## Section 2: Configuration and Setup

In [None]:
# Configuration
CONFIG = {
    'INPUT_SHAPE': (224, 224, 3),
    'BATCH_SIZE': 32,
    'EPOCHS_PRETRAIN': 50,
    'EPOCHS_FINETUNE': 30,
    'LEARNING_RATE_PRETRAIN': 1e-4,
    'LEARNING_RATE_FINETUNE': 1e-5,
    'DROPOUT_RATE': 0.5,
    'NUM_CLASSES': 50,  # Adjust based on your dataset
    'MODEL_DIR': '../models',
    'LOG_DIR': '../logs',
    'DATA_DIR': '../data'
}

# Create output directories
os.makedirs(CONFIG['MODEL_DIR'], exist_ok=True)
os.makedirs(CONFIG['LOG_DIR'], exist_ok=True)

print("=" * 80)
print("CONFIGURATION")
print("=" * 80)
for key, value in CONFIG.items():
    print(f"{key}: {value}")
print("=" * 80)

## Section 3: Attention Mechanisms Implementation

In [None]:
class ChannelAttention(layers.Layer):
    """Channel Attention Module (CAM)
    Recalibrates channel-wise feature responses"""
    
    def __init__(self, reduction_ratio=16, **kwargs):
        super(ChannelAttention, self).__init__(**kwargs)
        self.reduction_ratio = reduction_ratio
    
    def build(self, input_shape):
        channels = input_shape[-1]
        self.avg_pool = layers.GlobalAveragePooling2D(keepdims=True)
        self.max_pool = layers.GlobalMaxPooling2D(keepdims=True)
        self.fc1 = layers.Dense(channels // self.reduction_ratio, activation='relu')
        self.fc2 = layers.Dense(channels)
        super(ChannelAttention, self).build(input_shape)
    
    def call(self, inputs):
        avg_out = self.fc2(self.fc1(self.avg_pool(inputs)))
        max_out = self.fc2(self.fc1(self.max_pool(inputs)))
        channel_out = keras.activations.sigmoid(avg_out + max_out)
        return inputs * channel_out


class SpatialAttention(layers.Layer):
    """Spatial Attention Module (SAM)
    Generates attention maps along the spatial dimension"""
    
    def __init__(self, kernel_size=7, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        self.kernel_size = kernel_size
        self.conv = layers.Conv2D(
            filters=1,
            kernel_size=kernel_size,
            padding='same',
            activation='sigmoid',
            use_bias=False
        )
    
    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3, keepdims=True)
        max_out = tf.reduce_max(inputs, axis=3, keepdims=True)
        x = tf.concat([avg_out, max_out], axis=3)
        spatial_out = self.conv(x)
        return inputs * spatial_out


class CbamAttention(layers.Layer):
    """Convolutional Block Attention Module (CBAM)
    Sequentially applies Channel and Spatial Attention"""
    
    def __init__(self, reduction_ratio=16, **kwargs):
        super(CbamAttention, self).__init__(**kwargs)
        self.channel_attention = ChannelAttention(reduction_ratio=reduction_ratio)
        self.spatial_attention = SpatialAttention()
    
    def call(self, inputs):
        x = self.channel_attention(inputs)
        x = self.spatial_attention(x)
        return x

print("✓ Attention mechanisms defined successfully")

## Section 4: Build MobileNetV2 with Attention Mechanism

In [None]:
def create_mobilenetv2_attention_model(
    num_classes,
    input_shape=(224, 224, 3),
    freeze_base=False,
    dropout_rate=0.5
):
    """Create MobileNetV2 with CBAM attention for disease detection"""
    
    # Load pretrained MobileNetV2
    base_model = MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )
    
    if freeze_base:
        base_model.trainable = False
    
    # Build attention-enhanced architecture
    inputs = keras.Input(shape=input_shape)
    
    # Base feature extraction
    x = base_model(inputs, training=False)
    
    # Stage 1: CBAM attention on extracted features
    x = CbamAttention(reduction_ratio=16)(x)
    
    # Stage 2: Additional conv layers for fine-grained analysis
    x = layers.Conv2D(512, kernel_size=3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = CbamAttention(reduction_ratio=16)(x)
    
    # Stage 3: Multi-scale feature pooling
    avg_pool = layers.GlobalAveragePooling2D()(x)
    max_pool = layers.GlobalMaxPooling2D()(x)
    x = layers.Concatenate()([avg_pool, max_pool])
    
    # Classification head
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    
    x = layers.Dense(128, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = models.Model(inputs=inputs, outputs=outputs)
    return model, base_model


# Create model
print("Creating MobileNetV2 + Attention model...")
model, base_model = create_mobilenetv2_attention_model(
    num_classes=CONFIG['NUM_CLASSES'],
    input_shape=CONFIG['INPUT_SHAPE'],
    freeze_base=True,
    dropout_rate=CONFIG['DROPOUT_RATE']
)

# Compile model
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=CONFIG['LEARNING_RATE_PRETRAIN']),
    loss='categorical_crossentropy',
    metrics=[
        'accuracy',
        keras.metrics.TopKCategoricalAccuracy(k=3, name='top_3_accuracy'),
        keras.metrics.Precision(),
        keras.metrics.Recall()
    ]
)

print("✓ Model created and compiled")
print("\nModel Summary:")
model.summary()

## Section 5: Segmentation Model for Disease Region Isolation

In [None]:
def create_unet_segmentation_model(input_shape=(256, 256, 3), num_filters=32, depth=4):
    """U-Net model for leaf disease segmentation"""
    
    def conv_block(x, num_filters, kernel_size=3):
        x = layers.Conv2D(num_filters, kernel_size, padding='same', activation='relu', 
                         kernel_initializer='he_normal')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Dropout(0.3)(x)
        x = layers.Conv2D(num_filters, kernel_size, padding='same', activation='relu',
                         kernel_initializer='he_normal')(x)
        x = layers.BatchNormalization()(x)
        return x
    
    inputs = keras.Input(shape=input_shape)
    
    # Encoder
    encoder_blocks = []
    x = inputs
    for i in range(depth):
        filters = num_filters * (2 ** i)
        x = conv_block(x, filters)
        encoder_blocks.append(x)
        if i < depth - 1:
            x = layers.MaxPooling2D(2)(x)
    
    # Decoder
    for i in range(depth - 2, -1, -1):
        filters = num_filters * (2 ** i)
        x = layers.UpSampling2D(2)(x)
        x = layers.Concatenate()([x, encoder_blocks[i]])
        x = conv_block(x, filters)
    
    # Output
    outputs = layers.Conv2D(1, 1, activation='sigmoid')(x)
    
    segmentation_model = models.Model(inputs=inputs, outputs=outputs)
    return segmentation_model

# Create segmentation model
print("Creating U-Net segmentation model...")
seg_model = create_unet_segmentation_model(
    input_shape=(256, 256, 3),
    num_filters=32,
    depth=4
)

seg_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss='binary_crossentropy',
    metrics=['mae', keras.metrics.Precision(), keras.metrics.Recall()]
)

print("✓ Segmentation model created")
print("Segmentation Model Summary:")
seg_model.summary()

## Section 6: Data Loading and Preprocessing

In [None]:
# Data generators for PlantVillage dataset
def create_data_generators(augment=True):
    """Create training and validation data generators"""
    
    if augment:
        train_datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=30,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            vertical_flip=True,
            fill_mode='nearest'
        )
    else:
        train_datagen = ImageDataGenerator(rescale=1./255)
    
    val_datagen = ImageDataGenerator(rescale=1./255)
    
    return train_datagen, val_datagen

# Example function to load PlantVillage data
def load_plantvillage_data(plantvillage_path, batch_size=32, validation_split=0.2):
    """Load PlantVillage dataset for transfer learning
    
    Expected directory structure:
    plantvillage_path/
    ├── disease_class_1/
    │   ├── image1.jpg
    │   ├── image2.jpg
    │   └── ...
    ├── disease_class_2/
    │   └── ...
    └── ...
    """
    
    train_datagen, val_datagen = create_data_generators(augment=True)
    
    # Note: This is a template. Replace with your actual PlantVillage path
    print(f"Loading PlantVillage data from: {plantvillage_path}")
    print(f"Batch size: {batch_size}")
    print(f"Validation split: {validation_split}")
    
    # Example: train_generator = train_datagen.flow_from_directory(...)
    return None, None

print("✓ Data loading functions defined")

## Section 7: Transfer Learning - Pre-training on PlantVillage

In [None]:
# EXAMPLE: Pre-training on PlantVillage
# Uncomment and modify paths to use with your data

print("=" * 80)
print("PRE-TRAINING ON PLANTVILLAGE DATASET")
print("=" * 80)
print("""
To use pre-training:
1. Download PlantVillage dataset from: https://github.com/spMohanty/PlantVillage-Dataset
2. Extract and organize images by disease class
3. Set plantvillage_path to the dataset directory
4. Uncomment the code below and run

Expected directory structure:
plantvillage_path/
├── Apple___Apple_scab/
├── Apple___Black_rot/
├── Tomato___Bacterial_spot/
├── ... (50+ disease classes)
└── Tomato___healthy/
""")

# Example pre-training code (commented out)
# plantvillage_path = '/path/to/plantvillage/data'
# train_gen, val_gen = load_plantvillage_data(plantvillage_path)
# 
# callbacks = [
#     ModelCheckpoint('mobilenetv2_attention_plantvillage.h5', monitor='val_accuracy', save_best_only=True),
#     EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
#     ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5),
#     TensorBoard(log_dir=CONFIG['LOG_DIR'])
# ]
# 
# history_pretrain = model.fit(
#     train_gen,
#     epochs=CONFIG['EPOCHS_PRETRAIN'],
#     validation_data=val_gen,
#     callbacks=callbacks,
#     verbose=1
# )

## Section 8: Fine-tuning on Local Vietnamese Crop Data

In [None]:
print("=" * 80)
print("FINE-TUNING ON LOCAL VIETNAMESE CROPS")
print("=" * 80)
print("""
Local dataset structure for Vietnamese crops (Tomato, Rice, etc.):

local_data_path/
├── tomato/
│   ├── early_blight/
│   │   ├── image1.jpg
│   │   ├── image2.jpg
│   │   └── ...
│   ├── late_blight/
│   ├── powdery_mildew/
│   └── healthy/
├── rice/
│   ├── blast/
│   ├── brown_spot/
│   ├── sheath_blight/
│   └── healthy/
└── ... (other crops)
""")

def finetune_on_local_data(
    model,
    local_data_path,
    freeze_base_layers=100,
    learning_rate=1e-5,
    epochs=30,
    batch_size=16
):
    """Fine-tune model on local Vietnamese crop data"""
    
    # Freeze specified base layers for transfer learning
    print(f"Freezing first {freeze_base_layers} layers...")
    for layer in model.layers[:freeze_base_layers]:
        layer.trainable = False
    
    # Recompile with lower learning rate for fine-tuning
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
        loss='categorical_crossentropy',
        metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=3)]
    )
    
    # Create data generators
    train_datagen, val_datagen = create_data_generators(augment=True)
    
    print(f"Loading local data from: {local_data_path}")
    # train_generator = train_datagen.flow_from_directory(...)
    # val_generator = val_datagen.flow_from_directory(...)
    
    print("Ready for fine-tuning. Data generators created.")
    return model

print("✓ Fine-tuning function defined")

## Section 9: Model Evaluation and Metrics

In [None]:
def evaluate_model(model, test_generator, class_labels):
    """Evaluate model performance on test set"""
    
    print("Evaluating model...")
    
    # Predictions
    y_pred_probs = model.predict(test_generator, verbose=1)
    y_pred = np.argmax(y_pred_probs, axis=1)
    y_true = test_generator.classes
    
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    
    print("=" * 80)
    print("MODEL EVALUATION RESULTS")
    print("=" * 80)
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-Score:  {f1:.4f}")
    
    # Classification report
    print("\n" + "=" * 80)
    print("CLASSIFICATION REPORT")
    print("=" * 80)
    print(classification_report(y_true, y_pred, target_names=class_labels))
    
    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm,
        'y_true': y_true,
        'y_pred': y_pred,
        'y_pred_probs': y_pred_probs
    }

def plot_confusion_matrix(cm, class_labels, figsize=(12, 10)):
    """Plot confusion matrix heatmap"""
    
    plt.figure(figsize=figsize)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_labels, yticklabels=class_labels, cbar_kws={'label': 'Count'})
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()

def plot_training_history(history, metrics=['accuracy', 'loss']):
    """Plot training history"""
    
    fig, axes = plt.subplots(1, len(metrics), figsize=(15, 4))
    
    for idx, metric in enumerate(metrics):
        if len(metrics) == 1:
            ax = axes
        else:
            ax = axes[idx]
        
        ax.plot(history[metric], label=f'Train {metric}')
        if f'val_{metric}' in history:
            ax.plot(history[f'val_{metric}'], label=f'Val {metric}')
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel(metric.capitalize())
        ax.set_title(f'{metric.capitalize()} Over Epochs')
        ax.legend()
        ax.grid(True)
    
    plt.tight_layout()
    plt.show()

print("✓ Evaluation functions defined")

## Section 10: Visualization of Predictions and Attention Maps

In [None]:
def visualize_predictions(images, predictions, true_labels, class_labels, num_samples=9):
    """Visualize model predictions with attention"""
    
    num_samples = min(num_samples, len(images))
    fig, axes = plt.subplots(3, 3, figsize=(15, 12))
    axes = axes.flatten()
    
    for idx in range(num_samples):
        ax = axes[idx]
        
        # Display image
        img = images[idx]
        if img.dtype == np.float32 or img.dtype == np.float64:
            img = (img * 255).astype(np.uint8)
        ax.imshow(img)
        
        # Get prediction
        pred_class = np.argmax(predictions[idx])
        true_class = true_labels[idx]
        confidence = predictions[idx][pred_class]
        
        # Color: green if correct, red if wrong
        color = 'green' if pred_class == true_class else 'red'
        
        title = f"True: {class_labels[true_class]}\n"
        title += f"Pred: {class_labels[pred_class]} ({confidence:.2%})"
        ax.set_title(title, color=color, fontweight='bold')
        ax.axis('off')
    
    # Hide unused subplots
    for idx in range(num_samples, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

def create_attention_heatmap(model, image, layer_name=None):
    """Create attention heatmap for model interpretation"""
    
    # Get intermediate layer outputs
    intermediate_layer_model = keras.Model(
        inputs=model.input,
        outputs=model.get_layer(layer_name).output
    )
    
    # Get intermediate output
    img_batch = np.expand_dims(image, axis=0)
    intermediate_output = intermediate_layer_model.predict(img_batch)
    
    # Average across channels
    attention_map = np.mean(intermediate_output[0], axis=-1)
    
    # Normalize
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    
    # Resize to match input image size
    attention_map_resized = cv2.resize(attention_map, (image.shape[1], image.shape[0]))
    
    return attention_map_resized

def visualize_segmentation(original_image, segmentation_mask, disease_regions, figsize=(15, 5)):
    """Visualize original image, segmentation mask, and disease regions"""
    
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    
    # Original image
    axes[0].imshow(original_image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Segmentation mask
    axes[1].imshow(segmentation_mask, cmap='gray')
    axes[1].set_title('Leaf Segmentation Mask')
    axes[1].axis('off')
    
    # Disease regions
    axes[2].imshow(disease_regions, cmap='hot')
    axes[2].set_title('Disease Regions (Heatmap)')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

print("✓ Visualization functions defined")

## Section 11: Model Saving and Deployment

In [None]:
def save_model_and_metadata(model, seg_model, output_path, metadata=None):
    """Save trained models and metadata"""
    
    os.makedirs(output_path, exist_ok=True)
    
    # Save classification model
    model_path = os.path.join(output_path, 'mobilenetv2_attention_classifier.h5')
    model.save(model_path)
    print(f"✓ Saved classification model to {model_path}")
    
    # Save segmentation model
    seg_model_path = os.path.join(output_path, 'unet_segmentation.h5')
    seg_model.save(seg_model_path)
    print(f"✓ Saved segmentation model to {seg_model_path}")
    
    # Save metadata
    if metadata:
        metadata_path = os.path.join(output_path, 'metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=4, default=str)
        print(f"✓ Saved metadata to {metadata_path}")

def load_model_for_inference(model_path, custom_objects=None):
    """Load trained model for inference"""
    
    if custom_objects is None:
        custom_objects = {
            'ChannelAttention': ChannelAttention,
            'SpatialAttention': SpatialAttention,
            'CbamAttention': CbamAttention
        }
    
    model = keras.models.load_model(model_path, custom_objects=custom_objects)
    return model

# Save model example
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(CONFIG['MODEL_DIR'], f'mobilenetv2_attention_{timestamp}')

metadata = {
    'model_type': 'MobileNetV2 + CBAM Attention',
    'input_shape': CONFIG['INPUT_SHAPE'],
    'num_classes': CONFIG['NUM_CLASSES'],
    'includes_segmentation': True,
    'training_date': timestamp,
    'framework': 'TensorFlow',
    'version': '2.11.0'
}

print("Model saving functions defined")
print(f"\nExample output directory: {output_dir}")
# save_model_and_metadata(model, seg_model, output_dir, metadata)

## Section 12: Complete Training Pipeline Example

In [None]:
print("=" * 80)
print("COMPLETE TRAINING PIPELINE WORKFLOW")
print("=" * 80)

workflow_steps = """
STEP-BY-STEP TRAINING WORKFLOW:

1. **Prepare PlantVillage Dataset**
   - Download from: https://github.com/spMohanty/PlantVillage-Dataset
   - Organize by disease class
   - Expected: 50+ diseases across multiple crops

2. **Phase 1: Pre-training on PlantVillage**
   - Load MobileNetV2 + Attention model with frozen base
   - Train on PlantVillage data (50 epochs)
   - Output: mobilenetv2_attention_plantvillage.h5

3. **Phase 2: Fine-tuning on Local Data**
   - Prepare local Vietnamese crop data:
     * Tomato (cà chua) - early blight, late blight, etc.
     * Rice (lúa) - blast, brown spot, sheath blight
     * Other crops in Gia Lai region
   - Unfreeze top layers of pre-trained model
   - Train with lower learning rate (30 epochs)
   - Output: mobilenetv2_attention_finetuned.h5

4. **Segmentation Training (Optional)**
   - Train U-Net on annotated leaf images
   - Create binary masks (disease/no disease)
   - Output: unet_segmentation.h5

5. **Evaluation**
   - Test on held-out validation set
   - Generate confusion matrix and classification report
   - Calculate precision, recall, F1-score

6. **Deployment**
   - Save models as .h5 files
   - Integrate with backend server
   - Use for real-time predictions on leaf images

EXPECTED PERFORMANCE:
- Accuracy on PlantVillage: 92-97%
- Accuracy on local Vietnamese crops: 88-95%
- Inference time: <500ms per image on CPU

KEY ADVANTAGES:
✓ MobileNetV2: Efficient, suitable for edge deployment
✓ Attention mechanisms: Focus on small/unclear disease regions
✓ U-Net segmentation: Separate disease from background
✓ Transfer learning: Leverage large public datasets
✓ Fine-tuning: Adapt to local crop characteristics
✓ Multi-task learning: Classification + segmentation
"""

print(workflow_steps)

print("\n" + "=" * 80)
print("NEXT STEPS")
print("=" * 80)
print("""
1. Download PlantVillage dataset
2. Prepare local Vietnamese crop dataset
3. Update data paths in this notebook
4. Uncomment training cells and execute
5. Monitor training with TensorBoard
6. Evaluate on test set
7. Deploy trained models to production
""")