# Importing Libraries

In [1]:
from flask import Flask, request, jsonify
from PIL import Image
import torch
from torchvision import transforms
import io
from torchvision.models import resnet50
import torch.nn as nn

In [2]:
# Initialize Flask app
app = Flask(__name__)

In [3]:
# Define the model architecture
model_path = 'Models/best_ResNet-50.pth'

# Load a ResNet-50 model (ensure you define the same architecture used when training)
model = resnet50(pretrained=False)  # Use pretrained=False since you'll load your custom weights

# Modify the last layer if needed (e.g., for specific number of classes)
num_classes = 7  # Update with the number of classes in your model
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Load the state dictionary
state_dict = torch.load(model_path, map_location=torch.device('cpu'))  # Load to CPU
model.load_state_dict(state_dict)

# Set the model to evaluation mode
model.eval()

print("Model loaded successfully!")



Model loaded successfully!


  state_dict = torch.load(model_path, map_location=torch.device('cpu'))  # Load to CPU


In [4]:
# Define image transformations (should match the ones used during training)
image_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
])

# Define the class labels (update this with your actual class labels)
class_labels = ["BacterialBlight", "BacterialLeafBlight", "BacterialLeafStreak", "Blast", "BrownSpot", "Normal", "SheathBlight"]  # Replace with your actual labels


@app.route('/predict', methods=['POST'])
def predict():
    if 'image' not in request.files:
        return jsonify({"error": "No image file provided"}), 400
    
    # Get the image file
    file = request.files['image']
    
    try:
        # Open the image
        image = Image.open(file.stream).convert('RGB')
        
        # Apply transformations
        image = image_transforms(image).unsqueeze(0)  # Add batch dimension
        
        # Perform prediction
        with torch.no_grad():
            outputs = model(image)
            _, predicted_class = torch.max(outputs, 1)
            predicted_label = class_labels[predicted_class.item()]
        
        # Return the result
        return jsonify({"prediction": predicted_label}), 200
    
    except Exception as e:
        return jsonify({"error": str(e)}), 500

# Run the Flask app
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=True, use_reloader=False)


 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://10.50.44.202:5000
Press CTRL+C to quit
127.0.0.1 - - [20/Jan/2025 07:58:22] "GET / HTTP/1.1" 404 -
127.0.0.1 - - [20/Jan/2025 07:58:22] "GET /favicon.ico HTTP/1.1" 404 -
10.50.44.202 - - [20/Jan/2025 07:58:27] "GET / HTTP/1.1" 404 -
10.50.44.202 - - [20/Jan/2025 07:58:28] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [20/Jan/2025 07:58:50] "GET / HTTP/1.1" 404 -
127.0.0.1 - - [20/Jan/2025 07:59:36] "GET /predict HTTP/1.1" 405 -
127.0.0.1 - - [20/Jan/2025 08:02:37] "POST /predict HTTP/1.1" 200 -
