<a href="https://colab.research.google.com/github/superrjj/PalaYan_App/blob/master/rice_disease_auto_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================================
# PART 1: GOOGLE COLAB SERVER SETUP FOR RICE DISEASE CLASSIFICATION
# ============================================================================

# Install required packages
!pip install flask flask-cors firebase-admin tensorflow opencv-python pillow requests pyngrok

# Import libraries
from flask import Flask, request, jsonify
from flask_cors import CORS
import firebase_admin
from firebase_admin import credentials, firestore, storage
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import cv2
import numpy as np
import os
import requests
from PIL import Image
import json
from pyngrok import ngrok
import threading
import time

# Initialize Flask app
app = Flask(__name__)
CORS(app)

# Initialize Firebase Admin SDK
# Upload your service account JSON to Colab and update the path
cred = credentials.Certificate('/content/your-service-account-key.json')
firebase_admin.initialize_app(cred, {
    'storageBucket': 'your-project-id.appspot.com'
})

db = firestore.client()
bucket = storage.bucket()

# Global variables
model = None
class_names = []
is_training = False

# ============================================================================
# PART 2: DATA FETCHING AND PREPROCESSING FOR RICE DISEASES
# ============================================================================

def fetch_rice_disease_data_from_firebase():
    """Fetch all rice disease data from Firestore"""
    try:
        diseases_ref = db.collection('rice_local_disease')
        docs = diseases_ref.stream()

        disease_data = []
        for doc in docs:
            data = doc.to_dict()
            disease_data.append({
                'id': doc.id,
                'name': data.get('diseaseName'),
                'scientific_name': data.get('scientificName'),
                'image_urls': data.get('images', []),  # Multiple images per disease
                'description': data.get('description'),
                'symptoms': data.get('symptoms'),
                'cause': data.get('cause'),
                'treatments': data.get('treatments')
            })

        print(f"Fetched {len(disease_data)} rice diseases from Firebase")
        return disease_data
    except Exception as e:
        print(f"Error fetching rice disease data: {e}")
        return []

def download_and_organize_rice_disease_images():
    """Download rice disease images and organize them by disease name"""
    if os.path.exists('/content/rice_disease_dataset'):
        import shutil
        shutil.rmtree('/content/rice_disease_dataset')

    os.makedirs('/content/rice_disease_dataset', exist_ok=True)

    disease_data = fetch_rice_disease_data_from_firebase()
    class_names = []

    for disease in disease_data:
        disease_name = disease['name'].replace(' ', '_').replace('/', '_')
        class_dir = f'/content/rice_disease_dataset/{disease_name}'
        os.makedirs(class_dir, exist_ok=True)

        if disease_name not in class_names:
            class_names.append(disease_name)

        # Download multiple images for this disease
        for idx, image_url in enumerate(disease['image_urls']):
            try:
                response = requests.get(image_url)
                if response.status_code == 200:
                    image_path = f"{class_dir}/{disease['id']}_{idx}.jpg"
                    with open(image_path, 'wb') as f:
                        f.write(response.content)
                    print(f"Downloaded: {disease['name']} - Image {idx+1}")
                else:
                    print(f"Failed to download image {idx+1} for {disease['name']}")
            except Exception as e:
                print(f"Error downloading {disease['name']} image {idx+1}: {e}")

    return class_names

# ============================================================================
# PART 3: MODEL TRAINING FOR RICE DISEASE CLASSIFICATION
# ============================================================================

def create_rice_disease_model(num_classes):
    """Create a CNN model for rice disease classification"""
    # Use a more sophisticated architecture for disease classification
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=(224, 224, 3),
        include_top=False,
        weights='imagenet'
    )

    base_model.trainable = False  # Freeze base model

    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(num_classes, activation='softmax')
    ])

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    return model

def train_rice_disease_model():
    """Train the rice disease classification model"""
    global model, class_names, is_training

    if is_training:
        return {"status": "error", "message": "Training already in progress"}

    is_training = True

    try:
        # Download and organize data
        print("Downloading rice disease data...")
        class_names = download_and_organize_rice_disease_images()

        if len(class_names) < 2:
            is_training = False
            return {"status": "error", "message": "Need at least 2 disease classes to train"}

        # Data augmentation for plant disease images
        train_datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=30,
            width_shift_range=0.3,
            height_shift_range=0.3,
            shear_range=0.2,
            zoom_range=0.3,
            horizontal_flip=True,
            vertical_flip=True,
            brightness_range=[0.8, 1.2],
            validation_split=0.2
        )

        # Create data generators
        train_generator = train_datagen.flow_from_directory(
            '/content/rice_disease_dataset',
            target_size=(224, 224),
            batch_size=16,  # Smaller batch size for disease classification
            class_mode='categorical',
            subset='training'
        )

        validation_generator = train_datagen.flow_from_directory(
            '/content/rice_disease_dataset',
            target_size=(224, 224),
            batch_size=16,
            class_mode='categorical',
            subset='validation'
        )

        # Create and train model
        print(f"Training rice disease model with {len(class_names)} classes...")
        model = create_rice_disease_model(len(class_names))

        # Callbacks for better training
        callbacks = [
            tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
            tf.keras.callbacks.ReduceLROnPlateau(factor=0.2, patience=3)
        ]

        # Train the model
        history = model.fit(
            train_generator,
            epochs=20,  # More epochs for disease classification
            validation_data=validation_generator,
            callbacks=callbacks,
            verbose=1
        )

        # Save model locally
        model.save('/content/rice_disease_model.h5')

        # Save class names
        with open('/content/rice_disease_classes.json', 'w') as f:
            json.dump(class_names, f)

        # Upload model to Firebase Storage
        upload_rice_disease_model_to_firebase()

        print("Rice disease model training completed successfully!")
        is_training = False
        return {"status": "success", "message": "Rice disease model trained successfully", "classes": len(class_names)}

    except Exception as e:
        print(f"Training error: {e}")
        is_training = False
        return {"status": "error", "message": str(e)}

