# Entraînement du Modèle avec Transfer Learning (MobileNetV2)

Ce notebook entraîne un CNN pour classifier les maladies des plantes.

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

print(f"Version TensorFlow : {tf.__version__}")

## 1. Générateurs de Données

In [None]:
TRAIN_DIR = '../Dataset/train'
VALID_DIR = '../Dataset/valid'
IMG_SIZE = (224, 224)
BATCH_SIZE = 32

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

valid_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    TRAIN_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=True
)

valid_generator = valid_datagen.flow_from_directory(
    VALID_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

num_classes = train_generator.num_classes
print(f"Nombre de classes : {num_classes}")

# SAUVEGARDER LES INDICES DE CLASSE
class_indices = train_generator.class_indices
class_indices = {v: k for k, v in class_indices.items()} # Inverser pour obtenir index -> classe
with open('class_indices.json', 'w') as f:
    json.dump(class_indices, f)
print("Indices de classe sauvegardés dans class_indices.json")

## 2. Architecture du Modèle (Transfer Learning)

In [None]:
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Geler le modèle de base
base_model.trainable = False

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.2)(x)
predictions = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

model.compile(optimizer=Adam(learning_rate=0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

## 3. Entraînement

In [None]:
callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ModelCheckpoint('model.h5', monitor='val_accuracy', save_best_only=True)
]

EPOCHS = 10 
# Note : 10 époques peut prendre du temps sur CPU. Ajustez si nécessaire.

history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=valid_generator,
    callbacks=callbacks
)

## 4. Évaluation et Sauvegarde

In [None]:
# Tracer la précision et la perte
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(acc))

plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Précision Entraînement')
plt.plot(epochs_range, val_acc, label='Précision Validation')
plt.legend(loc='lower right')
plt.title('Précision Entraînement et Validation')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Perte Entraînement')
plt.plot(epochs_range, val_loss, label='Perte Validation')
plt.legend(loc='upper right')
plt.title('Perte Entraînement et Validation')
plt.show()

# --- NOUVEAU : Évaluation Détaillée ---
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import seaborn as sns

# Réinitialiser le générateur de validation pour être sûr
valid_generator.reset()

# Prédictions sur l'ensemble de validation
print("Génération des prédictions...")
predictions = model.predict(valid_generator)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = valid_generator.classes
class_labels = list(valid_generator.class_indices.keys())

# Matrice de Confusion
cm = confusion_matrix(true_classes, predicted_classes)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
plt.title('Matrice de Confusion')
plt.ylabel('Vraie Classe')
plt.xlabel('Classe Prédite')
plt.xticks(rotation=90)
plt.show()

# Rapport de Classification
print("\n--- Rapport de Classification ---")
report = classification_report(true_classes, predicted_classes, target_names=class_labels)
print(report)

# Scores Globaux
loss, accuracy = model.evaluate(valid_generator)
macro_f1 = f1_score(true_classes, predicted_classes, average='macro')
weighted_f1 = f1_score(true_classes, predicted_classes, average='weighted')

print("\n--- Scores Globaux ---")
print(f"Accuracy (Précision Globale) : {accuracy:.4f}")
print(f"Macro F1-Score : {macro_f1:.4f}")
print(f"Weighted F1-Score : {weighted_f1:.4f}")
print(f"Perte Globale : {loss:.4f}")

model.save('plant_disease_model.h5')