In [None]:
import argparse
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
parser = argparse.ArgumentParser(description="Predict tea leaf disease from an image")
parser.add_argument(
    "--img", "-i",
    required=True,
    help="Path to the input tea‑leaf image"
)
parser.add_argument(
    "--checkpoint", "-c",
    default="best_mobilenetv2_teadiseases.pth",
    help="Path to the trained model checkpoint"
)
args = parser.parse_args()

IMAGE_SIZE  = 224
CLASS_NAMES = [
    'algal_spot',
    'brown_blight',
    'gray_blight',
    'healthy',
    'helopeltis',
    'red-rust',
    'red-spider-infested',
    'red_spot',
    'white-spot'
]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

test_transform = transforms.Compose([
    transforms.Resize(int(IMAGE_SIZE * 1.14)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),                      
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

image = Image.open(args.img).convert("RGB")
input_tensor = test_transform(image).unsqueeze(0).to(DEVICE)

model = models.mobilenet_v2(weights=None)
model.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.last_channel, len(CLASS_NAMES))
)
state = torch.load(args.checkpoint, map_location=DEVICE)
model.load_state_dict(state)
model.to(DEVICE)
model.eval()

with torch.no_grad():
    outputs = model(input_tensor)
    probs = torch.softmax(outputs, dim=1)
    top5_probs, top5_indices = torch.topk(probs, k=5, dim=1)

In [None]:
# predicted_label = CLASS_NAMES[pred_idx]
# print(f"Predicted: {predicted_label} ({confidence*100:.1f}%)")

print("Top-5 Predictions:")
for i in range(5):
    class_idx = top5_indices[0][i].item()
    class_prob = top5_probs[0][i].item()
    print(f"{i+1}. {CLASS_NAMES[class_idx]} ({class_prob*100:.2f}%)")

plt.imshow(image)
plt.title(f"Top Prediction: {CLASS_NAMES[top5_indices[0][0].item()]} ({top5_probs[0][0].item()*100:.2f}%)")
plt.axis('off')
plt.show()
