<a href="https://colab.research.google.com/github/rafyqmohammed/tomato-disease-detection-vgg/blob/main/model_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==========================================================
# 1Ô∏è‚É£ Imports et configuration de base
# ==========================================================
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers, models, optimizers
import matplotlib.pyplot as plt
import os

print("TensorFlow version:", tf.__version__)

# ==========================================================
# 2Ô∏è‚É£ Pr√©paration du dataset
# ==========================================================
# Exemple de structure de donn√©es :
# dataset/
# ‚îú‚îÄ‚îÄ train/
# ‚îÇ   ‚îú‚îÄ‚îÄ class1/
# ‚îÇ   ‚îú‚îÄ‚îÄ class2/
# ‚îÇ   ‚îú‚îÄ‚îÄ class3/
# ‚îÇ   ‚îî‚îÄ‚îÄ class4/
# ‚îú‚îÄ‚îÄ val/
# ‚îÇ   ‚îú‚îÄ‚îÄ class1/
# ‚îÇ   ‚îú‚îÄ‚îÄ class2/
# ‚îÇ   ‚îú‚îÄ‚îÄ class3/
# ‚îÇ   ‚îî‚îÄ‚îÄ class4/

# üî∏ Si ton dataset est sur Google Drive :
# from google.colab import drive
# drive.mount('/content/drive')
# data_dir = '/content/drive/MyDrive/mon_dataset'

data_dir = '/content/dataset'  # Change le chemin selon ton cas
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')

# ==========================================================
# 3Ô∏è‚É£ G√©n√©rateurs de donn√©es + Data Augmentation
# ==========================================================
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    zoom_range=0.2,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
)

val_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

num_classes = len(train_generator.class_indices)
print("Nombre de classes d√©tect√©es :", num_classes)
print("Classes :", train_generator.class_indices)

# ==========================================================
# 4Ô∏è‚É£ Charger le mod√®le VGG16 pr√©-entra√Æn√©
# ==========================================================
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224,224,3))
base_model.trainable = False  # ‚ùÑÔ∏è On g√®le les couches de VGG16

# ==========================================================
# 5Ô∏è‚É£ Construire ton mod√®le personnalis√©
# ==========================================================
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(num_classes, activation='softmax')
])

model.summary()

# ==========================================================
# 6Ô∏è‚É£ Compilation du mod√®le
# ==========================================================
model.compile(
    optimizer=optimizers.Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# ==========================================================
# 7Ô∏è‚É£ Entra√Ænement du mod√®le
# ==========================================================
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=10,
    verbose=1
)

# ==========================================================
# 8Ô∏è‚É£ √âvaluation sur le set de validation
# ==========================================================
val_loss, val_acc = model.evaluate(val_generator)
print(f"\n‚úÖ Validation Accuracy: {val_acc:.4f}")
print(f"‚úÖ Validation Loss: {val_loss:.4f}")

# ==========================================================
# 9Ô∏è‚É£ Visualisation de l'apprentissage
# ==========================================================
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='val_acc')
plt.title("Accuracy")
plt.legend()

plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.title("Loss")
plt.legend()
plt.show()

# ==========================================================
# üîü Sauvegarde du mod√®le
# ==========================================================
model.save("vgg16_custom_model.h5")
print("‚úÖ Mod√®le sauvegard√© sous vgg16_custom_model.h5")

# ==========================================================
# 1Ô∏è‚É£1Ô∏è‚É£ Exemple de pr√©diction sur une seule image
# ==========================================================
import numpy as np
from tensorflow.keras.preprocessing import image

img_path = os.path.join(val_dir, list(train_generator.class_indices.keys())[0], os.listdir(os.path.join(val_dir, list(train_generator.class_indices.keys())[0]))[0])
print("üñºÔ∏è Exemple d'image :", img_path)

img = image.load_img(img_path, target_size=(224, 224))
img_array = image.img_to_array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)

pred = model.predict(img_array)
pred_class = list(train_generator.class_indices.keys())[np.argmax(pred)]
print(f"‚úÖ Pr√©diction : {pred_class}")
