<a href="https://colab.research.google.com/github/shrii21/JPA/blob/main/Plant_Disease_Web_Predictor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import os
import io
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers
from PIL import Image
from flask import Flask, request, jsonify, render_template_string

# --- Configuration ---
# NOTE: This web app is designed to run the prediction part of your model,
# not the full training process (which requires the 'PlantVillage' dataset and substantial time).
# This script assumes you have a trained model file named 'potatoes.h5' in the same directory.

IMAGE_SIZE = 256
CHANNELS = 3
N_CLASSES = 11  # Based on your training notebook structure (e.g., if you have 11 classes)
CLASS_NAMES = [
    "Potato Early Blight",
    "Potato Late Blight",
    "Potato Healthy",
    # Add the rest of your 11 class names here, or load them from a separate file/config
    "Tomato Bacterial Spot",
    "Tomato Early Blight",
    "Tomato Healthy",
    "Corn Common Rust",
    "Corn Gray Leaf Spot",
    "Corn Healthy",
    "Apple Scab",
    "Apple Black Rot"
]

# Create the same model architecture as defined in your notebook (cell 1fcdb1e1)
def create_model(input_shape, n_classes):
    """Recreates the CNN architecture defined in the training notebook."""
    model = models.Sequential([
        layers.Input(shape=input_shape), # Added explicit Input layer to address Keras UserWarning
        # We skip the resize_and_rescale layers here because we will preprocess manually for single images
        # layers.experimental.preprocessing.Resizing(IMAGE_SIZE, IMAGE_SIZE),
        # layers.experimental.preprocessing.Rescaling(1./255),

        # NOTE: Input shape must be defined without the batch size for a saved model load/predict
        layers.Conv2D(32, kernel_size=(3, 3), activation='relu'), # input_shape moved to Input layer
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(n_classes, activation='softmax'),
    ])
    return model

# --- Model Loading and Initialization ---
app = Flask(__name__)
# The model object is a global variable
model = None

def _load_model_on_startup():
    """Loads the trained Keras model on application startup."""
    global model
    model_path = "potatoes.h5"
    print(f"Attempting to load model from: {model_path}")

    # Check if the model file exists
    if not os.path.exists(model_path):
        print("MODEL NOT FOUND! Using untrained architecture.")
        # Create an untrained model if the file is missing (will produce garbage predictions)
        model = create_model(input_shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS), n_classes=N_CLASSES)
    else:
        try:
            # Load the trained model
            model = tf.keras.models.load_model(model_path, compile=False)
            print("Model loaded successfully.")
        except Exception as e:
            print(f"Error loading model: {e}. Creating untrained architecture instead.")
            model = create_model(input_shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS), n_classes=N_CLASSES)

    # Compile the model (even if untrained) to prepare it for inference
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.summary()

# Call the model loading function directly at startup
_load_model_on_startup()

def preprocess_image(image):
    """
    Preprocesses the image to match the model's expected input (256x256, normalized).
    This mimics the Resizing and Rescaling layers from your notebook.
    """
    # Resize the image
    image = image.resize((IMAGE_SIZE, IMAGE_SIZE))

    # Convert PIL image to numpy array
    img_array = tf.keras.preprocessing.image.img_to_array(image)

    # Rescale/Normalize (1./255)
    img_array = img_array / 255.0

    # Expand dimensions (Add batch dimension)
    img_array = np.expand_dims(img_array, 0)

    return img_array

@app.route('/', methods=['GET'])
def index():
    """Renders the single-page HTML interface."""
    return render_template_string(HTML_TEMPLATE)

@app.route('/predict', methods=['POST'])
def predict_image():
    """Handles the image upload and prediction via the model."""
    if 'file' not in request.files:
        return jsonify({'error': 'No file part'}), 400

    file = request.files['file']

    if file.filename == '':
        return jsonify({'error': 'No selected file'}), 400

    if file:
        try:
            # Read the image file
            image_bytes = file.read()
            image = Image.open(io.BytesIO(image_bytes))

            # Preprocess the image
            processed_image = preprocess_image(image)

            # Make prediction
            predictions = model.predict(processed_image)

            # Get the predicted class and confidence
            predicted_class_index = np.argmax(predictions[0])
            predicted_class = CLASS_NAMES[predicted_class_index]
            confidence = round(100 * np.max(predictions[0]), 2)

            return jsonify({
                'prediction': predicted_class,
                'confidence': confidence,
                'status': 'success'
            })

        except Exception as e:
            print(f"Prediction error: {e}")
            return jsonify({'error': f'An error occurred during prediction: {str(e)}'}), 500

# --- HTML/JavaScript Frontend (Styled with Tailwind CSS) ---
HTML_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Plant Disease Predictor</title>
    <!-- Load Tailwind CSS -->
    <script src="https://cdn.tailwindcss.com"></script>
    <style>
        @import url('https://fonts.googleapis.com/css2?family=Inter:wght@100..900&display=swap');
        body {
            font-family: 'Inter', sans-serif;
            background-color: #f7f9fb;
        }
        .container {
            max-width: 90%;
            margin: 0 auto;
            padding: 1.5rem;
        }
        #imagePreview {
            width: 100%;
            height: 256px;
            object-fit: cover;
            border-radius: 0.75rem;
            border: 2px dashed #cbd5e1;
            display: flex;
            align-items: center;
            justify-content: center;
            color: #64748b;
            background-color: #f1f5f9;
            transition: all 0.3s;
        }
        #imagePreview.has-image {
            border: none;
        }
    </style>
