In [1]:
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import json

# === Load class names ===
with open("class_names.json", "r") as f:
    class_names = json.load(f)

# === Define the image transforms (same as during training) ===
transform = transforms.Compose([
    transforms.Resize((300, 300)),  # EfficientNet-B3 input size
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])  # ImageNet normalization
])

# === Load and preprocess the image ===
def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image

# === Load the model and weights ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.efficientnet_b3(weights=None)  # no pretrained weights
model.classifier = torch.nn.Sequential(
    torch.nn.Dropout(0.5),
    torch.nn.Linear(model.classifier[1].in_features, len(class_names))
)
model.load_state_dict(torch.load("efficientnet_b3_plant_disease.pth", map_location=device))
model.to(device)
model.eval()

# === Predict function ===
def predict(image_path):
    image = preprocess_image(image_path).to(device)

    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        predicted_class = class_names[predicted.item()]
    
    return predicted_class


image_path = "sample.jpg"  
predicted_disease = predict(image_path)
print(f"🔍 Predicted Disease: {predicted_disease}")


🔍 Predicted Disease: Potato___Early_blight
