## Section 1: Imports and Configuration
### Importing necessary libraries for data manipulation, visualization, and model building.

In [None]:
# Cassava Leaf Disease Classification

## Overview
# This notebook implements a classification model using EfficientNetB0 to detect leaf diseases from cassava images. The dataset contains five disease classes.

# Section 1: Imports and Configuration
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras import models, layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import os
from PIL import Image

# Configuration settings
class CFG:    
    WORK_DIR = "../input/cassava-leaf-disease-classification/"
    BATCH_SIZE = 8
    EPOCHS = 10
    TARGET_SIZE = 256
    NCLASSES = 5

# Utility functions
def visualize_class_distribution(labels):
    plt.figure(figsize=(8, 6))
    sns.countplot(x='label', data=labels, palette='coolwarm')
    plt.title('Class Distribution', fontsize=16)
    plt.xlabel('Disease Classes', fontsize=14)
    plt.ylabel('Frequency', fontsize=14)
    plt.tight_layout()
    plt.show()

def load_and_display_sample_image(image_dir, image_id, label):
    sample_image_path = os.path.join(image_dir, image_id)
    sample_image = Image.open(sample_image_path)
    plt.figure(figsize=(6, 6))
    plt.imshow(sample_image)
    plt.title(f"Sample Image - Class: {label}")
    plt.axis('off')
    plt.show()

def create_data_generators(labels, image_dir):
    train_datagen = ImageDataGenerator(
        validation_split=0.2,
        rescale=1./255,
        rotation_range=45,
        zoom_range=0.2,
        shear_range=0.15,
        brightness_range=[0.8, 1.2],
        channel_shift_range=50.0,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='nearest'
    )

    train_generator = train_datagen.flow_from_dataframe(
        dataframe=labels,
        directory=image_dir,
        subset="training",
        x_col="image_id",
        y_col="label",
        target_size=(CFG.TARGET_SIZE, CFG.TARGET_SIZE),
        batch_size=CFG.BATCH_SIZE,
        class_mode="sparse",
        seed=42
    )

    validation_datagen = ImageDataGenerator(
        validation_split=0.2,
        rescale=1./255
    )

    validation_generator = validation_datagen.flow_from_dataframe(
        dataframe=labels,
        directory=image_dir,
        subset="validation",
        x_col="image_id",
        y_col="label",
        target_size=(CFG.TARGET_SIZE, CFG.TARGET_SIZE),
        batch_size=CFG.BATCH_SIZE,
        class_mode="sparse",
        seed=42
    )

    return train_generator, validation_generator

def plot_training_history(history):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Training and Validation Accuracy', fontsize=16)
    plt.xlabel('Epochs', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Training and Validation Loss', fontsize=16)
    plt.xlabel('Epochs', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.legend()

    plt.tight_layout()
    plt.show()

def evaluate_model(model, generator, class_indices):
    val_preds = model.predict(generator).argmax(axis=1)
    val_true = generator.classes

    # Classification report
    report = classification_report(val_true, val_preds, target_names=class_indices.keys())
    print(report)

    # Confusion matrix
    cm = confusion_matrix(val_true, val_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_indices.keys(), yticklabels=class_indices.keys())
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    return val_preds, val_true

def plot_roc_curve(model, generator, class_indices):
    val_true_bin = label_binarize(generator.classes, classes=list(class_indices.values()))
    pred_probs = model.predict(generator)

    plt.figure(figsize=(10, 8))
    for i, class_name in enumerate(class_indices.keys()):
        fpr, tpr, _ = roc_curve(val_true_bin[:, i], pred_probs[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.2f})')

    plt.plot([0, 1], [0, 1], 'k--')
    plt.title('Multi-Class ROC Curve', fontsize=16)
    plt.xlabel('False Positive Rate', fontsize=14)
    plt.ylabel('True Positive Rate', fontsize=14)
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

## Section 2: Load and Explore Data
### Load the dataset and explore its structure to understand the data.

In [None]:

train_labels = pd.read_csv(os.path.join(CFG.WORK_DIR, "train.csv"))
train_labels['label'] = train_labels['label'].astype(str)

# Visualize class distribution
visualize_class_distribution(train_labels)

# Display sample image
image_dir = os.path.join(CFG.WORK_DIR, "train_images")
load_and_display_sample_image(image_dir, train_labels['image_id'][0], train_labels['label'][0])

## Section 3: Data Preprocessing
### Preprocessing includes normalizing pixel values and augmenting the dataset to improve model generalization.

In [None]:
train_generator, validation_generator = create_data_generators(train_labels, image_dir)

## Section 4: Model Creation
### Define the EfficientNetB0 model architecture and prepare it for training.

In [None]:
# Section 4: Model Creation
def create_model():
    base_model = EfficientNetB0(include_top=False, weights='imagenet', input_shape=(CFG.TARGET_SIZE, CFG.TARGET_SIZE, 3))
    x = layers.GlobalAveragePooling2D()(base_model.output)
    x = layers.Dropout(0.5)(x)
    output = layers.Dense(CFG.NCLASSES, activation='softmax',kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
    model = models.Model(inputs=base_model.input, outputs=output)
    model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

model = create_model()
model.summary()

In [None]:
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels['label']), y=train_labels['label'])
class_weights = dict(enumerate(class_weights))
print("Class Weights:", class_weights)

## Section 5: Model Training
### Train the model using the training and validation data generators.

In [None]:
# Section 5: Model Training
callbacks = [
    ModelCheckpoint('best_model.weights.h5', save_best_only=True, save_weights_only=True, monitor='val_loss', mode='min',verbose = 1),
    EarlyStopping(monitor='val_loss',min_delta = 0.001, patience=5,mode = 'min',verbose = 1, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.3, patience=2,min_delta = 0.001, 
                              mode = 'min', verbose = 1)
]

STEPS_PER_EPOCH = int(len(train_labels)*0.8 / CFG.BATCH_SIZE)
VALIDATION_STEPS = int(len(train_labels)*0.2 / CFG.BATCH_SIZE)

history = model.fit(
    train_generator,
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=CFG.EPOCHS,
    validation_data=validation_generator,
    validation_steps=VALIDATION_STEPS,
    class_weight=class_weights,
    callbacks=callbacks
)

In [None]:
plot_training_history(history)

## Section 6: Model Evaluation and Predictions
### Evaluate the trained model and make predictions on the test data.

In [None]:
# Section 6: Evaluation and Metrics
val_preds, val_true = evaluate_model(model, validation_generator, validation_generator.class_indices)
plot_roc_curve(model, validation_generator, validation_generator.class_indices)