</head>
<body>
    <div class="container min-h-screen flex items-center justify-center">
        <div class="w-full max-w-lg bg-white p-8 rounded-xl shadow-2xl">
            <h1 class="text-3xl font-extrabold text-green-700 mb-6 text-center">
                Plant Leaf Predictor
            </h1>
            <p class="text-gray-600 mb-8 text-center">
                Upload a plant leaf image to predict its health status or disease type.
            </p>

            <div id="imagePreview" class="mb-6">
                <span id="previewText">Select an image file (PNG, JPG)</span>
                <img id="previewImage" class="hidden w-full h-full rounded-xl" alt="Preview">
            </div>

            <input type="file" id="fileInput" accept="image/*" class="hidden" onchange="previewFile()">

            <button onclick="document.getElementById('fileInput').click()"
                    class="w-full bg-green-500 hover:bg-green-600 text-white font-semibold py-3 px-4 rounded-xl
                           shadow-lg transition duration-200 ease-in-out transform hover:scale-[1.01] mb-4">
                Choose Image
            </button>

            <button id="predictButton" onclick="uploadFile()" disabled
                    class="w-full bg-blue-500 text-white font-semibold py-3 px-4 rounded-xl
                           shadow-lg transition duration-200 ease-in-out disabled:opacity-50 disabled:cursor-not-allowed">
                Analyze Leaf
            </button>

            <div id="resultContainer" class="mt-8 p-6 bg-gray-100 rounded-xl hidden">
                <p id="loadingIndicator" class="text-center text-blue-600 font-medium hidden">
                    <svg class="animate-spin -ml-1 mr-3 h-5 w-5 text-blue-600 inline" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
                        <circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
                        <path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
                    </svg>
                    Analyzing image...
                </p>
                <div id="predictionOutput" class="space-y-3">
                    <h2 class="text-xl font-bold text-gray-800">Analysis Result</h2>
                    <p class="text-lg">
                        <span class="font-medium text-gray-700">Predicted Class:</span>
                        <span id="predictedClass" class="font-extrabold text-green-600"></span>
                    </p>
                    <p class="text-sm">
                        <span class="font-medium text-gray-700">Confidence:</span>
                        <span id="confidenceScore" class="font-bold text-gray-600"></span>
                    </p>
                </div>
            </div>

            <div id="errorBox" class="mt-4 p-3 bg-red-100 border border-red-400 text-red-700 rounded-xl hidden">
                An error occurred during prediction. Please try another image.
            </div>

        </div>
    </div>

    <script>
        const fileInput = document.getElementById('fileInput');
        const previewImage = document.getElementById('previewImage');
        const previewText = document.getElementById('previewText');
        const imagePreviewDiv = document.getElementById('imagePreview');
        const predictButton = document.getElementById('predictButton');
        const resultContainer = document.getElementById('resultContainer');
        const loadingIndicator = document.getElementById('loadingIndicator');
        const predictionOutput = document.getElementById('predictionOutput');
        const errorBox = document.getElementById('errorBox');

        // --- Utility Functions ---
        function setButtonState(isLoading) {
            predictButton.disabled = isLoading || !fileInput.files.length;
            loadingIndicator.classList.toggle('hidden', !isLoading);
            predictionOutput.classList.toggle('hidden', isLoading);
            resultContainer.classList.toggle('hidden', !isLoading && !predictionOutput.classList.contains('hidden'));
        }

        // --- Event Handlers ---
        function previewFile() {
            const file = fileInput.files[0];
            errorBox.classList.add('hidden');

            if (file) {
                const reader = new FileReader();
                reader.onload = function(e) {
                    previewImage.src = e.target.result;
                    previewImage.classList.remove('hidden');
                    previewText.classList.add('hidden');
                    imagePreviewDiv.classList.add('has-image');
                    predictButton.disabled = false;
                    resultContainer.classList.add('hidden');
                };
                reader.readAsDataURL(file);
            } else {
                previewImage.classList.add('hidden');
                previewText.classList.remove('hidden');
                imagePreviewDiv.classList.remove('has-image');
                predictButton.disabled = true;
                resultContainer.classList.add('hidden');
            }
        }

        async function uploadFile() {
            const file = fileInput.files[0];
            if (!file) return;

            setButtonState(true);
            errorBox.classList.add('hidden');
            resultContainer.classList.remove('hidden');

            const formData = new FormData();
            formData.append('file', file);

            try {
                const response = await fetch('/predict', {
                    method: 'POST',
                    body: formData
                });

                if (!response.ok) {
                    throw new Error(`Server returned status: ${response.status}`);
                }

                const data = await response.json();

                if (data.status === 'success') {
                    document.getElementById('predictedClass').textContent = data.prediction;
                    document.getElementById('confidenceScore').textContent = `${data.confidence}%`;
                } else {
                    // Handle application-level error (e.g., model failed)
                    console.error("Prediction failed:", data.error);
                    errorBox.textContent = data.error || "Prediction failed. Check server logs.";
                    errorBox.classList.remove('hidden');
                    predictionOutput.classList.add('hidden');
                }

            } catch (error) {
                console.error('Error during fetch:', error);
                errorBox.textContent = `Network or Server Error: ${error.message}`;
                errorBox.classList.remove('hidden');
                predictionOutput.classList.add('hidden');
            } finally {
                setButtonState(false);
            }
        }
    </script>
</body>
</html>
"""

if __name__ == '__main__':
    # You would typically run this in a terminal with 'flask run' or similar,
    # but for local testing within a Python script:
    # app.run(debug=True)
    # Note: The canvas environment will handle running the Flask server.
    pass

Attempting to load model from: potatoes.h5
MODEL NOT FOUND! Using untrained architecture.
