In [8]:
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models, transforms
import json
import warnings

In [20]:
warnings.filterwarnings("ignore")

num_classes = 38
model = models.resnet34(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load('./model/resnet34_plant_disease.pth', weights_only=True))
model.eval()

with open("class_names.json", "r") as f:
    class_names = json.load(f)

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def detect_disease(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
        plant_disease = class_names[predicted.item()]

    if "___" in plant_disease:
        plant, disease = plant_disease.split("___")
        plant = plant.replace("_", " ").split(" (")[0]
        disease = disease.replace("_", " ")
        if "healthy" in disease.lower():
            return f'{plant} and it is healthy.\n'
        else:
            return f'Plant Name: {plant}\nDisease Name: {disease}'
    else:
        return f"Plant: {plant_disease}"

user_image_path = './fdf.webp'
detected_disease = detect_disease(user_image_path)
print(f"Detected Plant Disease ➡️\n{detected_disease}")

Detected Plant Disease ➡️
Plant Name: Apple
Disease Name: Black rot