def upload_rice_disease_model_to_firebase():
    """Upload trained rice disease model to Firebase Storage"""
    try:
        # Upload model file
        model_blob = bucket.blob('models/rice_disease_model.h5')
        model_blob.upload_from_filename('/content/rice_disease_model.h5')

        # Upload class names
        classes_blob = bucket.blob('models/rice_disease_classes.json')
        classes_blob.upload_from_filename('/content/rice_disease_classes.json')

        # Update model metadata in Firestore
        model_ref = db.collection('model_info').document('rice_disease_classifier')
        model_ref.set({
            'model_url': model_blob.public_url,
            'classes_url': classes_blob.public_url,
            'num_classes': len(class_names),
            'last_updated': firestore.SERVER_TIMESTAMP,
            'version': int(time.time()),
            'model_type': 'rice_disease_classification'
        })

        print("Rice disease model uploaded to Firebase successfully!")
    except Exception as e:
        print(f"Error uploading rice disease model: {e}")

# ============================================================================
# PART 4: FLASK API ENDPOINTS FOR RICE DISEASE
# ============================================================================

@app.route('/retrain', methods=['POST'])
def retrain_rice_disease_model():
    """Endpoint to trigger rice disease model retraining"""
    try:
        data = request.get_json()
        disease_name = data.get('diseaseName', 'Unknown')

        print(f"Rice disease model retraining triggered by new disease: {disease_name}")

        # Start training in a separate thread
        thread = threading.Thread(target=train_rice_disease_model)
        thread.start()

        return jsonify({
            "status": "success",
            "message": f"Rice disease model retraining started for: {disease_name}"
        })

    except Exception as e:
        return jsonify({
            "status": "error",
            "message": str(e)
        }), 500

@app.route('/training_status', methods=['GET'])
def get_training_status():
    """Check if rice disease model is currently training"""
    return jsonify({
        "is_training": is_training,
        "model_loaded": model is not None,
        "num_classes": len(class_names) if class_names else 0
    })

@app.route('/predict_disease', methods=['POST'])
def predict_rice_disease():
    """Endpoint for rice disease prediction"""
    global model, class_names

    if model is None:
        return jsonify({
            "status": "error",
            "message": "Rice disease model not loaded"
        }), 400

    try:
        # Get image from request
        file = request.files['image']

        # Preprocess image
        image = Image.open(file.stream)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        image = image.resize((224, 224))
        image_array = np.array(image) / 255.0
        image_array = np.expand_dims(image_array, axis=0)

        # Make prediction
        predictions = model.predict(image_array)
        predicted_class_idx = np.argmax(predictions[0])
        confidence = float(predictions[0][predicted_class_idx])

        predicted_disease = class_names[predicted_class_idx]

        # Get additional disease info from Firestore
        disease_info = get_disease_info(predicted_disease)

        return jsonify({
            "status": "success",
            "predicted_disease": predicted_disease,
            "confidence": confidence,
            "disease_info": disease_info,
            "all_predictions": {
                class_names[i]: float(predictions[0][i])
                for i in range(len(class_names))
            }
        })

    except Exception as e:
        return jsonify({
            "status": "error",
            "message": str(e)
        }), 500

def get_disease_info(disease_name):
    """Get additional information about the predicted disease"""
    try:
        doc_id = disease_name.replace('_', ' ')
        doc_ref = db.collection('rice_local_disease').where('diseaseName', '==', doc_id).limit(1)
        docs = doc_ref.stream()

        for doc in docs:
            data = doc.to_dict()
            return {
                'scientific_name': data.get('scientificName'),
                'description': data.get('description'),
                'symptoms': data.get('symptoms'),
                'cause': data.get('cause'),
                'treatments': data.get('treatments')
            }
    except Exception as e:
        print(f"Error getting disease info: {e}")

    return None

# ============================================================================
# PART 5: SERVER STARTUP FOR RICE DISEASE MODEL
# ============================================================================

def load_existing_rice_disease_model():
    """Load existing rice disease model if available"""
    global model, class_names

    try:
        # Download model from Firebase if exists
        model_blob = bucket.blob('models/rice_disease_model.h5')
        classes_blob = bucket.blob('models/rice_disease_classes.json')

        if model_blob.exists():
            model_blob.download_to_filename('/content/rice_disease_model.h5')
            classes_blob.download_to_filename('/content/rice_disease_classes.json')

            model = tf.keras.models.load_model('/content/rice_disease_model.h5')
            with open('/content/rice_disease_classes.json', 'r') as f:
                class_names = json.load(f)

            print(f"Rice disease model loaded with {len(class_names)} classes")
        else:
            print("No existing rice disease model found. Will train on first request.")

    except Exception as e:
        print(f"Error loading existing rice disease model: {e}")

@app.route('/')
def home():
    return jsonify({
        "message": "Rice Disease Classification API",
        "endpoints": {
            "/retrain": "POST - Trigger model retraining",
            "/training_status": "GET - Check training status",
            "/predict_disease": "POST - Predict rice disease from image"
        },
        "model_loaded": model is not None,
        "num_classes": len(class_names) if class_names else 0
    })

if __name__ == '__main__':
    # Load existing model
    load_existing_rice_disease_model()

    # Setup ngrok tunnel
    public_url = ngrok.connect(5000)
    print(f"Public URL: {public_url}")
    print("Rice Disease Classification API is running!")

    # Start Flask server
    app.run(host='0.0.0.0', port=5000, debug=False)